tasks.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594
  1. """
  2. Task Management API endpoints
  3. Provides endpoints for:
  4. - GET /api/tasks - Get paginated list of tasks with status filtering
  5. - POST /api/tasks/create - Create a new scan task
  6. - GET /api/tasks/detail - Get task details
  7. - POST /api/tasks/delete - Delete a task
  8. - GET /api/tasks/logs - Get task logs with pagination
  9. Requirements: 3.1, 3.4
  10. """
  11. import os
  12. from flask import jsonify, request, current_app
  13. from werkzeug.utils import secure_filename
  14. from app import db
  15. from app.api import api_bp
  16. from app.models import Task, TaskLog, AWSCredential, UserCredential
  17. from app.services import login_required, admin_required, get_current_user_from_context, check_credential_access
  18. from app.errors import ValidationError, NotFoundError, AuthorizationError
  19. ALLOWED_IMAGE_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp'}
  20. def allowed_file(filename: str) -> bool:
  21. """Check if file extension is allowed for network diagram"""
  22. return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_IMAGE_EXTENSIONS
  23. @api_bp.route('/tasks', methods=['GET'])
  24. @login_required
  25. def get_tasks():
  26. """
  27. Get paginated list of tasks with optional status filtering.
  28. Query Parameters:
  29. page: Page number (default: 1)
  30. page_size: Items per page (default: 20, max: 100)
  31. status: Optional filter by status (pending, running, completed, failed)
  32. Returns:
  33. JSON with 'data' array and 'pagination' object
  34. """
  35. current_user = get_current_user_from_context()
  36. # Get pagination parameters
  37. page = request.args.get('page', 1, type=int)
  38. # Support both pageSize (frontend) and page_size (backend convention)
  39. page_size = request.args.get('pageSize', type=int) or request.args.get('page_size', type=int) or 20
  40. page_size = min(page_size, 100)
  41. status = request.args.get('status', type=str)
  42. # Validate pagination
  43. if page < 1:
  44. page = 1
  45. if page_size < 1:
  46. page_size = 20
  47. # Build query based on user role
  48. if current_user.role in ['admin', 'power_user']:
  49. query = Task.query
  50. else:
  51. # Regular users can only see their own tasks
  52. query = Task.query.filter_by(created_by=current_user.id)
  53. # Apply status filter if provided
  54. if status and status in ['pending', 'running', 'completed', 'failed']:
  55. query = query.filter_by(status=status)
  56. # Order by created_at descending
  57. query = query.order_by(Task.created_at.desc())
  58. # Get total count
  59. total = query.count()
  60. total_pages = (total + page_size - 1) // page_size if total > 0 else 1
  61. # Apply pagination
  62. tasks = query.offset((page - 1) * page_size).limit(page_size).all()
  63. return jsonify({
  64. 'data': [task.to_dict() for task in tasks],
  65. 'pagination': {
  66. 'page': page,
  67. 'page_size': page_size,
  68. 'total': total,
  69. 'total_pages': total_pages
  70. }
  71. }), 200
  72. @api_bp.route('/tasks/create', methods=['POST'])
  73. @login_required
  74. def create_task():
  75. """
  76. Create a new scan task.
  77. Request Body (JSON or multipart/form-data):
  78. name: Task name (required)
  79. credential_ids: List of credential IDs to use (required)
  80. regions: List of AWS regions to scan (required)
  81. project_metadata: Project metadata object (required)
  82. - clientName: Client name (required)
  83. - projectName: Project name (required)
  84. - bdManager: BD Manager name (optional)
  85. - bdManagerEmail: BD Manager email (optional)
  86. - solutionsArchitect: Solutions Architect name (optional)
  87. - solutionsArchitectEmail: Solutions Architect email (optional)
  88. - cloudEngineer: Cloud Engineer name (optional)
  89. - cloudEngineerEmail: Cloud Engineer email (optional)
  90. network_diagram: Network diagram image file (optional, multipart only)
  91. Returns:
  92. JSON with created task details and task_id
  93. """
  94. current_user = get_current_user_from_context()
  95. # Handle both JSON and multipart/form-data
  96. if request.content_type and 'multipart/form-data' in request.content_type:
  97. data = request.form.to_dict()
  98. # Parse JSON fields from form data
  99. import json
  100. if 'credential_ids' in data:
  101. data['credential_ids'] = json.loads(data['credential_ids'])
  102. if 'regions' in data:
  103. data['regions'] = json.loads(data['regions'])
  104. if 'project_metadata' in data:
  105. data['project_metadata'] = json.loads(data['project_metadata'])
  106. network_diagram = request.files.get('network_diagram')
  107. else:
  108. data = request.get_json() or {}
  109. network_diagram = None
  110. # Validate required fields
  111. if not data.get('name'):
  112. raise ValidationError(
  113. message="Task name is required",
  114. details={"missing_fields": ["name"]}
  115. )
  116. credential_ids = data.get('credential_ids', [])
  117. if not credential_ids or not isinstance(credential_ids, list) or len(credential_ids) == 0:
  118. raise ValidationError(
  119. message="At least one credential must be selected",
  120. details={"missing_fields": ["credential_ids"]}
  121. )
  122. regions = data.get('regions', [])
  123. if not regions or not isinstance(regions, list) or len(regions) == 0:
  124. raise ValidationError(
  125. message="At least one region must be selected",
  126. details={"missing_fields": ["regions"]}
  127. )
  128. project_metadata = data.get('project_metadata', {})
  129. if not isinstance(project_metadata, dict):
  130. raise ValidationError(
  131. message="Project metadata must be an object",
  132. details={"field": "project_metadata", "reason": "invalid_type"}
  133. )
  134. # Validate required project metadata fields
  135. required_metadata = ['clientName', 'projectName']
  136. missing_metadata = [field for field in required_metadata if not project_metadata.get(field)]
  137. if missing_metadata:
  138. raise ValidationError(
  139. message="Missing required project metadata fields",
  140. details={"missing_fields": missing_metadata}
  141. )
  142. # Validate clientName and projectName don't contain invalid filename characters
  143. import re
  144. invalid_chars_pattern = r'[<>\/\\|*:?"]'
  145. client_name = project_metadata.get('clientName', '')
  146. project_name = project_metadata.get('projectName', '')
  147. if re.search(invalid_chars_pattern, client_name):
  148. raise ValidationError(
  149. message="Client name contains invalid characters",
  150. details={"field": "clientName", "reason": "Cannot contain < > / \\ | * : ? \""}
  151. )
  152. if re.search(invalid_chars_pattern, project_name):
  153. raise ValidationError(
  154. message="Project name contains invalid characters",
  155. details={"field": "projectName", "reason": "Cannot contain < > / \\ | * : ? \""}
  156. )
  157. # Validate credential access for regular users
  158. for cred_id in credential_ids:
  159. if not check_credential_access(current_user, cred_id):
  160. raise AuthorizationError(
  161. message=f"Access denied to credential {cred_id}",
  162. details={"credential_id": cred_id, "reason": "not_assigned"}
  163. )
  164. # Verify credential exists and is active
  165. credential = db.session.get(AWSCredential, cred_id)
  166. if not credential:
  167. raise NotFoundError(
  168. message=f"Credential {cred_id} not found",
  169. details={"credential_id": cred_id}
  170. )
  171. if not credential.is_active:
  172. raise ValidationError(
  173. message=f"Credential {cred_id} is not active",
  174. details={"credential_id": cred_id, "reason": "inactive"}
  175. )
  176. # Handle network diagram upload
  177. network_diagram_path = None
  178. if network_diagram and network_diagram.filename:
  179. if not allowed_file(network_diagram.filename):
  180. raise ValidationError(
  181. message="Invalid file type for network diagram. Allowed: png, jpg, jpeg, gif, bmp",
  182. details={"field": "network_diagram", "reason": "invalid_file_type"}
  183. )
  184. # Save the file
  185. uploads_folder = current_app.config.get('UPLOAD_FOLDER', 'uploads')
  186. os.makedirs(uploads_folder, exist_ok=True)
  187. filename = secure_filename(network_diagram.filename)
  188. # Add timestamp to avoid conflicts
  189. import time
  190. filename = f"{int(time.time())}_{filename}"
  191. network_diagram_path = os.path.join(uploads_folder, filename)
  192. network_diagram.save(network_diagram_path)
  193. # Store path in project metadata
  194. project_metadata['network_diagram_path'] = network_diagram_path
  195. # Create task
  196. task = Task(
  197. name=data['name'].strip(),
  198. status='pending',
  199. progress=0,
  200. created_by=current_user.id
  201. )
  202. task.credential_ids = credential_ids
  203. task.regions = regions
  204. task.project_metadata = project_metadata
  205. db.session.add(task)
  206. db.session.commit()
  207. # Dispatch to Celery
  208. celery_task = None
  209. try:
  210. # 先测试Redis连接
  211. import redis
  212. print(f"🔍 测试Redis连接...")
  213. r = redis.Redis(host='localhost', port=6379, db=0)
  214. r.ping()
  215. print(f"✅ Redis连接成功")
  216. # 导入并初始化Celery应用
  217. from app.celery_app import celery_app, init_celery
  218. init_celery(current_app._get_current_object())
  219. print(f"✅ Celery初始化完成, broker: {celery_app.conf.broker_url}")
  220. # 导入Celery任务
  221. from app.tasks.scan_tasks import scan_aws_resources
  222. print(f"✅ 任务模块导入成功")
  223. # 提交任务
  224. print(f"🔍 提交任务到Celery队列...")
  225. celery_task = scan_aws_resources.delay(
  226. task_id=task.id,
  227. credential_ids=credential_ids,
  228. regions=regions,
  229. project_metadata=project_metadata
  230. )
  231. print(f"✅ 任务已提交: {celery_task.id}")
  232. except redis.ConnectionError as e:
  233. # Redis连接失败,删除已创建的任务并返回错误
  234. db.session.delete(task)
  235. db.session.commit()
  236. raise ValidationError(
  237. message="Redis服务不可用,无法创建任务。请确保Redis服务已启动。",
  238. details={"error": str(e)}
  239. )
  240. except Exception as e:
  241. # 其他错误
  242. db.session.delete(task)
  243. db.session.commit()
  244. raise ValidationError(
  245. message="任务提交失败",
  246. details={"error": str(e), "error_type": type(e).__name__}
  247. )
  248. # Store Celery task ID
  249. task.celery_task_id = celery_task.id
  250. db.session.commit()
  251. return jsonify({
  252. 'message': 'Task created successfully',
  253. 'task': task.to_dict(),
  254. 'celery_task_id': celery_task.id
  255. }), 201
  256. @api_bp.route('/tasks/detail', methods=['GET'])
  257. @login_required
  258. def get_task_detail():
  259. """
  260. Get task details including current status and progress.
  261. Query Parameters:
  262. id: Task ID (required)
  263. Returns:
  264. JSON with task details
  265. """
  266. current_user = get_current_user_from_context()
  267. task_id = request.args.get('id', type=int)
  268. if not task_id:
  269. raise ValidationError(
  270. message="Task ID is required",
  271. details={"missing_fields": ["id"]}
  272. )
  273. task = db.session.get(Task, task_id)
  274. if not task:
  275. raise NotFoundError(
  276. message="Task not found",
  277. details={"task_id": task_id}
  278. )
  279. # Check access for regular users
  280. if current_user.role == 'user' and task.created_by != current_user.id:
  281. raise AuthorizationError(
  282. message="Access denied",
  283. details={"reason": "not_owner"}
  284. )
  285. # Get task details with additional info
  286. task_dict = task.to_dict()
  287. # Add report info if available
  288. if task.report:
  289. task_dict['report'] = task.report.to_dict()
  290. # Add error count
  291. error_count = TaskLog.query.filter_by(task_id=task_id, level='error').count()
  292. task_dict['error_count'] = error_count
  293. # Get Celery task status if running
  294. if task.status == 'running' and task.celery_task_id:
  295. from celery.result import AsyncResult
  296. from app.celery_app import celery_app
  297. result = AsyncResult(task.celery_task_id, app=celery_app)
  298. if result.state == 'PROGRESS':
  299. task_dict['celery_progress'] = result.info
  300. return jsonify(task_dict), 200
  301. @api_bp.route('/tasks/delete', methods=['POST'])
  302. @login_required
  303. def delete_task():
  304. """
  305. Delete a task and its associated logs and report.
  306. Request Body:
  307. id: Task ID (required)
  308. Returns:
  309. JSON with success message
  310. """
  311. current_user = get_current_user_from_context()
  312. data = request.get_json() or {}
  313. task_id = data.get('id')
  314. if not task_id:
  315. raise ValidationError(
  316. message="Task ID is required",
  317. details={"missing_fields": ["id"]}
  318. )
  319. task = db.session.get(Task, task_id)
  320. if not task:
  321. raise NotFoundError(
  322. message="Task not found",
  323. details={"task_id": task_id}
  324. )
  325. # Check access - only admin or task owner can delete
  326. if current_user.role != 'admin' and task.created_by != current_user.id:
  327. raise AuthorizationError(
  328. message="Access denied",
  329. details={"reason": "not_owner_or_admin"}
  330. )
  331. # Cannot delete running tasks
  332. if task.status == 'running':
  333. raise ValidationError(
  334. message="Cannot delete a running task",
  335. details={"task_id": task_id, "status": task.status}
  336. )
  337. # Delete associated report file if exists
  338. if task.report and task.report.file_path:
  339. try:
  340. if os.path.exists(task.report.file_path):
  341. os.remove(task.report.file_path)
  342. except OSError:
  343. pass # File may already be deleted
  344. # Delete task (cascade will handle logs and report)
  345. db.session.delete(task)
  346. db.session.commit()
  347. return jsonify({
  348. 'message': 'Task deleted successfully'
  349. }), 200
  350. @api_bp.route('/tasks/logs', methods=['GET'])
  351. @login_required
  352. def get_task_logs():
  353. """
  354. Get paginated task logs.
  355. Query Parameters:
  356. id: Task ID (required)
  357. page: Page number (default: 1)
  358. page_size: Items per page (default: 20, max: 100)
  359. level: Optional filter by log level (info, warning, error)
  360. Returns:
  361. JSON with 'data' array and 'pagination' object
  362. Requirements:
  363. - 8.3: Display error logs associated with task
  364. """
  365. current_user = get_current_user_from_context()
  366. task_id = request.args.get('id', type=int)
  367. if not task_id:
  368. raise ValidationError(
  369. message="Task ID is required",
  370. details={"missing_fields": ["id"]}
  371. )
  372. task = db.session.get(Task, task_id)
  373. if not task:
  374. raise NotFoundError(
  375. message="Task not found",
  376. details={"task_id": task_id}
  377. )
  378. # Check access for regular users
  379. if current_user.role == 'user' and task.created_by != current_user.id:
  380. raise AuthorizationError(
  381. message="Access denied",
  382. details={"reason": "not_owner"}
  383. )
  384. # Get pagination parameters
  385. page = request.args.get('page', 1, type=int)
  386. # Support both pageSize (frontend) and page_size (backend convention)
  387. page_size = request.args.get('pageSize', type=int) or request.args.get('page_size', type=int) or 20
  388. page_size = min(page_size, 100)
  389. level = request.args.get('level', type=str)
  390. # Validate pagination
  391. if page < 1:
  392. page = 1
  393. if page_size < 1:
  394. page_size = 20
  395. # Build query
  396. query = TaskLog.query.filter_by(task_id=task_id)
  397. # Apply level filter if provided
  398. if level and level in ['info', 'warning', 'error']:
  399. query = query.filter_by(level=level)
  400. # Order by created_at descending
  401. query = query.order_by(TaskLog.created_at.desc())
  402. # Get total count
  403. total = query.count()
  404. total_pages = (total + page_size - 1) // page_size if total > 0 else 1
  405. # Apply pagination
  406. logs = query.offset((page - 1) * page_size).limit(page_size).all()
  407. return jsonify({
  408. 'data': [log.to_dict() for log in logs],
  409. 'pagination': {
  410. 'page': page,
  411. 'page_size': page_size,
  412. 'total': total,
  413. 'total_pages': total_pages
  414. }
  415. }), 200
  416. @api_bp.route('/tasks/errors', methods=['GET'])
  417. @login_required
  418. def get_task_errors():
  419. """
  420. Get error logs for a specific task.
  421. This is a convenience endpoint that returns only error-level logs
  422. with full details including stack traces.
  423. Query Parameters:
  424. id: Task ID (required)
  425. page: Page number (default: 1)
  426. page_size: Items per page (default: 20, max: 100)
  427. Returns:
  428. JSON with 'data' array containing error logs and 'pagination' object
  429. Requirements:
  430. - 8.2: Record error details in task record
  431. - 8.3: Display error logs associated with task
  432. """
  433. current_user = get_current_user_from_context()
  434. task_id = request.args.get('id', type=int)
  435. if not task_id:
  436. raise ValidationError(
  437. message="Task ID is required",
  438. details={"missing_fields": ["id"]}
  439. )
  440. task = db.session.get(Task, task_id)
  441. if not task:
  442. raise NotFoundError(
  443. message="Task not found",
  444. details={"task_id": task_id}
  445. )
  446. # Check access for regular users
  447. if current_user.role == 'user' and task.created_by != current_user.id:
  448. raise AuthorizationError(
  449. message="Access denied",
  450. details={"reason": "not_owner"}
  451. )
  452. # Get pagination parameters
  453. page = request.args.get('page', 1, type=int)
  454. # Support both pageSize (frontend) and page_size (backend convention)
  455. page_size = request.args.get('pageSize', type=int) or request.args.get('page_size', type=int) or 20
  456. page_size = min(page_size, 100)
  457. # Validate pagination
  458. if page < 1:
  459. page = 1
  460. if page_size < 1:
  461. page_size = 20
  462. # Build query for error logs only
  463. query = TaskLog.query.filter_by(task_id=task_id, level='error')
  464. # Order by created_at descending
  465. query = query.order_by(TaskLog.created_at.desc())
  466. # Get total count
  467. total = query.count()
  468. total_pages = (total + page_size - 1) // page_size if total > 0 else 1
  469. # Apply pagination
  470. logs = query.offset((page - 1) * page_size).limit(page_size).all()
  471. # Build response with full error details
  472. error_data = []
  473. for log in logs:
  474. log_dict = log.to_dict()
  475. # Ensure details are included for error analysis
  476. error_data.append(log_dict)
  477. return jsonify({
  478. 'data': error_data,
  479. 'pagination': {
  480. 'page': page,
  481. 'page_size': page_size,
  482. 'total': total,
  483. 'total_pages': total_pages
  484. },
  485. 'summary': {
  486. 'total_errors': total,
  487. 'task_status': task.status
  488. }
  489. }), 200