| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771 |
- """
- Celery Tasks for AWS Resource Scanning
- This module contains Celery tasks for executing AWS resource scans
- and generating reports.
- Requirements:
- - 4.1: Dispatch tasks to Celery queue for Worker processing
- - 4.9: Report progress updates to Redis
- - 4.10: Retry up to 3 times with exponential backoff
- - 8.2: Record error details in task record
- """
- from datetime import datetime, timedelta, timezone
- from typing import List, Dict, Any
- from celery.exceptions import SoftTimeLimitExceeded
- import traceback
- from app import db
- from app.celery_app import celery_app
- from app.models import Task, TaskLog, Report
- from app.services.report_generator import ReportGenerator, generate_report_filename
- from app.scanners import AWSScanner, create_credential_provider_from_model
- from app.errors import (
- TaskError, ScanError, ReportGenerationError, TaskErrorLogger,
- ErrorCode, log_error
- )
- def update_task_status(task_id: int, status: str, **kwargs) -> None:
- """Update task status in database"""
- task = Task.query.get(task_id)
- if task:
- task.status = status
- if status == 'running' and not task.started_at:
- task.started_at = datetime.now(timezone.utc)
- if status in ('completed', 'failed'):
- task.completed_at = datetime.now(timezone.utc)
- for key, value in kwargs.items():
- if hasattr(task, key):
- setattr(task, key, value)
- db.session.commit()
- def update_task_progress(task_id: int, progress: int) -> None:
- """Update task progress in database"""
- task = Task.query.get(task_id)
- if task:
- task.progress = progress
- db.session.commit()
- def log_task_message(task_id: int, level: str, message: str, details: Any = None) -> None:
- """
- Log a message for a task.
-
- Requirements:
- - 8.2: Record error details in task record
- """
- import json
-
- # Handle details serialization
- if details is not None:
- if isinstance(details, dict):
- details_json = json.dumps(details)
- elif isinstance(details, str):
- details_json = json.dumps({'info': details})
- else:
- details_json = json.dumps({'data': str(details)})
- else:
- details_json = None
-
- log = TaskLog(
- task_id=task_id,
- level=level,
- message=message,
- details=details_json
- )
- db.session.add(log)
- db.session.commit()
- def log_task_error_with_trace(
- task_id: int,
- error: Exception,
- service: str = None,
- region: str = None,
- context: Dict[str, Any] = None
- ) -> None:
- """
- Log an error for a task with full stack trace.
-
- Requirements:
- - 8.1: Log errors with timestamp, context, and stack trace
- - 8.2: Record error details in task record
- """
- import json
-
- # Build comprehensive error details
- error_details = {
- 'error_type': type(error).__name__,
- 'timestamp': datetime.now(timezone.utc).isoformat(),
- 'stack_trace': traceback.format_exc()
- }
-
- if service:
- error_details['service'] = service
- if region:
- error_details['region'] = region
- if context:
- error_details['context'] = context
-
- try:
- # Rollback any pending transaction first
- db.session.rollback()
-
- # Create task log entry
- log = TaskLog(
- task_id=task_id,
- level='error',
- message=str(error)[:500], # Truncate long error messages
- details=json.dumps(error_details)
- )
- db.session.add(log)
- db.session.commit()
- except Exception as log_error_exc:
- # If logging fails, just print to console
- print(f"Failed to log error to database: {log_error_exc}")
- db.session.rollback()
-
- # Also log to application logger
- log_error(error, context={'task_id': task_id, 'service': service, 'region': region})
- @celery_app.task(bind=True, max_retries=3, default_retry_delay=60)
- def scan_aws_resources(
- self,
- task_id: int,
- credential_ids: List[int],
- regions: List[str],
- project_metadata: Dict[str, Any]
- ) -> Dict[str, Any]:
- """
- Execute AWS resource scanning task
-
- Requirements:
- - 4.1: Dispatch tasks to Celery queue for Worker processing
- - 4.9: Report progress updates to Redis
- - 4.10: Retry up to 3 times with exponential backoff
- - 8.2: Record error details in task record
-
- Args:
- task_id: Database task ID
- credential_ids: List of AWS credential IDs to use
- regions: List of regions to scan
- project_metadata: Project metadata for report generation
-
- Returns:
- Scan results and report path
- """
- import os
- from flask import current_app
- from app.models import AWSCredential, BaseAssumeRoleConfig
-
- try:
- # Update task status to running
- update_task_status(task_id, 'running')
- log_task_message(task_id, 'info', 'Task started', {
- 'credential_ids': credential_ids,
- 'regions': regions
- })
-
- # Store Celery task ID
- task = Task.query.get(task_id)
- if task:
- task.celery_task_id = self.request.id
- db.session.commit()
-
- # Get base assume role config if needed
- base_config = BaseAssumeRoleConfig.query.first()
-
- # Collect all scan results
- all_results = {}
- total_steps = len(credential_ids) * len(regions)
- current_step = 0
- scan_errors = []
-
- # Track if global services have been scanned for each credential
- global_services_scanned = set()
-
- # Check if us-east-1 is in the selected regions
- us_east_1_selected = 'us-east-1' in regions
-
- for cred_id in credential_ids:
- credential = AWSCredential.query.get(cred_id)
- if not credential:
- log_task_message(task_id, 'warning', f'Credential {cred_id} not found, skipping', {
- 'credential_id': cred_id
- })
- continue
-
- try:
- # Get AWS credentials
- cred_provider = create_credential_provider_from_model(credential, base_config)
- scanner = AWSScanner(cred_provider)
-
- # Scan global services only once per credential
- if cred_id not in global_services_scanned:
- try:
- # When us-east-1 is not selected, also scan ACM as a global service
- # ACM certificates in us-east-1 are used by CloudFront (global)
- services_to_scan_globally = list(scanner.global_services)
- if not us_east_1_selected and 'acm' not in services_to_scan_globally:
- services_to_scan_globally.append('acm')
- log_task_message(task_id, 'info', f'Adding ACM to global scan (us-east-1 not selected)', {
- 'account_id': credential.account_id
- })
-
- log_task_message(task_id, 'info', f'Scanning global services for account {credential.account_id}', {
- 'account_id': credential.account_id,
- 'services': services_to_scan_globally
- })
-
- # Scan global services (and ACM if us-east-1 not selected)
- global_scan_result = scanner.scan_resources(
- regions=['us-east-1'], # Global services use us-east-1
- services=services_to_scan_globally
- )
-
- # Merge global results
- for service_key, resources in global_scan_result.resources.items():
- if service_key not in all_results:
- all_results[service_key] = []
- for resource in resources:
- if hasattr(resource, 'to_dict'):
- resource_dict = resource.to_dict()
- elif isinstance(resource, dict):
- resource_dict = resource.copy()
- else:
- resource_dict = {
- 'account_id': getattr(resource, 'account_id', credential.account_id),
- 'region': getattr(resource, 'region', 'global'),
- 'service': getattr(resource, 'service', service_key),
- 'resource_type': getattr(resource, 'resource_type', ''),
- 'resource_id': getattr(resource, 'resource_id', ''),
- 'name': getattr(resource, 'name', ''),
- 'attributes': getattr(resource, 'attributes', {})
- }
- resource_dict['account_id'] = resource_dict.get('account_id') or credential.account_id
- all_results[service_key].append(resource_dict)
-
- for error in global_scan_result.errors:
- error_msg = error.get('error', 'Unknown error')
- error_service = error.get('service', 'unknown')
- log_task_message(task_id, 'warning', f"Scan error in {error_service}: {error_msg}", {
- 'service': error_service,
- 'error': error_msg
- })
- scan_errors.append(error)
-
- global_services_scanned.add(cred_id)
-
- except Exception as e:
- log_task_error_with_trace(
- task_id=task_id,
- error=e,
- service='global_services',
- context={'account_id': credential.account_id}
- )
- scan_errors.append({
- 'service': 'global_services',
- 'error': str(e)
- })
-
- # Get regional services (exclude global services)
- # Also exclude ACM if us-east-1 is not selected (already scanned as global)
- regional_services = [s for s in scanner.supported_services if s not in scanner.global_services]
- if not us_east_1_selected:
- regional_services = [s for s in regional_services if s != 'acm']
-
- for region in regions:
- try:
- # Scan resources in this region (only regional services)
- log_task_message(task_id, 'info', f'Scanning region {region} for account {credential.account_id}', {
- 'region': region,
- 'account_id': credential.account_id
- })
-
- # Use scan_resources for regional services only
- scan_result = scanner.scan_resources(
- regions=[region],
- services=regional_services
- )
-
- # Merge results, converting ResourceData to dict
- for service_key, resources in scan_result.resources.items():
- if service_key not in all_results:
- all_results[service_key] = []
- for resource in resources:
- # Convert ResourceData to dict using to_dict() method
- if hasattr(resource, 'to_dict'):
- resource_dict = resource.to_dict()
- elif isinstance(resource, dict):
- resource_dict = resource.copy()
- else:
- # Fallback: try to access attributes directly
- resource_dict = {
- 'account_id': getattr(resource, 'account_id', credential.account_id),
- 'region': getattr(resource, 'region', region),
- 'service': getattr(resource, 'service', service_key),
- 'resource_type': getattr(resource, 'resource_type', ''),
- 'resource_id': getattr(resource, 'resource_id', ''),
- 'name': getattr(resource, 'name', ''),
- 'attributes': getattr(resource, 'attributes', {})
- }
-
- # Ensure account_id and region are set
- resource_dict['account_id'] = resource_dict.get('account_id') or credential.account_id
- resource_dict['region'] = resource_dict.get('region') or region
- all_results[service_key].append(resource_dict)
-
- # Log any errors from the scan (Requirements 8.2)
- for error in scan_result.errors:
- error_msg = error.get('error', 'Unknown error')
- error_service = error.get('service', 'unknown')
- error_region = error.get('region', region)
-
- log_task_message(task_id, 'warning', f"Scan error in {error_service}: {error_msg}", {
- 'service': error_service,
- 'region': error_region,
- 'error': error_msg
- })
- scan_errors.append(error)
-
- except Exception as e:
- # Log error with full stack trace (Requirements 8.1, 8.2)
- log_task_error_with_trace(
- task_id=task_id,
- error=e,
- service='region_scan',
- region=region,
- context={'account_id': credential.account_id}
- )
- scan_errors.append({
- 'service': 'region_scan',
- 'region': region,
- 'error': str(e)
- })
-
- # Update progress
- current_step += 1
- progress = int((current_step / total_steps) * 90) # Reserve 10% for report generation
- self.update_state(
- state='PROGRESS',
- meta={'progress': progress, 'current': current_step, 'total': total_steps}
- )
- update_task_progress(task_id, progress)
-
- except Exception as e:
- # Log credential-level error with full stack trace
- log_task_error_with_trace(
- task_id=task_id,
- error=e,
- service='credential',
- context={'credential_id': cred_id, 'account_id': credential.account_id if credential else None}
- )
- scan_errors.append({
- 'service': 'credential',
- 'credential_id': cred_id,
- 'error': str(e)
- })
-
- # Generate report
- log_task_message(task_id, 'info', 'Generating report...', {
- 'total_services': len(all_results),
- 'total_errors': len(scan_errors)
- })
- update_task_progress(task_id, 95)
-
- report_path = None
- try:
- # Check if report already exists for this task (in case of retry)
- existing_report = Report.query.filter_by(task_id=task_id).first()
- if existing_report:
- log_task_message(task_id, 'info', 'Report already exists, skipping generation', {
- 'file_name': existing_report.file_name
- })
- report_path = existing_report.file_path
- else:
- # Get reports folder - use absolute path to ensure consistency
- reports_folder = current_app.config.get('REPORTS_FOLDER', 'reports')
- if not os.path.isabs(reports_folder):
- # Convert to absolute path relative to the app root
- reports_folder = os.path.abspath(reports_folder)
- os.makedirs(reports_folder, exist_ok=True)
-
- # Generate filename and path
- filename = generate_report_filename(project_metadata)
- report_path = os.path.join(reports_folder, filename)
-
- # Get network diagram path if provided
- network_diagram_path = project_metadata.get('network_diagram_path')
-
- # Generate the report
- generator = ReportGenerator()
- result = generator.generate_report(
- scan_results=all_results,
- project_metadata=project_metadata,
- output_path=report_path,
- network_diagram_path=network_diagram_path,
- regions=regions
- )
-
- # Create report record in database
- report = Report(
- task_id=task_id,
- file_name=result['file_name'],
- file_path=result['file_path'],
- file_size=result['file_size']
- )
- db.session.add(report)
- db.session.commit()
-
- log_task_message(task_id, 'info', f'Report generated: {filename}', {
- 'file_name': filename,
- 'file_size': result['file_size']
- })
-
- except Exception as e:
- # Log report generation error with full stack trace
- db.session.rollback() # Rollback any pending transaction
- log_task_error_with_trace(
- task_id=task_id,
- error=e,
- service='report_generation',
- context={'project_metadata': project_metadata}
- )
- report_path = None
-
- # Update task status to completed
- update_task_status(task_id, 'completed', progress=100)
- log_task_message(task_id, 'info', 'Task completed successfully', {
- 'total_resources': sum(len(r) for r in all_results.values()),
- 'total_errors': len(scan_errors),
- 'report_generated': report_path is not None
- })
-
- return {
- 'status': 'success',
- 'report_path': report_path,
- 'total_resources': sum(len(r) for r in all_results.values()),
- 'total_errors': len(scan_errors)
- }
-
- except SoftTimeLimitExceeded as e:
- # Log timeout error with full context
- log_task_error_with_trace(
- task_id=task_id,
- error=e,
- service='task_execution',
- context={'error_type': 'timeout'}
- )
- update_task_status(task_id, 'failed')
- raise
-
- except Exception as e:
- # Log error with full stack trace (Requirements 8.1)
- log_task_error_with_trace(
- task_id=task_id,
- error=e,
- service='task_execution',
- context={'retry_count': self.request.retries}
- )
- update_task_status(task_id, 'failed')
-
- # Retry with exponential backoff (Requirements 4.10)
- retry_count = self.request.retries
- if retry_count < self.max_retries:
- countdown = 60 * (2 ** retry_count) # Exponential backoff
- log_task_message(task_id, 'warning', f'Retrying task in {countdown} seconds', {
- 'attempt': retry_count + 1,
- 'max_retries': self.max_retries,
- 'countdown': countdown
- })
- raise self.retry(exc=e, countdown=countdown)
-
- raise
- @celery_app.task(bind=True, max_retries=3, default_retry_delay=60)
- def process_uploaded_scan(
- self,
- task_id: int,
- scan_data_path: str,
- project_metadata: Dict[str, Any]
- ) -> Dict[str, Any]:
- """
- Process uploaded CloudShell scan data and generate report.
-
- This task processes JSON scan data uploaded from CloudShell scanner
- and generates a report using the existing ReportGenerator.
-
- Requirements:
- - 5.1: Generate reports in the same format as existing scan tasks
- - 5.2: Use account_id from uploaded data as report identifier
- - 5.3: Update task status to completed when done
- - 5.5: Record error and update task status to failed on error
-
- Args:
- task_id: Database task ID
- scan_data_path: Path to the uploaded JSON scan data file
- project_metadata: Project metadata for report generation
-
- Returns:
- Processing results including report path
- """
- import os
- import json
- from flask import current_app
- from app.services.scan_data_processor import ScanDataProcessor
-
- try:
- # Update task status to running
- update_task_status(task_id, 'running')
- log_task_message(task_id, 'info', 'Processing uploaded scan data', {
- 'scan_data_path': scan_data_path
- })
-
- # Store Celery task ID
- task = Task.query.get(task_id)
- if task:
- task.celery_task_id = self.request.id
- db.session.commit()
-
- # Load scan data from file
- if not os.path.exists(scan_data_path):
- raise FileNotFoundError(f"Scan data file not found: {scan_data_path}")
-
- with open(scan_data_path, 'r', encoding='utf-8') as f:
- scan_data = json.load(f)
-
- log_task_message(task_id, 'info', 'Scan data loaded successfully', {
- 'file_size': os.path.getsize(scan_data_path)
- })
- update_task_progress(task_id, 20)
-
- # Validate and convert scan data
- processor = ScanDataProcessor()
- is_valid, validation_errors = processor.validate_scan_data(scan_data)
-
- if not is_valid:
- raise ValueError(f"Invalid scan data: {', '.join(validation_errors)}")
-
- log_task_message(task_id, 'info', 'Scan data validation passed')
- update_task_progress(task_id, 40)
-
- # Convert to ScanResult format
- scan_result = processor.convert_to_scan_result(scan_data)
-
- # Convert resources to dict format for report generator
- all_results = {}
- for service_key, resources in scan_result.resources.items():
- all_results[service_key] = []
- for resource in resources:
- if hasattr(resource, 'to_dict'):
- all_results[service_key].append(resource.to_dict())
- elif isinstance(resource, dict):
- all_results[service_key].append(resource)
- else:
- all_results[service_key].append({
- 'account_id': getattr(resource, 'account_id', ''),
- 'region': getattr(resource, 'region', ''),
- 'service': getattr(resource, 'service', service_key),
- 'resource_type': getattr(resource, 'resource_type', ''),
- 'resource_id': getattr(resource, 'resource_id', ''),
- 'name': getattr(resource, 'name', ''),
- 'attributes': getattr(resource, 'attributes', {})
- })
-
- # Get metadata
- metadata = scan_data.get('metadata', {})
- regions = metadata.get('regions_scanned', [])
-
- log_task_message(task_id, 'info', 'Scan data converted successfully', {
- 'total_services': len(all_results),
- 'total_resources': sum(len(r) for r in all_results.values()),
- 'regions': regions
- })
- update_task_progress(task_id, 60)
-
- # Generate report
- log_task_message(task_id, 'info', 'Generating report...')
- update_task_progress(task_id, 70)
-
- report_path = None
- try:
- # Check if report already exists for this task (in case of retry)
- existing_report = Report.query.filter_by(task_id=task_id).first()
- if existing_report:
- log_task_message(task_id, 'info', 'Report already exists, skipping generation', {
- 'file_name': existing_report.file_name
- })
- report_path = existing_report.file_path
- else:
- # Get reports folder
- reports_folder = current_app.config.get('REPORTS_FOLDER', 'reports')
- if not os.path.isabs(reports_folder):
- reports_folder = os.path.abspath(reports_folder)
- os.makedirs(reports_folder, exist_ok=True)
-
- # Generate filename and path
- filename = generate_report_filename(project_metadata)
- report_path = os.path.join(reports_folder, filename)
-
- # Get network diagram path if provided
- network_diagram_path = project_metadata.get('network_diagram_path')
-
- # Generate the report
- generator = ReportGenerator()
- result = generator.generate_report(
- scan_results=all_results,
- project_metadata=project_metadata,
- output_path=report_path,
- network_diagram_path=network_diagram_path,
- regions=regions
- )
-
- # Create report record in database
- report = Report(
- task_id=task_id,
- file_name=result['file_name'],
- file_path=result['file_path'],
- file_size=result['file_size']
- )
- db.session.add(report)
- db.session.commit()
-
- log_task_message(task_id, 'info', f'Report generated: {filename}', {
- 'file_name': filename,
- 'file_size': result['file_size']
- })
-
- except Exception as e:
- # Log report generation error with full stack trace
- db.session.rollback()
- log_task_error_with_trace(
- task_id=task_id,
- error=e,
- service='report_generation',
- context={'project_metadata': project_metadata}
- )
- report_path = None
-
- # Update task status to completed
- update_task_status(task_id, 'completed', progress=100)
-
- total_resources = sum(len(r) for r in all_results.values())
- total_errors = len(scan_result.errors)
-
- log_task_message(task_id, 'info', 'Task completed successfully', {
- 'total_resources': total_resources,
- 'total_errors': total_errors,
- 'report_generated': report_path is not None
- })
-
- return {
- 'status': 'success',
- 'report_path': report_path,
- 'total_resources': total_resources,
- 'total_errors': total_errors
- }
-
- except SoftTimeLimitExceeded as e:
- log_task_error_with_trace(
- task_id=task_id,
- error=e,
- service='task_execution',
- context={'error_type': 'timeout'}
- )
- update_task_status(task_id, 'failed')
- raise
-
- except Exception as e:
- log_task_error_with_trace(
- task_id=task_id,
- error=e,
- service='task_execution',
- context={'retry_count': self.request.retries}
- )
- update_task_status(task_id, 'failed')
-
- # Retry with exponential backoff
- retry_count = self.request.retries
- if retry_count < self.max_retries:
- countdown = 60 * (2 ** retry_count)
- log_task_message(task_id, 'warning', f'Retrying task in {countdown} seconds', {
- 'attempt': retry_count + 1,
- 'max_retries': self.max_retries,
- 'countdown': countdown
- })
- raise self.retry(exc=e, countdown=countdown)
-
- raise
- @celery_app.task
- def cleanup_old_reports(days: int = 30) -> Dict[str, Any]:
- """
- Clean up reports older than specified days
-
- Args:
- days: Number of days to keep reports
-
- Returns:
- Cleanup statistics
- """
- import os
-
- cutoff_date = datetime.now(timezone.utc) - timedelta(days=days)
- old_reports = Report.query.filter(Report.created_at < cutoff_date).all()
-
- deleted_count = 0
- for report in old_reports:
- try:
- # Delete file if exists
- if os.path.exists(report.file_path):
- os.remove(report.file_path)
-
- # Delete database record
- db.session.delete(report)
- deleted_count += 1
- except Exception as e:
- # Log error but continue with other reports
- print(f"Error deleting report {report.id}: {e}")
-
- db.session.commit()
-
- return {
- 'deleted_count': deleted_count,
- 'cutoff_date': cutoff_date.isoformat()
- }
- @celery_app.task
- def validate_credentials(credential_id: int) -> Dict[str, Any]:
- """
- Validate AWS credentials
-
- Args:
- credential_id: ID of the credential to validate
-
- Returns:
- Validation result
- """
- from app.models import AWSCredential
-
- credential = AWSCredential.query.get(credential_id)
- if not credential:
- return {'valid': False, 'message': 'Credential not found'}
-
- # TODO: Implement actual AWS credential validation in Task 6
- # This is a placeholder
- try:
- # Placeholder for actual validation
- # session = get_aws_session(credential)
- # sts = session.client('sts')
- # sts.get_caller_identity()
-
- return {'valid': True, 'message': 'Credential is valid'}
- except Exception as e:
- return {'valid': False, 'message': str(e)}
|