tasks.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593
  1. """
  2. Task Management API endpoints
  3. Provides endpoints for:
  4. - GET /api/tasks - Get paginated list of tasks with status filtering
  5. - POST /api/tasks/create - Create a new scan task
  6. - GET /api/tasks/detail - Get task details
  7. - POST /api/tasks/delete - Delete a task
  8. - GET /api/tasks/logs - Get task logs with pagination
  9. Requirements: 3.1, 3.4
  10. """
  11. import os
  12. from flask import jsonify, request, current_app
  13. from werkzeug.utils import secure_filename
  14. from app import db
  15. from app.api import api_bp
  16. from app.models import Task, TaskLog, AWSCredential, UserCredential
  17. from app.services import login_required, admin_required, get_current_user_from_context, check_credential_access
  18. from app.errors import ValidationError, NotFoundError, AuthorizationError
  19. ALLOWED_IMAGE_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp'}
  20. def allowed_file(filename: str) -> bool:
  21. """Check if file extension is allowed for network diagram"""
  22. return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_IMAGE_EXTENSIONS
  23. @api_bp.route('/tasks', methods=['GET'])
  24. @login_required
  25. def get_tasks():
  26. """
  27. Get paginated list of tasks with optional status filtering.
  28. Query Parameters:
  29. page: Page number (default: 1)
  30. page_size: Items per page (default: 20, max: 100)
  31. status: Optional filter by status (pending, running, completed, failed)
  32. Returns:
  33. JSON with 'data' array and 'pagination' object
  34. """
  35. current_user = get_current_user_from_context()
  36. # Get pagination parameters
  37. page = request.args.get('page', 1, type=int)
  38. # Support both pageSize (frontend) and page_size (backend convention)
  39. page_size = request.args.get('pageSize', type=int) or request.args.get('page_size', type=int) or 20
  40. page_size = min(page_size, 100)
  41. status = request.args.get('status', type=str)
  42. # Validate pagination
  43. if page < 1:
  44. page = 1
  45. if page_size < 1:
  46. page_size = 20
  47. # Build query based on user role
  48. if current_user.role in ['admin', 'power_user']:
  49. query = Task.query
  50. else:
  51. # Regular users can only see their own tasks
  52. query = Task.query.filter_by(created_by=current_user.id)
  53. # Apply status filter if provided
  54. if status and status in ['pending', 'running', 'completed', 'failed']:
  55. query = query.filter_by(status=status)
  56. # Order by created_at descending
  57. query = query.order_by(Task.created_at.desc())
  58. # Get total count
  59. total = query.count()
  60. total_pages = (total + page_size - 1) // page_size if total > 0 else 1
  61. # Apply pagination
  62. tasks = query.offset((page - 1) * page_size).limit(page_size).all()
  63. return jsonify({
  64. 'data': [task.to_dict() for task in tasks],
  65. 'pagination': {
  66. 'page': page,
  67. 'page_size': page_size,
  68. 'total': total,
  69. 'total_pages': total_pages
  70. }
  71. }), 200
  72. @api_bp.route('/tasks/create', methods=['POST'])
  73. @login_required
  74. def create_task():
  75. """
  76. Create a new scan task.
  77. Request Body (JSON or multipart/form-data):
  78. name: Task name (required)
  79. credential_ids: List of credential IDs to use (required)
  80. regions: List of AWS regions to scan (required)
  81. project_metadata: Project metadata object (required)
  82. - clientName: Client name (required)
  83. - projectName: Project name (required)
  84. - bdManager: BD Manager name (optional)
  85. - bdManagerEmail: BD Manager email (optional)
  86. - solutionsArchitect: Solutions Architect name (optional)
  87. - solutionsArchitectEmail: Solutions Architect email (optional)
  88. - cloudEngineer: Cloud Engineer name (optional)
  89. - cloudEngineerEmail: Cloud Engineer email (optional)
  90. network_diagram: Network diagram image file (optional, multipart only)
  91. Returns:
  92. JSON with created task details and task_id
  93. """
  94. current_user = get_current_user_from_context()
  95. # Handle both JSON and multipart/form-data
  96. if request.content_type and 'multipart/form-data' in request.content_type:
  97. data = request.form.to_dict()
  98. # Parse JSON fields from form data
  99. import json
  100. if 'credential_ids' in data:
  101. data['credential_ids'] = json.loads(data['credential_ids'])
  102. if 'regions' in data:
  103. data['regions'] = json.loads(data['regions'])
  104. if 'project_metadata' in data:
  105. data['project_metadata'] = json.loads(data['project_metadata'])
  106. network_diagram = request.files.get('network_diagram')
  107. else:
  108. data = request.get_json() or {}
  109. network_diagram = None
  110. # Validate required fields
  111. if not data.get('name'):
  112. raise ValidationError(
  113. message="Task name is required",
  114. details={"missing_fields": ["name"]}
  115. )
  116. credential_ids = data.get('credential_ids', [])
  117. if not credential_ids or not isinstance(credential_ids, list) or len(credential_ids) == 0:
  118. raise ValidationError(
  119. message="At least one credential must be selected",
  120. details={"missing_fields": ["credential_ids"]}
  121. )
  122. regions = data.get('regions', [])
  123. if not regions or not isinstance(regions, list) or len(regions) == 0:
  124. raise ValidationError(
  125. message="At least one region must be selected",
  126. details={"missing_fields": ["regions"]}
  127. )
  128. project_metadata = data.get('project_metadata', {})
  129. if not isinstance(project_metadata, dict):
  130. raise ValidationError(
  131. message="Project metadata must be an object",
  132. details={"field": "project_metadata", "reason": "invalid_type"}
  133. )
  134. # Validate required project metadata fields
  135. required_metadata = ['clientName', 'projectName']
  136. missing_metadata = [field for field in required_metadata if not project_metadata.get(field)]
  137. if missing_metadata:
  138. raise ValidationError(
  139. message="Missing required project metadata fields",
  140. details={"missing_fields": missing_metadata}
  141. )
  142. # Validate credential access for regular users
  143. for cred_id in credential_ids:
  144. if not check_credential_access(current_user, cred_id):
  145. raise AuthorizationError(
  146. message=f"Access denied to credential {cred_id}",
  147. details={"credential_id": cred_id, "reason": "not_assigned"}
  148. )
  149. # Verify credential exists and is active
  150. credential = db.session.get(AWSCredential, cred_id)
  151. if not credential:
  152. raise NotFoundError(
  153. message=f"Credential {cred_id} not found",
  154. details={"credential_id": cred_id}
  155. )
  156. if not credential.is_active:
  157. raise ValidationError(
  158. message=f"Credential {cred_id} is not active",
  159. details={"credential_id": cred_id, "reason": "inactive"}
  160. )
  161. # Handle network diagram upload
  162. network_diagram_path = None
  163. if network_diagram and network_diagram.filename:
  164. if not allowed_file(network_diagram.filename):
  165. raise ValidationError(
  166. message="Invalid file type for network diagram. Allowed: png, jpg, jpeg, gif, bmp",
  167. details={"field": "network_diagram", "reason": "invalid_file_type"}
  168. )
  169. # Save the file
  170. uploads_folder = current_app.config.get('UPLOAD_FOLDER', 'uploads')
  171. os.makedirs(uploads_folder, exist_ok=True)
  172. filename = secure_filename(network_diagram.filename)
  173. # Add timestamp to avoid conflicts
  174. import time
  175. filename = f"{int(time.time())}_{filename}"
  176. network_diagram_path = os.path.join(uploads_folder, filename)
  177. network_diagram.save(network_diagram_path)
  178. # Store path in project metadata
  179. project_metadata['network_diagram_path'] = network_diagram_path
  180. # Create task
  181. task = Task(
  182. name=data['name'].strip(),
  183. status='pending',
  184. progress=0,
  185. created_by=current_user.id
  186. )
  187. task.credential_ids = credential_ids
  188. task.regions = regions
  189. task.project_metadata = project_metadata
  190. db.session.add(task)
  191. db.session.commit()
  192. # Dispatch to Celery
  193. celery_task = None
  194. use_mock = False
  195. try:
  196. # 尝试使用真实的Celery (延迟导入)
  197. print("🔍 尝试导入Celery任务模块...")
  198. # 先测试Redis连接
  199. import redis
  200. r = redis.Redis(host='localhost', port=6379, db=0)
  201. r.ping()
  202. print("✅ Redis连接测试通过")
  203. # 导入并初始化Celery应用
  204. from app.celery_app import celery_app, init_celery
  205. # 确保Celery使用正确的broker配置
  206. init_celery(current_app._get_current_object())
  207. print(f"✅ Celery broker配置: {celery_app.conf.broker_url}")
  208. # 导入Celery任务
  209. from app.tasks.scan_tasks import scan_aws_resources
  210. print("✅ Celery任务模块导入成功")
  211. # 提交任务
  212. print("🔍 提交任务到Celery队列...")
  213. celery_task = scan_aws_resources.delay(
  214. task_id=task.id,
  215. credential_ids=credential_ids,
  216. regions=regions,
  217. project_metadata=project_metadata
  218. )
  219. print(f"✅ 任务已提交到Celery队列: {celery_task.id}")
  220. except Exception as e:
  221. # 详细的错误信息
  222. error_str = str(e)
  223. error_type = type(e).__name__
  224. print(f"❌ Celery任务提交失败:")
  225. print(f" 错误类型: {error_type}")
  226. print(f" 错误信息: {error_str}")
  227. use_mock = True
  228. # 如果Celery失败,使用Mock模式
  229. if use_mock:
  230. try:
  231. print("🔄 切换到Mock模式")
  232. from app.tasks.mock_tasks import scan_aws_resources
  233. celery_task = scan_aws_resources.delay(
  234. task_id=task.id,
  235. credential_ids=credential_ids,
  236. regions=regions,
  237. project_metadata=project_metadata
  238. )
  239. print(f"🔄 任务已提交到Mock队列: {celery_task.id}")
  240. except Exception as e:
  241. print(f"❌ Mock模式也失败: {e}")
  242. raise ValidationError(
  243. message="Failed to submit task to both Celery and Mock mode",
  244. details={"celery_error": str(e)}
  245. )
  246. # Store Celery task ID
  247. task.celery_task_id = celery_task.id
  248. db.session.commit()
  249. return jsonify({
  250. 'message': 'Task created successfully',
  251. 'task': task.to_dict(),
  252. 'celery_task_id': celery_task.id
  253. }), 201
  254. @api_bp.route('/tasks/detail', methods=['GET'])
  255. @login_required
  256. def get_task_detail():
  257. """
  258. Get task details including current status and progress.
  259. Query Parameters:
  260. id: Task ID (required)
  261. Returns:
  262. JSON with task details
  263. """
  264. current_user = get_current_user_from_context()
  265. task_id = request.args.get('id', type=int)
  266. if not task_id:
  267. raise ValidationError(
  268. message="Task ID is required",
  269. details={"missing_fields": ["id"]}
  270. )
  271. task = db.session.get(Task, task_id)
  272. if not task:
  273. raise NotFoundError(
  274. message="Task not found",
  275. details={"task_id": task_id}
  276. )
  277. # Check access for regular users
  278. if current_user.role == 'user' and task.created_by != current_user.id:
  279. raise AuthorizationError(
  280. message="Access denied",
  281. details={"reason": "not_owner"}
  282. )
  283. # Get task details with additional info
  284. task_dict = task.to_dict()
  285. # Add report info if available
  286. if task.report:
  287. task_dict['report'] = task.report.to_dict()
  288. # Add error count
  289. error_count = TaskLog.query.filter_by(task_id=task_id, level='error').count()
  290. task_dict['error_count'] = error_count
  291. # Get Celery task status if running
  292. if task.status == 'running' and task.celery_task_id:
  293. from celery.result import AsyncResult
  294. from app.celery_app import celery_app
  295. result = AsyncResult(task.celery_task_id, app=celery_app)
  296. if result.state == 'PROGRESS':
  297. task_dict['celery_progress'] = result.info
  298. return jsonify(task_dict), 200
  299. @api_bp.route('/tasks/delete', methods=['POST'])
  300. @login_required
  301. def delete_task():
  302. """
  303. Delete a task and its associated logs and report.
  304. Request Body:
  305. id: Task ID (required)
  306. Returns:
  307. JSON with success message
  308. """
  309. current_user = get_current_user_from_context()
  310. data = request.get_json() or {}
  311. task_id = data.get('id')
  312. if not task_id:
  313. raise ValidationError(
  314. message="Task ID is required",
  315. details={"missing_fields": ["id"]}
  316. )
  317. task = db.session.get(Task, task_id)
  318. if not task:
  319. raise NotFoundError(
  320. message="Task not found",
  321. details={"task_id": task_id}
  322. )
  323. # Check access - only admin or task owner can delete
  324. if current_user.role != 'admin' and task.created_by != current_user.id:
  325. raise AuthorizationError(
  326. message="Access denied",
  327. details={"reason": "not_owner_or_admin"}
  328. )
  329. # Cannot delete running tasks
  330. if task.status == 'running':
  331. raise ValidationError(
  332. message="Cannot delete a running task",
  333. details={"task_id": task_id, "status": task.status}
  334. )
  335. # Delete associated report file if exists
  336. if task.report and task.report.file_path:
  337. try:
  338. if os.path.exists(task.report.file_path):
  339. os.remove(task.report.file_path)
  340. except OSError:
  341. pass # File may already be deleted
  342. # Delete task (cascade will handle logs and report)
  343. db.session.delete(task)
  344. db.session.commit()
  345. return jsonify({
  346. 'message': 'Task deleted successfully'
  347. }), 200
  348. @api_bp.route('/tasks/logs', methods=['GET'])
  349. @login_required
  350. def get_task_logs():
  351. """
  352. Get paginated task logs.
  353. Query Parameters:
  354. id: Task ID (required)
  355. page: Page number (default: 1)
  356. page_size: Items per page (default: 20, max: 100)
  357. level: Optional filter by log level (info, warning, error)
  358. Returns:
  359. JSON with 'data' array and 'pagination' object
  360. Requirements:
  361. - 8.3: Display error logs associated with task
  362. """
  363. current_user = get_current_user_from_context()
  364. task_id = request.args.get('id', type=int)
  365. if not task_id:
  366. raise ValidationError(
  367. message="Task ID is required",
  368. details={"missing_fields": ["id"]}
  369. )
  370. task = db.session.get(Task, task_id)
  371. if not task:
  372. raise NotFoundError(
  373. message="Task not found",
  374. details={"task_id": task_id}
  375. )
  376. # Check access for regular users
  377. if current_user.role == 'user' and task.created_by != current_user.id:
  378. raise AuthorizationError(
  379. message="Access denied",
  380. details={"reason": "not_owner"}
  381. )
  382. # Get pagination parameters
  383. page = request.args.get('page', 1, type=int)
  384. # Support both pageSize (frontend) and page_size (backend convention)
  385. page_size = request.args.get('pageSize', type=int) or request.args.get('page_size', type=int) or 20
  386. page_size = min(page_size, 100)
  387. level = request.args.get('level', type=str)
  388. # Validate pagination
  389. if page < 1:
  390. page = 1
  391. if page_size < 1:
  392. page_size = 20
  393. # Build query
  394. query = TaskLog.query.filter_by(task_id=task_id)
  395. # Apply level filter if provided
  396. if level and level in ['info', 'warning', 'error']:
  397. query = query.filter_by(level=level)
  398. # Order by created_at descending
  399. query = query.order_by(TaskLog.created_at.desc())
  400. # Get total count
  401. total = query.count()
  402. total_pages = (total + page_size - 1) // page_size if total > 0 else 1
  403. # Apply pagination
  404. logs = query.offset((page - 1) * page_size).limit(page_size).all()
  405. return jsonify({
  406. 'data': [log.to_dict() for log in logs],
  407. 'pagination': {
  408. 'page': page,
  409. 'page_size': page_size,
  410. 'total': total,
  411. 'total_pages': total_pages
  412. }
  413. }), 200
  414. @api_bp.route('/tasks/errors', methods=['GET'])
  415. @login_required
  416. def get_task_errors():
  417. """
  418. Get error logs for a specific task.
  419. This is a convenience endpoint that returns only error-level logs
  420. with full details including stack traces.
  421. Query Parameters:
  422. id: Task ID (required)
  423. page: Page number (default: 1)
  424. page_size: Items per page (default: 20, max: 100)
  425. Returns:
  426. JSON with 'data' array containing error logs and 'pagination' object
  427. Requirements:
  428. - 8.2: Record error details in task record
  429. - 8.3: Display error logs associated with task
  430. """
  431. current_user = get_current_user_from_context()
  432. task_id = request.args.get('id', type=int)
  433. if not task_id:
  434. raise ValidationError(
  435. message="Task ID is required",
  436. details={"missing_fields": ["id"]}
  437. )
  438. task = db.session.get(Task, task_id)
  439. if not task:
  440. raise NotFoundError(
  441. message="Task not found",
  442. details={"task_id": task_id}
  443. )
  444. # Check access for regular users
  445. if current_user.role == 'user' and task.created_by != current_user.id:
  446. raise AuthorizationError(
  447. message="Access denied",
  448. details={"reason": "not_owner"}
  449. )
  450. # Get pagination parameters
  451. page = request.args.get('page', 1, type=int)
  452. # Support both pageSize (frontend) and page_size (backend convention)
  453. page_size = request.args.get('pageSize', type=int) or request.args.get('page_size', type=int) or 20
  454. page_size = min(page_size, 100)
  455. # Validate pagination
  456. if page < 1:
  457. page = 1
  458. if page_size < 1:
  459. page_size = 20
  460. # Build query for error logs only
  461. query = TaskLog.query.filter_by(task_id=task_id, level='error')
  462. # Order by created_at descending
  463. query = query.order_by(TaskLog.created_at.desc())
  464. # Get total count
  465. total = query.count()
  466. total_pages = (total + page_size - 1) // page_size if total > 0 else 1
  467. # Apply pagination
  468. logs = query.offset((page - 1) * page_size).limit(page_size).all()
  469. # Build response with full error details
  470. error_data = []
  471. for log in logs:
  472. log_dict = log.to_dict()
  473. # Ensure details are included for error analysis
  474. error_data.append(log_dict)
  475. return jsonify({
  476. 'data': error_data,
  477. 'pagination': {
  478. 'page': page,
  479. 'page_size': page_size,
  480. 'total': total,
  481. 'total_pages': total_pages
  482. },
  483. 'summary': {
  484. 'total_errors': total,
  485. 'task_status': task.status
  486. }
  487. }), 200