scan_tasks.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. """
  2. Celery Tasks for AWS Resource Scanning
  3. This module contains Celery tasks for executing AWS resource scans
  4. and generating reports.
  5. Requirements:
  6. - 4.1: Dispatch tasks to Celery queue for Worker processing
  7. - 4.9: Report progress updates to Redis
  8. - 4.10: Retry up to 3 times with exponential backoff
  9. - 8.2: Record error details in task record
  10. """
  11. from datetime import datetime, timedelta, timezone
  12. from typing import List, Dict, Any
  13. from celery.exceptions import SoftTimeLimitExceeded
  14. import traceback
  15. from app import db
  16. from app.celery_app import celery_app
  17. from app.models import Task, TaskLog, Report
  18. from app.services.report_generator import ReportGenerator, generate_report_filename
  19. from app.scanners import AWSScanner, create_credential_provider_from_model
  20. from app.errors import (
  21. TaskError, ScanError, ReportGenerationError, TaskErrorLogger,
  22. ErrorCode, log_error
  23. )
  24. def update_task_status(task_id: int, status: str, **kwargs) -> None:
  25. """Update task status in database"""
  26. task = Task.query.get(task_id)
  27. if task:
  28. task.status = status
  29. if status == 'running' and not task.started_at:
  30. task.started_at = datetime.now(timezone.utc)
  31. if status in ('completed', 'failed'):
  32. task.completed_at = datetime.now(timezone.utc)
  33. for key, value in kwargs.items():
  34. if hasattr(task, key):
  35. setattr(task, key, value)
  36. db.session.commit()
  37. def update_task_progress(task_id: int, progress: int) -> None:
  38. """Update task progress in database"""
  39. task = Task.query.get(task_id)
  40. if task:
  41. task.progress = progress
  42. db.session.commit()
  43. def log_task_message(task_id: int, level: str, message: str, details: Any = None) -> None:
  44. """
  45. Log a message for a task.
  46. Requirements:
  47. - 8.2: Record error details in task record
  48. """
  49. import json
  50. # Handle details serialization
  51. if details is not None:
  52. if isinstance(details, dict):
  53. details_json = json.dumps(details)
  54. elif isinstance(details, str):
  55. details_json = json.dumps({'info': details})
  56. else:
  57. details_json = json.dumps({'data': str(details)})
  58. else:
  59. details_json = None
  60. log = TaskLog(
  61. task_id=task_id,
  62. level=level,
  63. message=message,
  64. details=details_json
  65. )
  66. db.session.add(log)
  67. db.session.commit()
  68. def log_task_error_with_trace(
  69. task_id: int,
  70. error: Exception,
  71. service: str = None,
  72. region: str = None,
  73. context: Dict[str, Any] = None
  74. ) -> None:
  75. """
  76. Log an error for a task with full stack trace.
  77. Requirements:
  78. - 8.1: Log errors with timestamp, context, and stack trace
  79. - 8.2: Record error details in task record
  80. """
  81. import json
  82. # Build comprehensive error details
  83. error_details = {
  84. 'error_type': type(error).__name__,
  85. 'timestamp': datetime.now(timezone.utc).isoformat(),
  86. 'stack_trace': traceback.format_exc()
  87. }
  88. if service:
  89. error_details['service'] = service
  90. if region:
  91. error_details['region'] = region
  92. if context:
  93. error_details['context'] = context
  94. # Create task log entry
  95. log = TaskLog(
  96. task_id=task_id,
  97. level='error',
  98. message=str(error),
  99. details=json.dumps(error_details)
  100. )
  101. db.session.add(log)
  102. db.session.commit()
  103. # Also log to application logger
  104. log_error(error, context={'task_id': task_id, 'service': service, 'region': region})
  105. @celery_app.task(bind=True, max_retries=3, default_retry_delay=60)
  106. def scan_aws_resources(
  107. self,
  108. task_id: int,
  109. credential_ids: List[int],
  110. regions: List[str],
  111. project_metadata: Dict[str, Any]
  112. ) -> Dict[str, Any]:
  113. """
  114. Execute AWS resource scanning task
  115. Requirements:
  116. - 4.1: Dispatch tasks to Celery queue for Worker processing
  117. - 4.9: Report progress updates to Redis
  118. - 4.10: Retry up to 3 times with exponential backoff
  119. - 8.2: Record error details in task record
  120. Args:
  121. task_id: Database task ID
  122. credential_ids: List of AWS credential IDs to use
  123. regions: List of regions to scan
  124. project_metadata: Project metadata for report generation
  125. Returns:
  126. Scan results and report path
  127. """
  128. import os
  129. from flask import current_app
  130. from app.models import AWSCredential, BaseAssumeRoleConfig
  131. try:
  132. # Update task status to running
  133. update_task_status(task_id, 'running')
  134. log_task_message(task_id, 'info', 'Task started', {
  135. 'credential_ids': credential_ids,
  136. 'regions': regions
  137. })
  138. # Store Celery task ID
  139. task = Task.query.get(task_id)
  140. if task:
  141. task.celery_task_id = self.request.id
  142. db.session.commit()
  143. # Get base assume role config if needed
  144. base_config = BaseAssumeRoleConfig.query.first()
  145. # Collect all scan results
  146. all_results = {}
  147. total_steps = len(credential_ids) * len(regions)
  148. current_step = 0
  149. scan_errors = []
  150. # Track if global services have been scanned for each credential
  151. global_services_scanned = set()
  152. for cred_id in credential_ids:
  153. credential = AWSCredential.query.get(cred_id)
  154. if not credential:
  155. log_task_message(task_id, 'warning', f'Credential {cred_id} not found, skipping', {
  156. 'credential_id': cred_id
  157. })
  158. continue
  159. try:
  160. # Get AWS credentials
  161. cred_provider = create_credential_provider_from_model(credential, base_config)
  162. scanner = AWSScanner(cred_provider)
  163. # Scan global services only once per credential
  164. if cred_id not in global_services_scanned:
  165. try:
  166. log_task_message(task_id, 'info', f'Scanning global services for account {credential.account_id}', {
  167. 'account_id': credential.account_id
  168. })
  169. # Scan only global services
  170. global_scan_result = scanner.scan_resources(
  171. regions=['us-east-1'], # Global services use us-east-1
  172. services=scanner.global_services
  173. )
  174. # Merge global results
  175. for service_key, resources in global_scan_result.resources.items():
  176. if service_key not in all_results:
  177. all_results[service_key] = []
  178. for resource in resources:
  179. if hasattr(resource, 'to_dict'):
  180. resource_dict = resource.to_dict()
  181. elif isinstance(resource, dict):
  182. resource_dict = resource.copy()
  183. else:
  184. resource_dict = {
  185. 'account_id': getattr(resource, 'account_id', credential.account_id),
  186. 'region': getattr(resource, 'region', 'global'),
  187. 'service': getattr(resource, 'service', service_key),
  188. 'resource_type': getattr(resource, 'resource_type', ''),
  189. 'resource_id': getattr(resource, 'resource_id', ''),
  190. 'name': getattr(resource, 'name', ''),
  191. 'attributes': getattr(resource, 'attributes', {})
  192. }
  193. resource_dict['account_id'] = resource_dict.get('account_id') or credential.account_id
  194. all_results[service_key].append(resource_dict)
  195. for error in global_scan_result.errors:
  196. error_msg = error.get('error', 'Unknown error')
  197. error_service = error.get('service', 'unknown')
  198. log_task_message(task_id, 'warning', f"Scan error in {error_service}: {error_msg}", {
  199. 'service': error_service,
  200. 'error': error_msg
  201. })
  202. scan_errors.append(error)
  203. global_services_scanned.add(cred_id)
  204. except Exception as e:
  205. log_task_error_with_trace(
  206. task_id=task_id,
  207. error=e,
  208. service='global_services',
  209. context={'account_id': credential.account_id}
  210. )
  211. scan_errors.append({
  212. 'service': 'global_services',
  213. 'error': str(e)
  214. })
  215. # Get regional services (exclude global services)
  216. regional_services = [s for s in scanner.supported_services if s not in scanner.global_services]
  217. for region in regions:
  218. try:
  219. # Scan resources in this region (only regional services)
  220. log_task_message(task_id, 'info', f'Scanning region {region} for account {credential.account_id}', {
  221. 'region': region,
  222. 'account_id': credential.account_id
  223. })
  224. # Use scan_resources for regional services only
  225. scan_result = scanner.scan_resources(
  226. regions=[region],
  227. services=regional_services
  228. )
  229. # Merge results, converting ResourceData to dict
  230. for service_key, resources in scan_result.resources.items():
  231. if service_key not in all_results:
  232. all_results[service_key] = []
  233. for resource in resources:
  234. # Convert ResourceData to dict using to_dict() method
  235. if hasattr(resource, 'to_dict'):
  236. resource_dict = resource.to_dict()
  237. elif isinstance(resource, dict):
  238. resource_dict = resource.copy()
  239. else:
  240. # Fallback: try to access attributes directly
  241. resource_dict = {
  242. 'account_id': getattr(resource, 'account_id', credential.account_id),
  243. 'region': getattr(resource, 'region', region),
  244. 'service': getattr(resource, 'service', service_key),
  245. 'resource_type': getattr(resource, 'resource_type', ''),
  246. 'resource_id': getattr(resource, 'resource_id', ''),
  247. 'name': getattr(resource, 'name', ''),
  248. 'attributes': getattr(resource, 'attributes', {})
  249. }
  250. # Ensure account_id and region are set
  251. resource_dict['account_id'] = resource_dict.get('account_id') or credential.account_id
  252. resource_dict['region'] = resource_dict.get('region') or region
  253. all_results[service_key].append(resource_dict)
  254. # Log any errors from the scan (Requirements 8.2)
  255. for error in scan_result.errors:
  256. error_msg = error.get('error', 'Unknown error')
  257. error_service = error.get('service', 'unknown')
  258. error_region = error.get('region', region)
  259. log_task_message(task_id, 'warning', f"Scan error in {error_service}: {error_msg}", {
  260. 'service': error_service,
  261. 'region': error_region,
  262. 'error': error_msg
  263. })
  264. scan_errors.append(error)
  265. except Exception as e:
  266. # Log error with full stack trace (Requirements 8.1, 8.2)
  267. log_task_error_with_trace(
  268. task_id=task_id,
  269. error=e,
  270. service='region_scan',
  271. region=region,
  272. context={'account_id': credential.account_id}
  273. )
  274. scan_errors.append({
  275. 'service': 'region_scan',
  276. 'region': region,
  277. 'error': str(e)
  278. })
  279. # Update progress
  280. current_step += 1
  281. progress = int((current_step / total_steps) * 90) # Reserve 10% for report generation
  282. self.update_state(
  283. state='PROGRESS',
  284. meta={'progress': progress, 'current': current_step, 'total': total_steps}
  285. )
  286. update_task_progress(task_id, progress)
  287. except Exception as e:
  288. # Log credential-level error with full stack trace
  289. log_task_error_with_trace(
  290. task_id=task_id,
  291. error=e,
  292. service='credential',
  293. context={'credential_id': cred_id, 'account_id': credential.account_id if credential else None}
  294. )
  295. scan_errors.append({
  296. 'service': 'credential',
  297. 'credential_id': cred_id,
  298. 'error': str(e)
  299. })
  300. # Generate report
  301. log_task_message(task_id, 'info', 'Generating report...', {
  302. 'total_services': len(all_results),
  303. 'total_errors': len(scan_errors)
  304. })
  305. update_task_progress(task_id, 95)
  306. report_path = None
  307. try:
  308. # Get reports folder - use absolute path to ensure consistency
  309. reports_folder = current_app.config.get('REPORTS_FOLDER', 'reports')
  310. if not os.path.isabs(reports_folder):
  311. # Convert to absolute path relative to the app root
  312. reports_folder = os.path.abspath(reports_folder)
  313. os.makedirs(reports_folder, exist_ok=True)
  314. # Generate filename and path
  315. filename = generate_report_filename(project_metadata)
  316. report_path = os.path.join(reports_folder, filename)
  317. # Get network diagram path if provided
  318. network_diagram_path = project_metadata.get('network_diagram_path')
  319. # Generate the report
  320. generator = ReportGenerator()
  321. result = generator.generate_report(
  322. scan_results=all_results,
  323. project_metadata=project_metadata,
  324. output_path=report_path,
  325. network_diagram_path=network_diagram_path,
  326. regions=regions
  327. )
  328. # Create report record in database
  329. report = Report(
  330. task_id=task_id,
  331. file_name=result['file_name'],
  332. file_path=result['file_path'],
  333. file_size=result['file_size']
  334. )
  335. db.session.add(report)
  336. db.session.commit()
  337. log_task_message(task_id, 'info', f'Report generated: {filename}', {
  338. 'file_name': filename,
  339. 'file_size': result['file_size']
  340. })
  341. except Exception as e:
  342. # Log report generation error with full stack trace
  343. log_task_error_with_trace(
  344. task_id=task_id,
  345. error=e,
  346. service='report_generation',
  347. context={'project_metadata': project_metadata}
  348. )
  349. report_path = None
  350. # Update task status to completed
  351. update_task_status(task_id, 'completed', progress=100)
  352. log_task_message(task_id, 'info', 'Task completed successfully', {
  353. 'total_resources': sum(len(r) for r in all_results.values()),
  354. 'total_errors': len(scan_errors),
  355. 'report_generated': report_path is not None
  356. })
  357. return {
  358. 'status': 'success',
  359. 'report_path': report_path,
  360. 'total_resources': sum(len(r) for r in all_results.values()),
  361. 'total_errors': len(scan_errors)
  362. }
  363. except SoftTimeLimitExceeded as e:
  364. # Log timeout error with full context
  365. log_task_error_with_trace(
  366. task_id=task_id,
  367. error=e,
  368. service='task_execution',
  369. context={'error_type': 'timeout'}
  370. )
  371. update_task_status(task_id, 'failed')
  372. raise
  373. except Exception as e:
  374. # Log error with full stack trace (Requirements 8.1)
  375. log_task_error_with_trace(
  376. task_id=task_id,
  377. error=e,
  378. service='task_execution',
  379. context={'retry_count': self.request.retries}
  380. )
  381. update_task_status(task_id, 'failed')
  382. # Retry with exponential backoff (Requirements 4.10)
  383. retry_count = self.request.retries
  384. if retry_count < self.max_retries:
  385. countdown = 60 * (2 ** retry_count) # Exponential backoff
  386. log_task_message(task_id, 'warning', f'Retrying task in {countdown} seconds', {
  387. 'attempt': retry_count + 1,
  388. 'max_retries': self.max_retries,
  389. 'countdown': countdown
  390. })
  391. raise self.retry(exc=e, countdown=countdown)
  392. raise
  393. @celery_app.task
  394. def cleanup_old_reports(days: int = 30) -> Dict[str, Any]:
  395. """
  396. Clean up reports older than specified days
  397. Args:
  398. days: Number of days to keep reports
  399. Returns:
  400. Cleanup statistics
  401. """
  402. import os
  403. cutoff_date = datetime.now(timezone.utc) - timedelta(days=days)
  404. old_reports = Report.query.filter(Report.created_at < cutoff_date).all()
  405. deleted_count = 0
  406. for report in old_reports:
  407. try:
  408. # Delete file if exists
  409. if os.path.exists(report.file_path):
  410. os.remove(report.file_path)
  411. # Delete database record
  412. db.session.delete(report)
  413. deleted_count += 1
  414. except Exception as e:
  415. # Log error but continue with other reports
  416. print(f"Error deleting report {report.id}: {e}")
  417. db.session.commit()
  418. return {
  419. 'deleted_count': deleted_count,
  420. 'cutoff_date': cutoff_date.isoformat()
  421. }
  422. @celery_app.task
  423. def validate_credentials(credential_id: int) -> Dict[str, Any]:
  424. """
  425. Validate AWS credentials
  426. Args:
  427. credential_id: ID of the credential to validate
  428. Returns:
  429. Validation result
  430. """
  431. from app.models import AWSCredential
  432. credential = AWSCredential.query.get(credential_id)
  433. if not credential:
  434. return {'valid': False, 'message': 'Credential not found'}
  435. # TODO: Implement actual AWS credential validation in Task 6
  436. # This is a placeholder
  437. try:
  438. # Placeholder for actual validation
  439. # session = get_aws_session(credential)
  440. # sts = session.client('sts')
  441. # sts.get_caller_identity()
  442. return {'valid': True, 'message': 'Credential is valid'}
  443. except Exception as e:
  444. return {'valid': False, 'message': str(e)}