auth_service.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. """
  2. JWT Authentication Service
  3. Provides JWT token generation, validation, and authentication decorators.
  4. """
  5. import jwt
  6. from datetime import datetime, timedelta, timezone
  7. from functools import wraps
  8. from typing import Optional, Tuple, Dict, Any
  9. from flask import request, current_app, g
  10. from app import db
  11. from app.models import User
  12. from app.errors import AuthenticationError, AuthorizationError
  13. class AuthService:
  14. """Service for handling JWT authentication"""
  15. @staticmethod
  16. def generate_access_token(user: User) -> str:
  17. """Generate a JWT access token for a user"""
  18. now = datetime.now(timezone.utc)
  19. expires = now + current_app.config['JWT_ACCESS_TOKEN_EXPIRES']
  20. payload = {
  21. 'user_id': user.id,
  22. 'username': user.username,
  23. 'role': user.role,
  24. 'type': 'access',
  25. 'exp': expires,
  26. 'iat': now
  27. }
  28. return jwt.encode(
  29. payload,
  30. current_app.config['JWT_SECRET_KEY'],
  31. algorithm='HS256'
  32. )
  33. @staticmethod
  34. def generate_refresh_token(user: User) -> str:
  35. """Generate a JWT refresh token for a user"""
  36. now = datetime.now(timezone.utc)
  37. expires = now + current_app.config['JWT_REFRESH_TOKEN_EXPIRES']
  38. payload = {
  39. 'user_id': user.id,
  40. 'type': 'refresh',
  41. 'exp': expires,
  42. 'iat': now
  43. }
  44. return jwt.encode(
  45. payload,
  46. current_app.config['JWT_SECRET_KEY'],
  47. algorithm='HS256'
  48. )
  49. @staticmethod
  50. def generate_tokens(user: User) -> Dict[str, str]:
  51. """Generate both access and refresh tokens"""
  52. return {
  53. 'access_token': AuthService.generate_access_token(user),
  54. 'refresh_token': AuthService.generate_refresh_token(user)
  55. }
  56. @staticmethod
  57. def decode_token(token: str) -> Dict[str, Any]:
  58. """
  59. Decode and validate a JWT token
  60. Raises:
  61. AuthenticationError: If token is invalid or expired
  62. """
  63. try:
  64. payload = jwt.decode(
  65. token,
  66. current_app.config['JWT_SECRET_KEY'],
  67. algorithms=['HS256']
  68. )
  69. return payload
  70. except jwt.ExpiredSignatureError:
  71. raise AuthenticationError(
  72. message="Token has expired",
  73. details={"reason": "token_expired"}
  74. )
  75. except jwt.InvalidTokenError as e:
  76. raise AuthenticationError(
  77. message="Invalid token",
  78. details={"reason": "invalid_token"}
  79. )
  80. @staticmethod
  81. def get_token_from_header() -> Optional[str]:
  82. """Extract JWT token from Authorization header"""
  83. auth_header = request.headers.get('Authorization', '')
  84. if auth_header.startswith('Bearer '):
  85. return auth_header[7:]
  86. return None
  87. @staticmethod
  88. def get_current_user() -> Optional[User]:
  89. """Get the current authenticated user from the request"""
  90. token = AuthService.get_token_from_header()
  91. if not token:
  92. return None
  93. try:
  94. payload = AuthService.decode_token(token)
  95. if payload.get('type') != 'access':
  96. return None
  97. user = db.session.get(User, payload['user_id'])
  98. if user and user.is_active:
  99. return user
  100. return None
  101. except AuthenticationError:
  102. return None
  103. @staticmethod
  104. def authenticate(username: str, password: str) -> Tuple[User, Dict[str, str]]:
  105. """
  106. Authenticate a user with username and password
  107. Returns:
  108. Tuple of (User, tokens dict)
  109. Raises:
  110. AuthenticationError: If credentials are invalid
  111. """
  112. user = User.query.filter_by(username=username).first()
  113. if not user or not user.check_password(password):
  114. raise AuthenticationError(
  115. message="Invalid username or password",
  116. details={"reason": "invalid_credentials"}
  117. )
  118. if not user.is_active:
  119. raise AuthenticationError(
  120. message="User account is disabled",
  121. details={"reason": "account_disabled"}
  122. )
  123. tokens = AuthService.generate_tokens(user)
  124. return user, tokens
  125. @staticmethod
  126. def refresh_access_token(refresh_token: str) -> Dict[str, str]:
  127. """
  128. Generate a new access token using a refresh token
  129. Returns:
  130. Dict with new access_token
  131. Raises:
  132. AuthenticationError: If refresh token is invalid
  133. """
  134. payload = AuthService.decode_token(refresh_token)
  135. if payload.get('type') != 'refresh':
  136. raise AuthenticationError(
  137. message="Invalid token type",
  138. details={"reason": "not_refresh_token"}
  139. )
  140. user = db.session.get(User, payload['user_id'])
  141. if not user or not user.is_active:
  142. raise AuthenticationError(
  143. message="User not found or inactive",
  144. details={"reason": "user_invalid"}
  145. )
  146. return {
  147. 'access_token': AuthService.generate_access_token(user)
  148. }
  149. def login_required(f):
  150. """Decorator to require authentication for a route"""
  151. @wraps(f)
  152. def decorated_function(*args, **kwargs):
  153. token = AuthService.get_token_from_header()
  154. if not token:
  155. raise AuthenticationError(
  156. message="Authentication required",
  157. details={"reason": "missing_token"}
  158. )
  159. payload = AuthService.decode_token(token)
  160. if payload.get('type') != 'access':
  161. raise AuthenticationError(
  162. message="Invalid token type",
  163. details={"reason": "not_access_token"}
  164. )
  165. user = db.session.get(User, payload['user_id'])
  166. if not user:
  167. raise AuthenticationError(
  168. message="User not found",
  169. details={"reason": "user_not_found"}
  170. )
  171. if not user.is_active:
  172. raise AuthenticationError(
  173. message="User account is disabled",
  174. details={"reason": "account_disabled"}
  175. )
  176. # Store user in flask g object for access in route
  177. g.current_user = user
  178. return f(*args, **kwargs)
  179. return decorated_function
  180. def get_current_user_from_context() -> User:
  181. """Get the current user from Flask g context (use after login_required)"""
  182. return getattr(g, 'current_user', None)
  183. def admin_required(f):
  184. """
  185. Decorator to require admin role for a route.
  186. Must be used after @login_required.
  187. """
  188. @wraps(f)
  189. def decorated_function(*args, **kwargs):
  190. token = AuthService.get_token_from_header()
  191. if not token:
  192. raise AuthenticationError(
  193. message="Authentication required",
  194. details={"reason": "missing_token"}
  195. )
  196. payload = AuthService.decode_token(token)
  197. if payload.get('type') != 'access':
  198. raise AuthenticationError(
  199. message="Invalid token type",
  200. details={"reason": "not_access_token"}
  201. )
  202. user = db.session.get(User, payload['user_id'])
  203. if not user:
  204. raise AuthenticationError(
  205. message="User not found",
  206. details={"reason": "user_not_found"}
  207. )
  208. if not user.is_active:
  209. raise AuthenticationError(
  210. message="User account is disabled",
  211. details={"reason": "account_disabled"}
  212. )
  213. if user.role != 'admin':
  214. raise AuthorizationError(
  215. message="Admin access required",
  216. details={"reason": "insufficient_permissions", "required_role": "admin"}
  217. )
  218. g.current_user = user
  219. return f(*args, **kwargs)
  220. return decorated_function
  221. def power_user_required(f):
  222. """
  223. Decorator to require power_user or admin role for a route.
  224. Must be used after @login_required.
  225. """
  226. @wraps(f)
  227. def decorated_function(*args, **kwargs):
  228. token = AuthService.get_token_from_header()
  229. if not token:
  230. raise AuthenticationError(
  231. message="Authentication required",
  232. details={"reason": "missing_token"}
  233. )
  234. payload = AuthService.decode_token(token)
  235. if payload.get('type') != 'access':
  236. raise AuthenticationError(
  237. message="Invalid token type",
  238. details={"reason": "not_access_token"}
  239. )
  240. user = db.session.get(User, payload['user_id'])
  241. if not user:
  242. raise AuthenticationError(
  243. message="User not found",
  244. details={"reason": "user_not_found"}
  245. )
  246. if not user.is_active:
  247. raise AuthenticationError(
  248. message="User account is disabled",
  249. details={"reason": "account_disabled"}
  250. )
  251. if user.role not in ['admin', 'power_user']:
  252. raise AuthorizationError(
  253. message="Power user or admin access required",
  254. details={"reason": "insufficient_permissions", "required_role": "power_user"}
  255. )
  256. g.current_user = user
  257. return f(*args, **kwargs)
  258. return decorated_function
  259. def check_resource_access(user: User, resource_owner_id: int) -> bool:
  260. """
  261. Check if a user has access to a resource based on ownership.
  262. - Admin: Can access all resources
  263. - Power User: Can access all resources
  264. - User: Can only access their own resources
  265. Args:
  266. user: The current user
  267. resource_owner_id: The ID of the resource owner
  268. Returns:
  269. True if access is allowed, False otherwise
  270. """
  271. if user.role in ['admin', 'power_user']:
  272. return True
  273. return user.id == resource_owner_id
  274. def check_credential_access(user: User, credential_id: int) -> bool:
  275. """
  276. Check if a user has access to a specific credential.
  277. - Admin: Can access all credentials
  278. - Power User: Can access all credentials
  279. - User: Can only access assigned credentials
  280. Args:
  281. user: The current user
  282. credential_id: The ID of the credential
  283. Returns:
  284. True if access is allowed, False otherwise
  285. """
  286. if user.role in ['admin', 'power_user']:
  287. return True
  288. # Check if credential is assigned to user
  289. from app.models import UserCredential
  290. assignment = UserCredential.query.filter_by(
  291. user_id=user.id,
  292. credential_id=credential_id
  293. ).first()
  294. return assignment is not None
  295. def get_accessible_credentials(user: User):
  296. """
  297. Get list of credentials accessible to a user.
  298. - Admin/Power User: All credentials
  299. - User: Only assigned credentials
  300. Args:
  301. user: The current user
  302. Returns:
  303. Query for accessible credentials
  304. """
  305. from app.models import AWSCredential, UserCredential
  306. if user.role in ['admin', 'power_user']:
  307. return AWSCredential.query.filter_by(is_active=True)
  308. # Get only assigned credentials for regular users
  309. return AWSCredential.query.join(UserCredential).filter(
  310. UserCredential.user_id == user.id,
  311. AWSCredential.is_active == True
  312. )
  313. def get_accessible_reports(user: User):
  314. """
  315. Get list of reports accessible to a user.
  316. - Admin/Power User: All reports
  317. - User: Only their own reports
  318. Args:
  319. user: The current user
  320. Returns:
  321. Query for accessible reports
  322. """
  323. from app.models import Report, Task
  324. if user.role in ['admin', 'power_user']:
  325. return Report.query
  326. # Get only reports from user's own tasks
  327. return Report.query.join(Task).filter(Task.created_by == user.id)