""" Task Management API endpoints Provides endpoints for: - GET /api/tasks - Get paginated list of tasks with status filtering - POST /api/tasks/create - Create a new scan task - GET /api/tasks/detail - Get task details - POST /api/tasks/delete - Delete a task - GET /api/tasks/logs - Get task logs with pagination Requirements: 3.1, 3.4 """ import os from flask import jsonify, request, current_app from werkzeug.utils import secure_filename from app import db from app.api import api_bp from app.models import Task, TaskLog, AWSCredential, UserCredential from app.services import login_required, admin_required, get_current_user_from_context, check_credential_access from app.errors import ValidationError, NotFoundError, AuthorizationError ALLOWED_IMAGE_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp'} def allowed_file(filename: str) -> bool: """Check if file extension is allowed for network diagram""" return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_IMAGE_EXTENSIONS @api_bp.route('/tasks', methods=['GET']) @login_required def get_tasks(): """ Get paginated list of tasks with optional status filtering. Query Parameters: page: Page number (default: 1) page_size: Items per page (default: 20, max: 100) status: Optional filter by status (pending, running, completed, failed) Returns: JSON with 'data' array and 'pagination' object """ current_user = get_current_user_from_context() # Get pagination parameters page = request.args.get('page', 1, type=int) # Support both pageSize (frontend) and page_size (backend convention) page_size = request.args.get('pageSize', type=int) or request.args.get('page_size', type=int) or 20 page_size = min(page_size, 100) status = request.args.get('status', type=str) # Validate pagination if page < 1: page = 1 if page_size < 1: page_size = 20 # Build query based on user role if current_user.role in ['admin', 'power_user']: query = Task.query else: # Regular users can only see their own tasks query = Task.query.filter_by(created_by=current_user.id) # Apply status filter if provided if status and status in ['pending', 'running', 'completed', 'failed']: query = query.filter_by(status=status) # Order by created_at descending query = query.order_by(Task.created_at.desc()) # Get total count total = query.count() total_pages = (total + page_size - 1) // page_size if total > 0 else 1 # Apply pagination tasks = query.offset((page - 1) * page_size).limit(page_size).all() return jsonify({ 'data': [task.to_dict() for task in tasks], 'pagination': { 'page': page, 'page_size': page_size, 'total': total, 'total_pages': total_pages } }), 200 @api_bp.route('/tasks/create', methods=['POST']) @login_required def create_task(): """ Create a new scan task. Request Body (JSON or multipart/form-data): name: Task name (required) credential_ids: List of credential IDs to use (required) regions: List of AWS regions to scan (required) project_metadata: Project metadata object (required) - clientName: Client name (required) - projectName: Project name (required) - bdManager: BD Manager name (optional) - bdManagerEmail: BD Manager email (optional) - solutionsArchitect: Solutions Architect name (optional) - solutionsArchitectEmail: Solutions Architect email (optional) - cloudEngineer: Cloud Engineer name (optional) - cloudEngineerEmail: Cloud Engineer email (optional) network_diagram: Network diagram image file (optional, multipart only) Returns: JSON with created task details and task_id """ current_user = get_current_user_from_context() # Handle both JSON and multipart/form-data if request.content_type and 'multipart/form-data' in request.content_type: data = request.form.to_dict() # Parse JSON fields from form data import json if 'credential_ids' in data: data['credential_ids'] = json.loads(data['credential_ids']) if 'regions' in data: data['regions'] = json.loads(data['regions']) if 'project_metadata' in data: data['project_metadata'] = json.loads(data['project_metadata']) network_diagram = request.files.get('network_diagram') else: data = request.get_json() or {} network_diagram = None # Validate required fields if not data.get('name'): raise ValidationError( message="Task name is required", details={"missing_fields": ["name"]} ) credential_ids = data.get('credential_ids', []) if not credential_ids or not isinstance(credential_ids, list) or len(credential_ids) == 0: raise ValidationError( message="At least one credential must be selected", details={"missing_fields": ["credential_ids"]} ) regions = data.get('regions', []) if not regions or not isinstance(regions, list) or len(regions) == 0: raise ValidationError( message="At least one region must be selected", details={"missing_fields": ["regions"]} ) project_metadata = data.get('project_metadata', {}) if not isinstance(project_metadata, dict): raise ValidationError( message="Project metadata must be an object", details={"field": "project_metadata", "reason": "invalid_type"} ) # Validate required project metadata fields required_metadata = ['clientName', 'projectName'] missing_metadata = [field for field in required_metadata if not project_metadata.get(field)] if missing_metadata: raise ValidationError( message="Missing required project metadata fields", details={"missing_fields": missing_metadata} ) # Validate credential access for regular users for cred_id in credential_ids: if not check_credential_access(current_user, cred_id): raise AuthorizationError( message=f"Access denied to credential {cred_id}", details={"credential_id": cred_id, "reason": "not_assigned"} ) # Verify credential exists and is active credential = db.session.get(AWSCredential, cred_id) if not credential: raise NotFoundError( message=f"Credential {cred_id} not found", details={"credential_id": cred_id} ) if not credential.is_active: raise ValidationError( message=f"Credential {cred_id} is not active", details={"credential_id": cred_id, "reason": "inactive"} ) # Handle network diagram upload network_diagram_path = None if network_diagram and network_diagram.filename: if not allowed_file(network_diagram.filename): raise ValidationError( message="Invalid file type for network diagram. Allowed: png, jpg, jpeg, gif, bmp", details={"field": "network_diagram", "reason": "invalid_file_type"} ) # Save the file uploads_folder = current_app.config.get('UPLOAD_FOLDER', 'uploads') os.makedirs(uploads_folder, exist_ok=True) filename = secure_filename(network_diagram.filename) # Add timestamp to avoid conflicts import time filename = f"{int(time.time())}_{filename}" network_diagram_path = os.path.join(uploads_folder, filename) network_diagram.save(network_diagram_path) # Store path in project metadata project_metadata['network_diagram_path'] = network_diagram_path # Create task task = Task( name=data['name'].strip(), status='pending', progress=0, created_by=current_user.id ) task.credential_ids = credential_ids task.regions = regions task.project_metadata = project_metadata db.session.add(task) db.session.commit() # Dispatch to Celery celery_task = None use_mock = False try: # 尝试使用真实的Celery (延迟导入) print("🔍 尝试导入Celery任务模块...") # 先测试Redis连接 import redis r = redis.Redis(host='localhost', port=6379, db=0) r.ping() print("✅ Redis连接测试通过") # 导入并初始化Celery应用 from app.celery_app import celery_app, init_celery # 确保Celery使用正确的broker配置 init_celery(current_app._get_current_object()) print(f"✅ Celery broker配置: {celery_app.conf.broker_url}") # 导入Celery任务 from app.tasks.scan_tasks import scan_aws_resources print("✅ Celery任务模块导入成功") # 提交任务 print("🔍 提交任务到Celery队列...") celery_task = scan_aws_resources.delay( task_id=task.id, credential_ids=credential_ids, regions=regions, project_metadata=project_metadata ) print(f"✅ 任务已提交到Celery队列: {celery_task.id}") except Exception as e: # 详细的错误信息 error_str = str(e) error_type = type(e).__name__ print(f"❌ Celery任务提交失败:") print(f" 错误类型: {error_type}") print(f" 错误信息: {error_str}") use_mock = True # 如果Celery失败,使用Mock模式 if use_mock: try: print("🔄 切换到Mock模式") from app.tasks.mock_tasks import scan_aws_resources celery_task = scan_aws_resources.delay( task_id=task.id, credential_ids=credential_ids, regions=regions, project_metadata=project_metadata ) print(f"🔄 任务已提交到Mock队列: {celery_task.id}") except Exception as e: print(f"❌ Mock模式也失败: {e}") raise ValidationError( message="Failed to submit task to both Celery and Mock mode", details={"celery_error": str(e)} ) # Store Celery task ID task.celery_task_id = celery_task.id db.session.commit() return jsonify({ 'message': 'Task created successfully', 'task': task.to_dict(), 'celery_task_id': celery_task.id }), 201 @api_bp.route('/tasks/detail', methods=['GET']) @login_required def get_task_detail(): """ Get task details including current status and progress. Query Parameters: id: Task ID (required) Returns: JSON with task details """ current_user = get_current_user_from_context() task_id = request.args.get('id', type=int) if not task_id: raise ValidationError( message="Task ID is required", details={"missing_fields": ["id"]} ) task = db.session.get(Task, task_id) if not task: raise NotFoundError( message="Task not found", details={"task_id": task_id} ) # Check access for regular users if current_user.role == 'user' and task.created_by != current_user.id: raise AuthorizationError( message="Access denied", details={"reason": "not_owner"} ) # Get task details with additional info task_dict = task.to_dict() # Add report info if available if task.report: task_dict['report'] = task.report.to_dict() # Add error count error_count = TaskLog.query.filter_by(task_id=task_id, level='error').count() task_dict['error_count'] = error_count # Get Celery task status if running if task.status == 'running' and task.celery_task_id: from celery.result import AsyncResult from app.celery_app import celery_app result = AsyncResult(task.celery_task_id, app=celery_app) if result.state == 'PROGRESS': task_dict['celery_progress'] = result.info return jsonify(task_dict), 200 @api_bp.route('/tasks/delete', methods=['POST']) @login_required def delete_task(): """ Delete a task and its associated logs and report. Request Body: id: Task ID (required) Returns: JSON with success message """ current_user = get_current_user_from_context() data = request.get_json() or {} task_id = data.get('id') if not task_id: raise ValidationError( message="Task ID is required", details={"missing_fields": ["id"]} ) task = db.session.get(Task, task_id) if not task: raise NotFoundError( message="Task not found", details={"task_id": task_id} ) # Check access - only admin or task owner can delete if current_user.role != 'admin' and task.created_by != current_user.id: raise AuthorizationError( message="Access denied", details={"reason": "not_owner_or_admin"} ) # Cannot delete running tasks if task.status == 'running': raise ValidationError( message="Cannot delete a running task", details={"task_id": task_id, "status": task.status} ) # Delete associated report file if exists if task.report and task.report.file_path: try: if os.path.exists(task.report.file_path): os.remove(task.report.file_path) except OSError: pass # File may already be deleted # Delete task (cascade will handle logs and report) db.session.delete(task) db.session.commit() return jsonify({ 'message': 'Task deleted successfully' }), 200 @api_bp.route('/tasks/logs', methods=['GET']) @login_required def get_task_logs(): """ Get paginated task logs. Query Parameters: id: Task ID (required) page: Page number (default: 1) page_size: Items per page (default: 20, max: 100) level: Optional filter by log level (info, warning, error) Returns: JSON with 'data' array and 'pagination' object Requirements: - 8.3: Display error logs associated with task """ current_user = get_current_user_from_context() task_id = request.args.get('id', type=int) if not task_id: raise ValidationError( message="Task ID is required", details={"missing_fields": ["id"]} ) task = db.session.get(Task, task_id) if not task: raise NotFoundError( message="Task not found", details={"task_id": task_id} ) # Check access for regular users if current_user.role == 'user' and task.created_by != current_user.id: raise AuthorizationError( message="Access denied", details={"reason": "not_owner"} ) # Get pagination parameters page = request.args.get('page', 1, type=int) # Support both pageSize (frontend) and page_size (backend convention) page_size = request.args.get('pageSize', type=int) or request.args.get('page_size', type=int) or 20 page_size = min(page_size, 100) level = request.args.get('level', type=str) # Validate pagination if page < 1: page = 1 if page_size < 1: page_size = 20 # Build query query = TaskLog.query.filter_by(task_id=task_id) # Apply level filter if provided if level and level in ['info', 'warning', 'error']: query = query.filter_by(level=level) # Order by created_at descending query = query.order_by(TaskLog.created_at.desc()) # Get total count total = query.count() total_pages = (total + page_size - 1) // page_size if total > 0 else 1 # Apply pagination logs = query.offset((page - 1) * page_size).limit(page_size).all() return jsonify({ 'data': [log.to_dict() for log in logs], 'pagination': { 'page': page, 'page_size': page_size, 'total': total, 'total_pages': total_pages } }), 200 @api_bp.route('/tasks/errors', methods=['GET']) @login_required def get_task_errors(): """ Get error logs for a specific task. This is a convenience endpoint that returns only error-level logs with full details including stack traces. Query Parameters: id: Task ID (required) page: Page number (default: 1) page_size: Items per page (default: 20, max: 100) Returns: JSON with 'data' array containing error logs and 'pagination' object Requirements: - 8.2: Record error details in task record - 8.3: Display error logs associated with task """ current_user = get_current_user_from_context() task_id = request.args.get('id', type=int) if not task_id: raise ValidationError( message="Task ID is required", details={"missing_fields": ["id"]} ) task = db.session.get(Task, task_id) if not task: raise NotFoundError( message="Task not found", details={"task_id": task_id} ) # Check access for regular users if current_user.role == 'user' and task.created_by != current_user.id: raise AuthorizationError( message="Access denied", details={"reason": "not_owner"} ) # Get pagination parameters page = request.args.get('page', 1, type=int) # Support both pageSize (frontend) and page_size (backend convention) page_size = request.args.get('pageSize', type=int) or request.args.get('page_size', type=int) or 20 page_size = min(page_size, 100) # Validate pagination if page < 1: page = 1 if page_size < 1: page_size = 20 # Build query for error logs only query = TaskLog.query.filter_by(task_id=task_id, level='error') # Order by created_at descending query = query.order_by(TaskLog.created_at.desc()) # Get total count total = query.count() total_pages = (total + page_size - 1) // page_size if total > 0 else 1 # Apply pagination logs = query.offset((page - 1) * page_size).limit(page_size).all() # Build response with full error details error_data = [] for log in logs: log_dict = log.to_dict() # Ensure details are included for error analysis error_data.append(log_dict) return jsonify({ 'data': error_data, 'pagination': { 'page': page, 'page_size': page_size, 'total': total, 'total_pages': total_pages }, 'summary': { 'total_errors': total, 'task_status': task.status } }), 200