""" Celery Tasks for AWS Resource Scanning This module contains Celery tasks for executing AWS resource scans and generating reports. Requirements: - 4.1: Dispatch tasks to Celery queue for Worker processing - 4.9: Report progress updates to Redis - 4.10: Retry up to 3 times with exponential backoff - 8.2: Record error details in task record """ from datetime import datetime, timedelta, timezone from typing import List, Dict, Any from celery.exceptions import SoftTimeLimitExceeded import traceback from app import db from app.celery_app import celery_app from app.models import Task, TaskLog, Report from app.services.report_generator import ReportGenerator, generate_report_filename from app.scanners import AWSScanner, create_credential_provider_from_model from app.errors import ( TaskError, ScanError, ReportGenerationError, TaskErrorLogger, ErrorCode, log_error ) def update_task_status(task_id: int, status: str, **kwargs) -> None: """Update task status in database""" task = Task.query.get(task_id) if task: task.status = status if status == 'running' and not task.started_at: task.started_at = datetime.now(timezone.utc) if status in ('completed', 'failed'): task.completed_at = datetime.now(timezone.utc) for key, value in kwargs.items(): if hasattr(task, key): setattr(task, key, value) db.session.commit() def update_task_progress(task_id: int, progress: int) -> None: """Update task progress in database""" task = Task.query.get(task_id) if task: task.progress = progress db.session.commit() def log_task_message(task_id: int, level: str, message: str, details: Any = None) -> None: """ Log a message for a task. Requirements: - 8.2: Record error details in task record """ import json # Handle details serialization if details is not None: if isinstance(details, dict): details_json = json.dumps(details) elif isinstance(details, str): details_json = json.dumps({'info': details}) else: details_json = json.dumps({'data': str(details)}) else: details_json = None log = TaskLog( task_id=task_id, level=level, message=message, details=details_json ) db.session.add(log) db.session.commit() def log_task_error_with_trace( task_id: int, error: Exception, service: str = None, region: str = None, context: Dict[str, Any] = None ) -> None: """ Log an error for a task with full stack trace. Requirements: - 8.1: Log errors with timestamp, context, and stack trace - 8.2: Record error details in task record """ import json # Build comprehensive error details error_details = { 'error_type': type(error).__name__, 'timestamp': datetime.now(timezone.utc).isoformat(), 'stack_trace': traceback.format_exc() } if service: error_details['service'] = service if region: error_details['region'] = region if context: error_details['context'] = context # Create task log entry log = TaskLog( task_id=task_id, level='error', message=str(error), details=json.dumps(error_details) ) db.session.add(log) db.session.commit() # Also log to application logger log_error(error, context={'task_id': task_id, 'service': service, 'region': region}) @celery_app.task(bind=True, max_retries=3, default_retry_delay=60) def scan_aws_resources( self, task_id: int, credential_ids: List[int], regions: List[str], project_metadata: Dict[str, Any] ) -> Dict[str, Any]: """ Execute AWS resource scanning task Requirements: - 4.1: Dispatch tasks to Celery queue for Worker processing - 4.9: Report progress updates to Redis - 4.10: Retry up to 3 times with exponential backoff - 8.2: Record error details in task record Args: task_id: Database task ID credential_ids: List of AWS credential IDs to use regions: List of regions to scan project_metadata: Project metadata for report generation Returns: Scan results and report path """ import os from flask import current_app from app.models import AWSCredential, BaseAssumeRoleConfig try: # Update task status to running update_task_status(task_id, 'running') log_task_message(task_id, 'info', 'Task started', { 'credential_ids': credential_ids, 'regions': regions }) # Store Celery task ID task = Task.query.get(task_id) if task: task.celery_task_id = self.request.id db.session.commit() # Get base assume role config if needed base_config = BaseAssumeRoleConfig.query.first() # Collect all scan results all_results = {} total_steps = len(credential_ids) * len(regions) current_step = 0 scan_errors = [] # Track if global services have been scanned for each credential global_services_scanned = set() for cred_id in credential_ids: credential = AWSCredential.query.get(cred_id) if not credential: log_task_message(task_id, 'warning', f'Credential {cred_id} not found, skipping', { 'credential_id': cred_id }) continue try: # Get AWS credentials cred_provider = create_credential_provider_from_model(credential, base_config) scanner = AWSScanner(cred_provider) # Scan global services only once per credential if cred_id not in global_services_scanned: try: log_task_message(task_id, 'info', f'Scanning global services for account {credential.account_id}', { 'account_id': credential.account_id }) # Scan only global services global_scan_result = scanner.scan_resources( regions=['us-east-1'], # Global services use us-east-1 services=scanner.global_services ) # Merge global results for service_key, resources in global_scan_result.resources.items(): if service_key not in all_results: all_results[service_key] = [] for resource in resources: if hasattr(resource, 'to_dict'): resource_dict = resource.to_dict() elif isinstance(resource, dict): resource_dict = resource.copy() else: resource_dict = { 'account_id': getattr(resource, 'account_id', credential.account_id), 'region': getattr(resource, 'region', 'global'), 'service': getattr(resource, 'service', service_key), 'resource_type': getattr(resource, 'resource_type', ''), 'resource_id': getattr(resource, 'resource_id', ''), 'name': getattr(resource, 'name', ''), 'attributes': getattr(resource, 'attributes', {}) } resource_dict['account_id'] = resource_dict.get('account_id') or credential.account_id all_results[service_key].append(resource_dict) for error in global_scan_result.errors: error_msg = error.get('error', 'Unknown error') error_service = error.get('service', 'unknown') log_task_message(task_id, 'warning', f"Scan error in {error_service}: {error_msg}", { 'service': error_service, 'error': error_msg }) scan_errors.append(error) global_services_scanned.add(cred_id) except Exception as e: log_task_error_with_trace( task_id=task_id, error=e, service='global_services', context={'account_id': credential.account_id} ) scan_errors.append({ 'service': 'global_services', 'error': str(e) }) # Get regional services (exclude global services) regional_services = [s for s in scanner.supported_services if s not in scanner.global_services] for region in regions: try: # Scan resources in this region (only regional services) log_task_message(task_id, 'info', f'Scanning region {region} for account {credential.account_id}', { 'region': region, 'account_id': credential.account_id }) # Use scan_resources for regional services only scan_result = scanner.scan_resources( regions=[region], services=regional_services ) # Merge results, converting ResourceData to dict for service_key, resources in scan_result.resources.items(): if service_key not in all_results: all_results[service_key] = [] for resource in resources: # Convert ResourceData to dict using to_dict() method if hasattr(resource, 'to_dict'): resource_dict = resource.to_dict() elif isinstance(resource, dict): resource_dict = resource.copy() else: # Fallback: try to access attributes directly resource_dict = { 'account_id': getattr(resource, 'account_id', credential.account_id), 'region': getattr(resource, 'region', region), 'service': getattr(resource, 'service', service_key), 'resource_type': getattr(resource, 'resource_type', ''), 'resource_id': getattr(resource, 'resource_id', ''), 'name': getattr(resource, 'name', ''), 'attributes': getattr(resource, 'attributes', {}) } # Ensure account_id and region are set resource_dict['account_id'] = resource_dict.get('account_id') or credential.account_id resource_dict['region'] = resource_dict.get('region') or region all_results[service_key].append(resource_dict) # Log any errors from the scan (Requirements 8.2) for error in scan_result.errors: error_msg = error.get('error', 'Unknown error') error_service = error.get('service', 'unknown') error_region = error.get('region', region) log_task_message(task_id, 'warning', f"Scan error in {error_service}: {error_msg}", { 'service': error_service, 'region': error_region, 'error': error_msg }) scan_errors.append(error) except Exception as e: # Log error with full stack trace (Requirements 8.1, 8.2) log_task_error_with_trace( task_id=task_id, error=e, service='region_scan', region=region, context={'account_id': credential.account_id} ) scan_errors.append({ 'service': 'region_scan', 'region': region, 'error': str(e) }) # Update progress current_step += 1 progress = int((current_step / total_steps) * 90) # Reserve 10% for report generation self.update_state( state='PROGRESS', meta={'progress': progress, 'current': current_step, 'total': total_steps} ) update_task_progress(task_id, progress) except Exception as e: # Log credential-level error with full stack trace log_task_error_with_trace( task_id=task_id, error=e, service='credential', context={'credential_id': cred_id, 'account_id': credential.account_id if credential else None} ) scan_errors.append({ 'service': 'credential', 'credential_id': cred_id, 'error': str(e) }) # Generate report log_task_message(task_id, 'info', 'Generating report...', { 'total_services': len(all_results), 'total_errors': len(scan_errors) }) update_task_progress(task_id, 95) report_path = None try: # Get reports folder - use absolute path to ensure consistency reports_folder = current_app.config.get('REPORTS_FOLDER', 'reports') if not os.path.isabs(reports_folder): # Convert to absolute path relative to the app root reports_folder = os.path.abspath(reports_folder) os.makedirs(reports_folder, exist_ok=True) # Generate filename and path filename = generate_report_filename(project_metadata) report_path = os.path.join(reports_folder, filename) # Get network diagram path if provided network_diagram_path = project_metadata.get('network_diagram_path') # Generate the report generator = ReportGenerator() result = generator.generate_report( scan_results=all_results, project_metadata=project_metadata, output_path=report_path, network_diagram_path=network_diagram_path, regions=regions ) # Create report record in database report = Report( task_id=task_id, file_name=result['file_name'], file_path=result['file_path'], file_size=result['file_size'] ) db.session.add(report) db.session.commit() log_task_message(task_id, 'info', f'Report generated: {filename}', { 'file_name': filename, 'file_size': result['file_size'] }) except Exception as e: # Log report generation error with full stack trace log_task_error_with_trace( task_id=task_id, error=e, service='report_generation', context={'project_metadata': project_metadata} ) report_path = None # Update task status to completed update_task_status(task_id, 'completed', progress=100) log_task_message(task_id, 'info', 'Task completed successfully', { 'total_resources': sum(len(r) for r in all_results.values()), 'total_errors': len(scan_errors), 'report_generated': report_path is not None }) return { 'status': 'success', 'report_path': report_path, 'total_resources': sum(len(r) for r in all_results.values()), 'total_errors': len(scan_errors) } except SoftTimeLimitExceeded as e: # Log timeout error with full context log_task_error_with_trace( task_id=task_id, error=e, service='task_execution', context={'error_type': 'timeout'} ) update_task_status(task_id, 'failed') raise except Exception as e: # Log error with full stack trace (Requirements 8.1) log_task_error_with_trace( task_id=task_id, error=e, service='task_execution', context={'retry_count': self.request.retries} ) update_task_status(task_id, 'failed') # Retry with exponential backoff (Requirements 4.10) retry_count = self.request.retries if retry_count < self.max_retries: countdown = 60 * (2 ** retry_count) # Exponential backoff log_task_message(task_id, 'warning', f'Retrying task in {countdown} seconds', { 'attempt': retry_count + 1, 'max_retries': self.max_retries, 'countdown': countdown }) raise self.retry(exc=e, countdown=countdown) raise @celery_app.task def cleanup_old_reports(days: int = 30) -> Dict[str, Any]: """ Clean up reports older than specified days Args: days: Number of days to keep reports Returns: Cleanup statistics """ import os cutoff_date = datetime.now(timezone.utc) - timedelta(days=days) old_reports = Report.query.filter(Report.created_at < cutoff_date).all() deleted_count = 0 for report in old_reports: try: # Delete file if exists if os.path.exists(report.file_path): os.remove(report.file_path) # Delete database record db.session.delete(report) deleted_count += 1 except Exception as e: # Log error but continue with other reports print(f"Error deleting report {report.id}: {e}") db.session.commit() return { 'deleted_count': deleted_count, 'cutoff_date': cutoff_date.isoformat() } @celery_app.task def validate_credentials(credential_id: int) -> Dict[str, Any]: """ Validate AWS credentials Args: credential_id: ID of the credential to validate Returns: Validation result """ from app.models import AWSCredential credential = AWSCredential.query.get(credential_id) if not credential: return {'valid': False, 'message': 'Credential not found'} # TODO: Implement actual AWS credential validation in Task 6 # This is a placeholder try: # Placeholder for actual validation # session = get_aws_session(credential) # sts = session.client('sts') # sts.get_caller_identity() return {'valid': True, 'message': 'Credential is valid'} except Exception as e: return {'valid': False, 'message': str(e)}