scan_tasks.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555
  1. """
  2. Celery Tasks for AWS Resource Scanning
  3. This module contains Celery tasks for executing AWS resource scans
  4. and generating reports.
  5. Requirements:
  6. - 4.1: Dispatch tasks to Celery queue for Worker processing
  7. - 4.9: Report progress updates to Redis
  8. - 4.10: Retry up to 3 times with exponential backoff
  9. - 8.2: Record error details in task record
  10. """
  11. from datetime import datetime, timedelta, timezone
  12. from typing import List, Dict, Any
  13. from celery.exceptions import SoftTimeLimitExceeded
  14. import traceback
  15. from app import db
  16. from app.celery_app import celery_app
  17. from app.models import Task, TaskLog, Report
  18. from app.services.report_generator import ReportGenerator, generate_report_filename
  19. from app.scanners import AWSScanner, create_credential_provider_from_model
  20. from app.errors import (
  21. TaskError, ScanError, ReportGenerationError, TaskErrorLogger,
  22. ErrorCode, log_error
  23. )
  24. def update_task_status(task_id: int, status: str, **kwargs) -> None:
  25. """Update task status in database"""
  26. task = Task.query.get(task_id)
  27. if task:
  28. task.status = status
  29. if status == 'running' and not task.started_at:
  30. task.started_at = datetime.now(timezone.utc)
  31. if status in ('completed', 'failed'):
  32. task.completed_at = datetime.now(timezone.utc)
  33. for key, value in kwargs.items():
  34. if hasattr(task, key):
  35. setattr(task, key, value)
  36. db.session.commit()
  37. def update_task_progress(task_id: int, progress: int) -> None:
  38. """Update task progress in database"""
  39. task = Task.query.get(task_id)
  40. if task:
  41. task.progress = progress
  42. db.session.commit()
  43. def log_task_message(task_id: int, level: str, message: str, details: Any = None) -> None:
  44. """
  45. Log a message for a task.
  46. Requirements:
  47. - 8.2: Record error details in task record
  48. """
  49. import json
  50. # Handle details serialization
  51. if details is not None:
  52. if isinstance(details, dict):
  53. details_json = json.dumps(details)
  54. elif isinstance(details, str):
  55. details_json = json.dumps({'info': details})
  56. else:
  57. details_json = json.dumps({'data': str(details)})
  58. else:
  59. details_json = None
  60. log = TaskLog(
  61. task_id=task_id,
  62. level=level,
  63. message=message,
  64. details=details_json
  65. )
  66. db.session.add(log)
  67. db.session.commit()
  68. def log_task_error_with_trace(
  69. task_id: int,
  70. error: Exception,
  71. service: str = None,
  72. region: str = None,
  73. context: Dict[str, Any] = None
  74. ) -> None:
  75. """
  76. Log an error for a task with full stack trace.
  77. Requirements:
  78. - 8.1: Log errors with timestamp, context, and stack trace
  79. - 8.2: Record error details in task record
  80. """
  81. import json
  82. # Build comprehensive error details
  83. error_details = {
  84. 'error_type': type(error).__name__,
  85. 'timestamp': datetime.now(timezone.utc).isoformat(),
  86. 'stack_trace': traceback.format_exc()
  87. }
  88. if service:
  89. error_details['service'] = service
  90. if region:
  91. error_details['region'] = region
  92. if context:
  93. error_details['context'] = context
  94. try:
  95. # Rollback any pending transaction first
  96. db.session.rollback()
  97. # Create task log entry
  98. log = TaskLog(
  99. task_id=task_id,
  100. level='error',
  101. message=str(error)[:500], # Truncate long error messages
  102. details=json.dumps(error_details)
  103. )
  104. db.session.add(log)
  105. db.session.commit()
  106. except Exception as log_error_exc:
  107. # If logging fails, just print to console
  108. print(f"Failed to log error to database: {log_error_exc}")
  109. db.session.rollback()
  110. # Also log to application logger
  111. log_error(error, context={'task_id': task_id, 'service': service, 'region': region})
  112. @celery_app.task(bind=True, max_retries=3, default_retry_delay=60)
  113. def scan_aws_resources(
  114. self,
  115. task_id: int,
  116. credential_ids: List[int],
  117. regions: List[str],
  118. project_metadata: Dict[str, Any]
  119. ) -> Dict[str, Any]:
  120. """
  121. Execute AWS resource scanning task
  122. Requirements:
  123. - 4.1: Dispatch tasks to Celery queue for Worker processing
  124. - 4.9: Report progress updates to Redis
  125. - 4.10: Retry up to 3 times with exponential backoff
  126. - 8.2: Record error details in task record
  127. Args:
  128. task_id: Database task ID
  129. credential_ids: List of AWS credential IDs to use
  130. regions: List of regions to scan
  131. project_metadata: Project metadata for report generation
  132. Returns:
  133. Scan results and report path
  134. """
  135. import os
  136. from flask import current_app
  137. from app.models import AWSCredential, BaseAssumeRoleConfig
  138. try:
  139. # Update task status to running
  140. update_task_status(task_id, 'running')
  141. log_task_message(task_id, 'info', 'Task started', {
  142. 'credential_ids': credential_ids,
  143. 'regions': regions
  144. })
  145. # Store Celery task ID
  146. task = Task.query.get(task_id)
  147. if task:
  148. task.celery_task_id = self.request.id
  149. db.session.commit()
  150. # Get base assume role config if needed
  151. base_config = BaseAssumeRoleConfig.query.first()
  152. # Collect all scan results
  153. all_results = {}
  154. total_steps = len(credential_ids) * len(regions)
  155. current_step = 0
  156. scan_errors = []
  157. # Track if global services have been scanned for each credential
  158. global_services_scanned = set()
  159. # Check if us-east-1 is in the selected regions
  160. us_east_1_selected = 'us-east-1' in regions
  161. for cred_id in credential_ids:
  162. credential = AWSCredential.query.get(cred_id)
  163. if not credential:
  164. log_task_message(task_id, 'warning', f'Credential {cred_id} not found, skipping', {
  165. 'credential_id': cred_id
  166. })
  167. continue
  168. try:
  169. # Get AWS credentials
  170. cred_provider = create_credential_provider_from_model(credential, base_config)
  171. scanner = AWSScanner(cred_provider)
  172. # Scan global services only once per credential
  173. if cred_id not in global_services_scanned:
  174. try:
  175. # When us-east-1 is not selected, also scan ACM as a global service
  176. # ACM certificates in us-east-1 are used by CloudFront (global)
  177. services_to_scan_globally = list(scanner.global_services)
  178. if not us_east_1_selected and 'acm' not in services_to_scan_globally:
  179. services_to_scan_globally.append('acm')
  180. log_task_message(task_id, 'info', f'Adding ACM to global scan (us-east-1 not selected)', {
  181. 'account_id': credential.account_id
  182. })
  183. log_task_message(task_id, 'info', f'Scanning global services for account {credential.account_id}', {
  184. 'account_id': credential.account_id,
  185. 'services': services_to_scan_globally
  186. })
  187. # Scan global services (and ACM if us-east-1 not selected)
  188. global_scan_result = scanner.scan_resources(
  189. regions=['us-east-1'], # Global services use us-east-1
  190. services=services_to_scan_globally
  191. )
  192. # Merge global results
  193. for service_key, resources in global_scan_result.resources.items():
  194. if service_key not in all_results:
  195. all_results[service_key] = []
  196. for resource in resources:
  197. if hasattr(resource, 'to_dict'):
  198. resource_dict = resource.to_dict()
  199. elif isinstance(resource, dict):
  200. resource_dict = resource.copy()
  201. else:
  202. resource_dict = {
  203. 'account_id': getattr(resource, 'account_id', credential.account_id),
  204. 'region': getattr(resource, 'region', 'global'),
  205. 'service': getattr(resource, 'service', service_key),
  206. 'resource_type': getattr(resource, 'resource_type', ''),
  207. 'resource_id': getattr(resource, 'resource_id', ''),
  208. 'name': getattr(resource, 'name', ''),
  209. 'attributes': getattr(resource, 'attributes', {})
  210. }
  211. resource_dict['account_id'] = resource_dict.get('account_id') or credential.account_id
  212. all_results[service_key].append(resource_dict)
  213. for error in global_scan_result.errors:
  214. error_msg = error.get('error', 'Unknown error')
  215. error_service = error.get('service', 'unknown')
  216. log_task_message(task_id, 'warning', f"Scan error in {error_service}: {error_msg}", {
  217. 'service': error_service,
  218. 'error': error_msg
  219. })
  220. scan_errors.append(error)
  221. global_services_scanned.add(cred_id)
  222. except Exception as e:
  223. log_task_error_with_trace(
  224. task_id=task_id,
  225. error=e,
  226. service='global_services',
  227. context={'account_id': credential.account_id}
  228. )
  229. scan_errors.append({
  230. 'service': 'global_services',
  231. 'error': str(e)
  232. })
  233. # Get regional services (exclude global services)
  234. # Also exclude ACM if us-east-1 is not selected (already scanned as global)
  235. regional_services = [s for s in scanner.supported_services if s not in scanner.global_services]
  236. if not us_east_1_selected:
  237. regional_services = [s for s in regional_services if s != 'acm']
  238. for region in regions:
  239. try:
  240. # Scan resources in this region (only regional services)
  241. log_task_message(task_id, 'info', f'Scanning region {region} for account {credential.account_id}', {
  242. 'region': region,
  243. 'account_id': credential.account_id
  244. })
  245. # Use scan_resources for regional services only
  246. scan_result = scanner.scan_resources(
  247. regions=[region],
  248. services=regional_services
  249. )
  250. # Merge results, converting ResourceData to dict
  251. for service_key, resources in scan_result.resources.items():
  252. if service_key not in all_results:
  253. all_results[service_key] = []
  254. for resource in resources:
  255. # Convert ResourceData to dict using to_dict() method
  256. if hasattr(resource, 'to_dict'):
  257. resource_dict = resource.to_dict()
  258. elif isinstance(resource, dict):
  259. resource_dict = resource.copy()
  260. else:
  261. # Fallback: try to access attributes directly
  262. resource_dict = {
  263. 'account_id': getattr(resource, 'account_id', credential.account_id),
  264. 'region': getattr(resource, 'region', region),
  265. 'service': getattr(resource, 'service', service_key),
  266. 'resource_type': getattr(resource, 'resource_type', ''),
  267. 'resource_id': getattr(resource, 'resource_id', ''),
  268. 'name': getattr(resource, 'name', ''),
  269. 'attributes': getattr(resource, 'attributes', {})
  270. }
  271. # Ensure account_id and region are set
  272. resource_dict['account_id'] = resource_dict.get('account_id') or credential.account_id
  273. resource_dict['region'] = resource_dict.get('region') or region
  274. all_results[service_key].append(resource_dict)
  275. # Log any errors from the scan (Requirements 8.2)
  276. for error in scan_result.errors:
  277. error_msg = error.get('error', 'Unknown error')
  278. error_service = error.get('service', 'unknown')
  279. error_region = error.get('region', region)
  280. log_task_message(task_id, 'warning', f"Scan error in {error_service}: {error_msg}", {
  281. 'service': error_service,
  282. 'region': error_region,
  283. 'error': error_msg
  284. })
  285. scan_errors.append(error)
  286. except Exception as e:
  287. # Log error with full stack trace (Requirements 8.1, 8.2)
  288. log_task_error_with_trace(
  289. task_id=task_id,
  290. error=e,
  291. service='region_scan',
  292. region=region,
  293. context={'account_id': credential.account_id}
  294. )
  295. scan_errors.append({
  296. 'service': 'region_scan',
  297. 'region': region,
  298. 'error': str(e)
  299. })
  300. # Update progress
  301. current_step += 1
  302. progress = int((current_step / total_steps) * 90) # Reserve 10% for report generation
  303. self.update_state(
  304. state='PROGRESS',
  305. meta={'progress': progress, 'current': current_step, 'total': total_steps}
  306. )
  307. update_task_progress(task_id, progress)
  308. except Exception as e:
  309. # Log credential-level error with full stack trace
  310. log_task_error_with_trace(
  311. task_id=task_id,
  312. error=e,
  313. service='credential',
  314. context={'credential_id': cred_id, 'account_id': credential.account_id if credential else None}
  315. )
  316. scan_errors.append({
  317. 'service': 'credential',
  318. 'credential_id': cred_id,
  319. 'error': str(e)
  320. })
  321. # Generate report
  322. log_task_message(task_id, 'info', 'Generating report...', {
  323. 'total_services': len(all_results),
  324. 'total_errors': len(scan_errors)
  325. })
  326. update_task_progress(task_id, 95)
  327. report_path = None
  328. try:
  329. # Check if report already exists for this task (in case of retry)
  330. existing_report = Report.query.filter_by(task_id=task_id).first()
  331. if existing_report:
  332. log_task_message(task_id, 'info', 'Report already exists, skipping generation', {
  333. 'file_name': existing_report.file_name
  334. })
  335. report_path = existing_report.file_path
  336. else:
  337. # Get reports folder - use absolute path to ensure consistency
  338. reports_folder = current_app.config.get('REPORTS_FOLDER', 'reports')
  339. if not os.path.isabs(reports_folder):
  340. # Convert to absolute path relative to the app root
  341. reports_folder = os.path.abspath(reports_folder)
  342. os.makedirs(reports_folder, exist_ok=True)
  343. # Generate filename and path
  344. filename = generate_report_filename(project_metadata)
  345. report_path = os.path.join(reports_folder, filename)
  346. # Get network diagram path if provided
  347. network_diagram_path = project_metadata.get('network_diagram_path')
  348. # Generate the report
  349. generator = ReportGenerator()
  350. result = generator.generate_report(
  351. scan_results=all_results,
  352. project_metadata=project_metadata,
  353. output_path=report_path,
  354. network_diagram_path=network_diagram_path,
  355. regions=regions
  356. )
  357. # Create report record in database
  358. report = Report(
  359. task_id=task_id,
  360. file_name=result['file_name'],
  361. file_path=result['file_path'],
  362. file_size=result['file_size']
  363. )
  364. db.session.add(report)
  365. db.session.commit()
  366. log_task_message(task_id, 'info', f'Report generated: {filename}', {
  367. 'file_name': filename,
  368. 'file_size': result['file_size']
  369. })
  370. except Exception as e:
  371. # Log report generation error with full stack trace
  372. db.session.rollback() # Rollback any pending transaction
  373. log_task_error_with_trace(
  374. task_id=task_id,
  375. error=e,
  376. service='report_generation',
  377. context={'project_metadata': project_metadata}
  378. )
  379. report_path = None
  380. # Update task status to completed
  381. update_task_status(task_id, 'completed', progress=100)
  382. log_task_message(task_id, 'info', 'Task completed successfully', {
  383. 'total_resources': sum(len(r) for r in all_results.values()),
  384. 'total_errors': len(scan_errors),
  385. 'report_generated': report_path is not None
  386. })
  387. return {
  388. 'status': 'success',
  389. 'report_path': report_path,
  390. 'total_resources': sum(len(r) for r in all_results.values()),
  391. 'total_errors': len(scan_errors)
  392. }
  393. except SoftTimeLimitExceeded as e:
  394. # Log timeout error with full context
  395. log_task_error_with_trace(
  396. task_id=task_id,
  397. error=e,
  398. service='task_execution',
  399. context={'error_type': 'timeout'}
  400. )
  401. update_task_status(task_id, 'failed')
  402. raise
  403. except Exception as e:
  404. # Log error with full stack trace (Requirements 8.1)
  405. log_task_error_with_trace(
  406. task_id=task_id,
  407. error=e,
  408. service='task_execution',
  409. context={'retry_count': self.request.retries}
  410. )
  411. update_task_status(task_id, 'failed')
  412. # Retry with exponential backoff (Requirements 4.10)
  413. retry_count = self.request.retries
  414. if retry_count < self.max_retries:
  415. countdown = 60 * (2 ** retry_count) # Exponential backoff
  416. log_task_message(task_id, 'warning', f'Retrying task in {countdown} seconds', {
  417. 'attempt': retry_count + 1,
  418. 'max_retries': self.max_retries,
  419. 'countdown': countdown
  420. })
  421. raise self.retry(exc=e, countdown=countdown)
  422. raise
  423. @celery_app.task
  424. def cleanup_old_reports(days: int = 30) -> Dict[str, Any]:
  425. """
  426. Clean up reports older than specified days
  427. Args:
  428. days: Number of days to keep reports
  429. Returns:
  430. Cleanup statistics
  431. """
  432. import os
  433. cutoff_date = datetime.now(timezone.utc) - timedelta(days=days)
  434. old_reports = Report.query.filter(Report.created_at < cutoff_date).all()
  435. deleted_count = 0
  436. for report in old_reports:
  437. try:
  438. # Delete file if exists
  439. if os.path.exists(report.file_path):
  440. os.remove(report.file_path)
  441. # Delete database record
  442. db.session.delete(report)
  443. deleted_count += 1
  444. except Exception as e:
  445. # Log error but continue with other reports
  446. print(f"Error deleting report {report.id}: {e}")
  447. db.session.commit()
  448. return {
  449. 'deleted_count': deleted_count,
  450. 'cutoff_date': cutoff_date.isoformat()
  451. }
  452. @celery_app.task
  453. def validate_credentials(credential_id: int) -> Dict[str, Any]:
  454. """
  455. Validate AWS credentials
  456. Args:
  457. credential_id: ID of the credential to validate
  458. Returns:
  459. Validation result
  460. """
  461. from app.models import AWSCredential
  462. credential = AWSCredential.query.get(credential_id)
  463. if not credential:
  464. return {'valid': False, 'message': 'Credential not found'}
  465. # TODO: Implement actual AWS credential validation in Task 6
  466. # This is a placeholder
  467. try:
  468. # Placeholder for actual validation
  469. # session = get_aws_session(credential)
  470. # sts = session.client('sts')
  471. # sts.get_caller_identity()
  472. return {'valid': True, 'message': 'Credential is valid'}
  473. except Exception as e:
  474. return {'valid': False, 'message': str(e)}