test_cloudshell_scanner.py 47 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309
  1. #!/usr/bin/env python3
  2. """
  3. Unit tests for CloudShell Scanner - Region Filtering and Error Handling
  4. Tests for Task 1.6:
  5. - Region list retrieval and filtering logic
  6. - Error capture and continue scanning logic
  7. - Retry mechanism with exponential backoff
  8. Requirements tested:
  9. - 1.3: Scan only specified regions when provided
  10. - 1.4: Scan all available regions when not specified
  11. - 1.8: Record errors and continue scanning other resources
  12. """
  13. import time
  14. import unittest
  15. from unittest.mock import MagicMock, patch, PropertyMock
  16. import pytest
  17. from botocore.exceptions import ClientError, BotoCoreError
  18. # Import the module under test
  19. from cloudshell_scanner import (
  20. CloudShellScanner,
  21. retry_with_exponential_backoff,
  22. is_retryable_error,
  23. RETRYABLE_ERROR_CODES,
  24. RETRYABLE_EXCEPTIONS,
  25. )
  26. class TestRetryWithExponentialBackoff(unittest.TestCase):
  27. """Tests for the retry_with_exponential_backoff decorator."""
  28. def test_successful_call_no_retry(self):
  29. """Test that successful calls don't trigger retries."""
  30. call_count = 0
  31. @retry_with_exponential_backoff(max_retries=3, base_delay=0.01)
  32. def successful_func():
  33. nonlocal call_count
  34. call_count += 1
  35. return "success"
  36. result = successful_func()
  37. self.assertEqual(result, "success")
  38. self.assertEqual(call_count, 1)
  39. def test_retry_on_throttling_error(self):
  40. """Test that throttling errors trigger retries."""
  41. call_count = 0
  42. @retry_with_exponential_backoff(max_retries=2, base_delay=0.01)
  43. def throttled_func():
  44. nonlocal call_count
  45. call_count += 1
  46. if call_count < 3:
  47. error_response = {
  48. "Error": {
  49. "Code": "Throttling",
  50. "Message": "Rate exceeded"
  51. }
  52. }
  53. raise ClientError(error_response, "TestOperation")
  54. return "success"
  55. result = throttled_func()
  56. self.assertEqual(result, "success")
  57. self.assertEqual(call_count, 3)
  58. def test_no_retry_on_non_retryable_error(self):
  59. """Test that non-retryable errors are raised immediately."""
  60. call_count = 0
  61. @retry_with_exponential_backoff(max_retries=3, base_delay=0.01)
  62. def access_denied_func():
  63. nonlocal call_count
  64. call_count += 1
  65. error_response = {
  66. "Error": {
  67. "Code": "AccessDenied",
  68. "Message": "Access Denied"
  69. }
  70. }
  71. raise ClientError(error_response, "TestOperation")
  72. with self.assertRaises(ClientError):
  73. access_denied_func()
  74. # Should only be called once since AccessDenied is not retryable
  75. self.assertEqual(call_count, 1)
  76. def test_max_retries_exhausted(self):
  77. """Test that exception is raised after max retries."""
  78. call_count = 0
  79. @retry_with_exponential_backoff(max_retries=2, base_delay=0.01)
  80. def always_fails():
  81. nonlocal call_count
  82. call_count += 1
  83. error_response = {
  84. "Error": {
  85. "Code": "ServiceUnavailable",
  86. "Message": "Service unavailable"
  87. }
  88. }
  89. raise ClientError(error_response, "TestOperation")
  90. with self.assertRaises(ClientError):
  91. always_fails()
  92. # Should be called max_retries + 1 times
  93. self.assertEqual(call_count, 3)
  94. def test_exponential_backoff_timing(self):
  95. """Test that delays increase exponentially."""
  96. call_times = []
  97. @retry_with_exponential_backoff(max_retries=2, base_delay=0.1, exponential_base=2.0)
  98. def timed_func():
  99. call_times.append(time.time())
  100. if len(call_times) < 3:
  101. error_response = {
  102. "Error": {
  103. "Code": "Throttling",
  104. "Message": "Rate exceeded"
  105. }
  106. }
  107. raise ClientError(error_response, "TestOperation")
  108. return "success"
  109. timed_func()
  110. # Check that delays are approximately exponential
  111. # First delay should be ~0.1s, second should be ~0.2s
  112. if len(call_times) >= 2:
  113. first_delay = call_times[1] - call_times[0]
  114. self.assertGreater(first_delay, 0.05) # At least half the base delay
  115. if len(call_times) >= 3:
  116. second_delay = call_times[2] - call_times[1]
  117. self.assertGreater(second_delay, first_delay * 0.8) # Second delay should be larger
  118. class TestIsRetryableError(unittest.TestCase):
  119. """Tests for the is_retryable_error function."""
  120. def test_throttling_is_retryable(self):
  121. """Test that throttling errors are retryable."""
  122. error_response = {
  123. "Error": {
  124. "Code": "Throttling",
  125. "Message": "Rate exceeded"
  126. }
  127. }
  128. error = ClientError(error_response, "TestOperation")
  129. self.assertTrue(is_retryable_error(error))
  130. def test_service_unavailable_is_retryable(self):
  131. """Test that service unavailable errors are retryable."""
  132. error_response = {
  133. "Error": {
  134. "Code": "ServiceUnavailable",
  135. "Message": "Service unavailable"
  136. }
  137. }
  138. error = ClientError(error_response, "TestOperation")
  139. self.assertTrue(is_retryable_error(error))
  140. def test_access_denied_not_retryable(self):
  141. """Test that access denied errors are not retryable."""
  142. error_response = {
  143. "Error": {
  144. "Code": "AccessDenied",
  145. "Message": "Access Denied"
  146. }
  147. }
  148. error = ClientError(error_response, "TestOperation")
  149. self.assertFalse(is_retryable_error(error))
  150. def test_connection_error_is_retryable(self):
  151. """Test that connection errors are retryable."""
  152. error = ConnectionError("Connection refused")
  153. self.assertTrue(is_retryable_error(error))
  154. def test_timeout_error_is_retryable(self):
  155. """Test that timeout errors are retryable."""
  156. error = TimeoutError("Request timed out")
  157. self.assertTrue(is_retryable_error(error))
  158. class TestRegionFiltering(unittest.TestCase):
  159. """Tests for region filtering functionality."""
  160. @patch('cloudshell_scanner.boto3.Session')
  161. def setUp(self, mock_session):
  162. """Set up test fixtures."""
  163. # Mock the boto3 session
  164. self.mock_session = MagicMock()
  165. mock_session.return_value = self.mock_session
  166. # Mock STS client for get_account_id
  167. self.mock_sts = MagicMock()
  168. self.mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
  169. # Mock EC2 client for list_regions
  170. self.mock_ec2 = MagicMock()
  171. self.mock_ec2.describe_regions.return_value = {
  172. "Regions": [
  173. {"RegionName": "us-east-1"},
  174. {"RegionName": "us-west-2"},
  175. {"RegionName": "eu-west-1"},
  176. {"RegionName": "ap-northeast-1"},
  177. ]
  178. }
  179. def get_client(service, **kwargs):
  180. if service == "sts":
  181. return self.mock_sts
  182. elif service == "ec2":
  183. return self.mock_ec2
  184. return MagicMock()
  185. self.mock_session.client.side_effect = get_client
  186. self.scanner = CloudShellScanner()
  187. def test_list_regions_returns_available_regions(self):
  188. """Test that list_regions returns available regions from AWS."""
  189. regions = self.scanner.list_regions()
  190. self.assertEqual(len(regions), 4)
  191. self.assertIn("us-east-1", regions)
  192. self.assertIn("us-west-2", regions)
  193. self.assertIn("eu-west-1", regions)
  194. self.assertIn("ap-northeast-1", regions)
  195. def test_list_regions_fallback_on_error(self):
  196. """Test that list_regions falls back to defaults on error."""
  197. self.mock_ec2.describe_regions.side_effect = Exception("API Error")
  198. regions = self.scanner.list_regions()
  199. # Should return default regions
  200. self.assertIn("us-east-1", regions)
  201. self.assertIn("us-west-2", regions)
  202. self.assertGreater(len(regions), 0)
  203. def test_filter_regions_with_valid_regions(self):
  204. """Test filtering with valid regions."""
  205. # Validates: Requirements 1.3
  206. requested = ["us-east-1", "us-west-2"]
  207. filtered = self.scanner.filter_regions(requested)
  208. self.assertEqual(len(filtered), 2)
  209. self.assertIn("us-east-1", filtered)
  210. self.assertIn("us-west-2", filtered)
  211. def test_filter_regions_with_invalid_regions(self):
  212. """Test filtering removes invalid regions."""
  213. # Validates: Requirements 1.3
  214. requested = ["us-east-1", "invalid-region", "us-west-2"]
  215. filtered = self.scanner.filter_regions(requested)
  216. self.assertEqual(len(filtered), 2)
  217. self.assertIn("us-east-1", filtered)
  218. self.assertIn("us-west-2", filtered)
  219. self.assertNotIn("invalid-region", filtered)
  220. def test_filter_regions_none_returns_all(self):
  221. """Test that None returns all available regions."""
  222. # Validates: Requirements 1.4
  223. filtered = self.scanner.filter_regions(None)
  224. self.assertEqual(len(filtered), 4)
  225. def test_filter_regions_all_invalid_falls_back(self):
  226. """Test that all invalid regions falls back to all available."""
  227. requested = ["invalid-1", "invalid-2"]
  228. filtered = self.scanner.filter_regions(requested)
  229. # Should fall back to all available regions
  230. self.assertEqual(len(filtered), 4)
  231. def test_filter_regions_normalizes_input(self):
  232. """Test that region names are normalized (whitespace, case)."""
  233. requested = [" US-EAST-1 ", "us-west-2"]
  234. filtered = self.scanner.filter_regions(requested)
  235. self.assertEqual(len(filtered), 2)
  236. self.assertIn("us-east-1", filtered)
  237. self.assertIn("us-west-2", filtered)
  238. def test_validate_region_valid(self):
  239. """Test validate_region with valid region."""
  240. self.assertTrue(self.scanner.validate_region("us-east-1"))
  241. def test_validate_region_invalid(self):
  242. """Test validate_region with invalid region."""
  243. self.assertFalse(self.scanner.validate_region("invalid-region"))
  244. class TestErrorHandling(unittest.TestCase):
  245. """Tests for error handling functionality."""
  246. @patch('cloudshell_scanner.boto3.Session')
  247. def setUp(self, mock_session):
  248. """Set up test fixtures."""
  249. self.mock_session = MagicMock()
  250. mock_session.return_value = self.mock_session
  251. # Mock STS client
  252. self.mock_sts = MagicMock()
  253. self.mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
  254. # Mock EC2 client
  255. self.mock_ec2 = MagicMock()
  256. self.mock_ec2.describe_regions.return_value = {
  257. "Regions": [{"RegionName": "us-east-1"}]
  258. }
  259. def get_client(service, **kwargs):
  260. if service == "sts":
  261. return self.mock_sts
  262. elif service == "ec2":
  263. return self.mock_ec2
  264. return MagicMock()
  265. self.mock_session.client.side_effect = get_client
  266. self.scanner = CloudShellScanner()
  267. def test_create_error_info_client_error(self):
  268. """Test error info creation for ClientError."""
  269. error_response = {
  270. "Error": {
  271. "Code": "AccessDenied",
  272. "Message": "User is not authorized"
  273. }
  274. }
  275. exception = ClientError(error_response, "DescribeInstances")
  276. error_info = self.scanner._create_error_info(
  277. service="ec2",
  278. region="us-east-1",
  279. exception=exception,
  280. )
  281. self.assertEqual(error_info["service"], "ec2")
  282. self.assertEqual(error_info["region"], "us-east-1")
  283. self.assertEqual(error_info["error_type"], "ClientError")
  284. self.assertIsNotNone(error_info["details"])
  285. self.assertEqual(error_info["details"]["error_code"], "AccessDenied")
  286. self.assertIn("permission_hint", error_info["details"])
  287. def test_create_error_info_generic_exception(self):
  288. """Test error info creation for generic exceptions."""
  289. exception = ValueError("Invalid value")
  290. error_info = self.scanner._create_error_info(
  291. service="vpc",
  292. region="eu-west-1",
  293. exception=exception,
  294. )
  295. self.assertEqual(error_info["service"], "vpc")
  296. self.assertEqual(error_info["region"], "eu-west-1")
  297. self.assertEqual(error_info["error_type"], "ValueError")
  298. self.assertIn("Invalid value", error_info["error"])
  299. def test_scan_continues_after_error(self):
  300. """Test that scanning continues after encountering an error."""
  301. # Validates: Requirements 1.8
  302. # Mock _scan_service to fail for one service but succeed for another
  303. call_count = {"vpc": 0, "ec2": 0}
  304. def mock_scan_service(account_id, region, service):
  305. call_count[service] = call_count.get(service, 0) + 1
  306. if service == "vpc":
  307. raise Exception("VPC scan failed")
  308. return [{"resource_id": "i-123", "service": service}]
  309. self.scanner._scan_service = mock_scan_service
  310. # Scan with both services
  311. result = self.scanner.scan_resources(
  312. regions=["us-east-1"],
  313. services=["vpc", "ec2"],
  314. )
  315. # Both services should have been attempted
  316. self.assertEqual(call_count["vpc"], 1)
  317. self.assertEqual(call_count["ec2"], 1)
  318. # Should have one error and one successful resource
  319. self.assertEqual(len(result["errors"]), 1)
  320. self.assertEqual(result["errors"][0]["service"], "vpc")
  321. self.assertIn("ec2", result["resources"])
  322. def test_error_info_includes_region(self):
  323. """Test that error info includes the correct region."""
  324. # Validates: Requirements 1.8
  325. def mock_scan_service(account_id, region, service):
  326. raise Exception(f"Error in {region}")
  327. self.scanner._scan_service = mock_scan_service
  328. result = self.scanner.scan_resources(
  329. regions=["us-east-1"],
  330. services=["vpc"],
  331. )
  332. self.assertEqual(len(result["errors"]), 1)
  333. self.assertEqual(result["errors"][0]["region"], "us-east-1")
  334. class TestCallWithRetry(unittest.TestCase):
  335. """Tests for the _call_with_retry method."""
  336. @patch('cloudshell_scanner.boto3.Session')
  337. def setUp(self, mock_session):
  338. """Set up test fixtures."""
  339. self.mock_session = MagicMock()
  340. mock_session.return_value = self.mock_session
  341. self.mock_sts = MagicMock()
  342. self.mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
  343. self.mock_session.client.return_value = self.mock_sts
  344. self.scanner = CloudShellScanner()
  345. def test_call_with_retry_success(self):
  346. """Test successful call without retries."""
  347. def success_func():
  348. return "result"
  349. result = self.scanner._call_with_retry(success_func, max_retries=3, base_delay=0.01)
  350. self.assertEqual(result, "result")
  351. def test_call_with_retry_eventual_success(self):
  352. """Test call that succeeds after retries."""
  353. call_count = 0
  354. def eventual_success():
  355. nonlocal call_count
  356. call_count += 1
  357. if call_count < 3:
  358. error_response = {
  359. "Error": {
  360. "Code": "Throttling",
  361. "Message": "Rate exceeded"
  362. }
  363. }
  364. raise ClientError(error_response, "TestOperation")
  365. return "success"
  366. result = self.scanner._call_with_retry(
  367. eventual_success,
  368. max_retries=3,
  369. base_delay=0.01,
  370. )
  371. self.assertEqual(result, "success")
  372. self.assertEqual(call_count, 3)
  373. def test_call_with_retry_exhausted(self):
  374. """Test call that exhausts all retries."""
  375. def always_fails():
  376. error_response = {
  377. "Error": {
  378. "Code": "ServiceUnavailable",
  379. "Message": "Service unavailable"
  380. }
  381. }
  382. raise ClientError(error_response, "TestOperation")
  383. with self.assertRaises(ClientError):
  384. self.scanner._call_with_retry(
  385. always_fails,
  386. max_retries=2,
  387. base_delay=0.01,
  388. )
  389. class TestScanResourcesIntegration(unittest.TestCase):
  390. """Integration tests for scan_resources with region filtering."""
  391. @patch('cloudshell_scanner.boto3.Session')
  392. def setUp(self, mock_session):
  393. """Set up test fixtures."""
  394. self.mock_session = MagicMock()
  395. mock_session.return_value = self.mock_session
  396. self.mock_sts = MagicMock()
  397. self.mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
  398. self.mock_ec2 = MagicMock()
  399. self.mock_ec2.describe_regions.return_value = {
  400. "Regions": [
  401. {"RegionName": "us-east-1"},
  402. {"RegionName": "us-west-2"},
  403. {"RegionName": "eu-west-1"},
  404. ]
  405. }
  406. def get_client(service, **kwargs):
  407. if service == "sts":
  408. return self.mock_sts
  409. elif service == "ec2":
  410. return self.mock_ec2
  411. return MagicMock()
  412. self.mock_session.client.side_effect = get_client
  413. self.scanner = CloudShellScanner()
  414. def test_scan_resources_with_specified_regions(self):
  415. """Test scan_resources only scans specified regions."""
  416. # Validates: Requirements 1.3
  417. scanned_regions = set()
  418. def mock_scan_service(account_id, region, service):
  419. if region != "global":
  420. scanned_regions.add(region)
  421. return []
  422. self.scanner._scan_service = mock_scan_service
  423. result = self.scanner.scan_resources(
  424. regions=["us-east-1", "us-west-2"],
  425. services=["vpc"],
  426. )
  427. # Only specified regions should be scanned
  428. self.assertEqual(scanned_regions, {"us-east-1", "us-west-2"})
  429. self.assertEqual(set(result["metadata"]["regions_scanned"]), {"us-east-1", "us-west-2"})
  430. def test_scan_resources_with_no_regions_scans_all(self):
  431. """Test scan_resources scans all regions when none specified."""
  432. # Validates: Requirements 1.4
  433. scanned_regions = set()
  434. def mock_scan_service(account_id, region, service):
  435. if region != "global":
  436. scanned_regions.add(region)
  437. return []
  438. self.scanner._scan_service = mock_scan_service
  439. result = self.scanner.scan_resources(
  440. regions=None,
  441. services=["vpc"],
  442. )
  443. # All available regions should be scanned
  444. self.assertEqual(scanned_regions, {"us-east-1", "us-west-2", "eu-west-1"})
  445. def test_scan_resources_filters_invalid_regions(self):
  446. """Test scan_resources filters out invalid regions."""
  447. scanned_regions = set()
  448. def mock_scan_service(account_id, region, service):
  449. if region != "global":
  450. scanned_regions.add(region)
  451. return []
  452. self.scanner._scan_service = mock_scan_service
  453. result = self.scanner.scan_resources(
  454. regions=["us-east-1", "invalid-region", "us-west-2"],
  455. services=["vpc"],
  456. )
  457. # Invalid region should be filtered out
  458. self.assertNotIn("invalid-region", scanned_regions)
  459. self.assertEqual(scanned_regions, {"us-east-1", "us-west-2"})
  460. if __name__ == "__main__":
  461. unittest.main()
  462. # =========================================================================
  463. # JSON Export Tests (Task 1.7)
  464. # =========================================================================
  465. import json
  466. import os
  467. import tempfile
  468. from datetime import datetime, timezone, date
  469. class TestJsonExport(unittest.TestCase):
  470. """Tests for JSON export functionality (Task 1.7)."""
  471. @patch('cloudshell_scanner.boto3.Session')
  472. def setUp(self, mock_session):
  473. """Set up test fixtures."""
  474. self.mock_session = MagicMock()
  475. mock_session.return_value = self.mock_session
  476. self.mock_sts = MagicMock()
  477. self.mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
  478. self.mock_session.client.return_value = self.mock_sts
  479. self.scanner = CloudShellScanner()
  480. # Create a temporary directory for test files
  481. self.temp_dir = tempfile.mkdtemp()
  482. def tearDown(self):
  483. """Clean up temporary files."""
  484. import shutil
  485. shutil.rmtree(self.temp_dir, ignore_errors=True)
  486. def _create_valid_scan_data(self) -> dict:
  487. """Create a valid scan data structure for testing."""
  488. return {
  489. "metadata": {
  490. "account_id": "123456789012",
  491. "scan_timestamp": "2024-01-15T10:30:00Z",
  492. "regions_scanned": ["us-east-1", "us-west-2"],
  493. "services_scanned": ["vpc", "ec2"],
  494. "scanner_version": "1.0.0",
  495. "total_resources": 5,
  496. "total_errors": 1,
  497. },
  498. "resources": {
  499. "vpc": [
  500. {
  501. "account_id": "123456789012",
  502. "region": "us-east-1",
  503. "service": "vpc",
  504. "resource_type": "VPC",
  505. "resource_id": "vpc-12345",
  506. "name": "main-vpc",
  507. "attributes": {"CIDR": "10.0.0.0/16"},
  508. }
  509. ],
  510. "ec2": [
  511. {
  512. "account_id": "123456789012",
  513. "region": "us-east-1",
  514. "service": "ec2",
  515. "resource_type": "Instance",
  516. "resource_id": "i-12345",
  517. "name": "web-server",
  518. "attributes": {"InstanceType": "t3.micro"},
  519. }
  520. ],
  521. },
  522. "errors": [
  523. {
  524. "service": "rds",
  525. "region": "us-west-2",
  526. "error": "Access denied",
  527. "error_type": "ClientError",
  528. "details": {"error_code": "AccessDenied"},
  529. }
  530. ],
  531. }
  532. def test_export_json_creates_file(self):
  533. """Test that export_json creates a JSON file."""
  534. # Validates: Requirements 1.6
  535. scan_data = self._create_valid_scan_data()
  536. output_path = os.path.join(self.temp_dir, "test_output.json")
  537. self.scanner.export_json(scan_data, output_path)
  538. self.assertTrue(os.path.exists(output_path))
  539. def test_export_json_valid_json_format(self):
  540. """Test that exported file contains valid JSON."""
  541. # Validates: Requirements 2.4
  542. scan_data = self._create_valid_scan_data()
  543. output_path = os.path.join(self.temp_dir, "test_output.json")
  544. self.scanner.export_json(scan_data, output_path)
  545. with open(output_path, "r", encoding="utf-8") as f:
  546. loaded_data = json.load(f)
  547. self.assertIsInstance(loaded_data, dict)
  548. def test_export_json_contains_metadata(self):
  549. """Test that exported JSON contains all required metadata fields."""
  550. # Validates: Requirements 2.1
  551. scan_data = self._create_valid_scan_data()
  552. output_path = os.path.join(self.temp_dir, "test_output.json")
  553. self.scanner.export_json(scan_data, output_path)
  554. with open(output_path, "r", encoding="utf-8") as f:
  555. loaded_data = json.load(f)
  556. metadata = loaded_data["metadata"]
  557. self.assertIn("account_id", metadata)
  558. self.assertIn("scan_timestamp", metadata)
  559. self.assertIn("regions_scanned", metadata)
  560. self.assertIn("services_scanned", metadata)
  561. self.assertIn("scanner_version", metadata)
  562. self.assertIn("total_resources", metadata)
  563. self.assertIn("total_errors", metadata)
  564. def test_export_json_contains_resources(self):
  565. """Test that exported JSON contains resources organized by service."""
  566. # Validates: Requirements 2.2
  567. scan_data = self._create_valid_scan_data()
  568. output_path = os.path.join(self.temp_dir, "test_output.json")
  569. self.scanner.export_json(scan_data, output_path)
  570. with open(output_path, "r", encoding="utf-8") as f:
  571. loaded_data = json.load(f)
  572. self.assertIn("resources", loaded_data)
  573. self.assertIsInstance(loaded_data["resources"], dict)
  574. self.assertIn("vpc", loaded_data["resources"])
  575. self.assertIn("ec2", loaded_data["resources"])
  576. def test_export_json_contains_errors(self):
  577. """Test that exported JSON contains errors field."""
  578. # Validates: Requirements 2.3
  579. scan_data = self._create_valid_scan_data()
  580. output_path = os.path.join(self.temp_dir, "test_output.json")
  581. self.scanner.export_json(scan_data, output_path)
  582. with open(output_path, "r", encoding="utf-8") as f:
  583. loaded_data = json.load(f)
  584. self.assertIn("errors", loaded_data)
  585. self.assertIsInstance(loaded_data["errors"], list)
  586. self.assertEqual(len(loaded_data["errors"]), 1)
  587. def test_export_json_preserves_data_integrity(self):
  588. """Test that exported data matches original data."""
  589. # Validates: Requirements 2.4, 2.5 (round-trip consistency)
  590. scan_data = self._create_valid_scan_data()
  591. output_path = os.path.join(self.temp_dir, "test_output.json")
  592. self.scanner.export_json(scan_data, output_path)
  593. with open(output_path, "r", encoding="utf-8") as f:
  594. loaded_data = json.load(f)
  595. # Check metadata values
  596. self.assertEqual(loaded_data["metadata"]["account_id"], "123456789012")
  597. self.assertEqual(loaded_data["metadata"]["regions_scanned"], ["us-east-1", "us-west-2"])
  598. self.assertEqual(loaded_data["metadata"]["services_scanned"], ["vpc", "ec2"])
  599. # Check resources
  600. self.assertEqual(len(loaded_data["resources"]["vpc"]), 1)
  601. self.assertEqual(loaded_data["resources"]["vpc"][0]["resource_id"], "vpc-12345")
  602. def test_export_json_handles_unicode(self):
  603. """Test that export handles Unicode characters correctly."""
  604. scan_data = self._create_valid_scan_data()
  605. scan_data["resources"]["vpc"][0]["name"] = "测试VPC-日本語"
  606. output_path = os.path.join(self.temp_dir, "test_unicode.json")
  607. self.scanner.export_json(scan_data, output_path)
  608. with open(output_path, "r", encoding="utf-8") as f:
  609. loaded_data = json.load(f)
  610. self.assertEqual(loaded_data["resources"]["vpc"][0]["name"], "测试VPC-日本語")
  611. def test_export_json_raises_on_invalid_path(self):
  612. """Test that export raises error for invalid file path."""
  613. scan_data = self._create_valid_scan_data()
  614. invalid_path = "/nonexistent/directory/output.json"
  615. with self.assertRaises((IOError, OSError)):
  616. self.scanner.export_json(scan_data, invalid_path)
  617. class TestJsonSerializer(unittest.TestCase):
  618. """Tests for the custom JSON serializer."""
  619. @patch('cloudshell_scanner.boto3.Session')
  620. def setUp(self, mock_session):
  621. """Set up test fixtures."""
  622. self.mock_session = MagicMock()
  623. mock_session.return_value = self.mock_session
  624. self.mock_sts = MagicMock()
  625. self.mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
  626. self.mock_session.client.return_value = self.mock_sts
  627. self.scanner = CloudShellScanner()
  628. def test_serializer_handles_datetime(self):
  629. """Test that serializer converts datetime to ISO 8601 format."""
  630. dt = datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone.utc)
  631. result = self.scanner._json_serializer(dt)
  632. self.assertEqual(result, "2024-01-15T10:30:00Z")
  633. def test_serializer_handles_naive_datetime(self):
  634. """Test that serializer handles naive datetime (no timezone)."""
  635. dt = datetime(2024, 1, 15, 10, 30, 0)
  636. result = self.scanner._json_serializer(dt)
  637. # Should add UTC timezone
  638. self.assertIn("2024-01-15T10:30:00", result)
  639. def test_serializer_handles_date(self):
  640. """Test that serializer converts date to ISO format."""
  641. d = date(2024, 1, 15)
  642. result = self.scanner._json_serializer(d)
  643. self.assertEqual(result, "2024-01-15")
  644. def test_serializer_handles_bytes(self):
  645. """Test that serializer converts bytes to string."""
  646. b = b"test bytes"
  647. result = self.scanner._json_serializer(b)
  648. self.assertEqual(result, "test bytes")
  649. def test_serializer_handles_set(self):
  650. """Test that serializer converts set to list."""
  651. s = {"a", "b", "c"}
  652. result = self.scanner._json_serializer(s)
  653. self.assertIsInstance(result, list)
  654. self.assertEqual(set(result), {"a", "b", "c"})
  655. def test_serializer_handles_frozenset(self):
  656. """Test that serializer converts frozenset to list."""
  657. fs = frozenset(["x", "y", "z"])
  658. result = self.scanner._json_serializer(fs)
  659. self.assertIsInstance(result, list)
  660. self.assertEqual(set(result), {"x", "y", "z"})
  661. def test_serializer_fallback_to_string(self):
  662. """Test that serializer falls back to string for unknown types."""
  663. class CustomObject:
  664. __slots__ = [] # No __dict__ attribute
  665. def __str__(self):
  666. return "custom_object_str"
  667. obj = CustomObject()
  668. result = self.scanner._json_serializer(obj)
  669. self.assertEqual(result, "custom_object_str")
  670. def test_serializer_handles_object_with_dict(self):
  671. """Test that serializer converts objects with __dict__ to dict."""
  672. class DataObject:
  673. def __init__(self):
  674. self.name = "test"
  675. self.value = 42
  676. obj = DataObject()
  677. result = self.scanner._json_serializer(obj)
  678. self.assertIsInstance(result, dict)
  679. self.assertEqual(result["name"], "test")
  680. self.assertEqual(result["value"], 42)
  681. class TestValidateScanDataStructure(unittest.TestCase):
  682. """Tests for scan data structure validation."""
  683. @patch('cloudshell_scanner.boto3.Session')
  684. def setUp(self, mock_session):
  685. """Set up test fixtures."""
  686. self.mock_session = MagicMock()
  687. mock_session.return_value = self.mock_session
  688. self.mock_sts = MagicMock()
  689. self.mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
  690. self.mock_session.client.return_value = self.mock_sts
  691. self.scanner = CloudShellScanner()
  692. def _create_valid_scan_data(self) -> dict:
  693. """Create a valid scan data structure."""
  694. return {
  695. "metadata": {
  696. "account_id": "123456789012",
  697. "scan_timestamp": "2024-01-15T10:30:00Z",
  698. "regions_scanned": ["us-east-1"],
  699. "services_scanned": ["vpc"],
  700. "scanner_version": "1.0.0",
  701. "total_resources": 0,
  702. "total_errors": 0,
  703. },
  704. "resources": {},
  705. "errors": [],
  706. }
  707. def test_valid_structure_passes(self):
  708. """Test that valid structure passes validation."""
  709. data = self._create_valid_scan_data()
  710. # Should not raise
  711. self.scanner._validate_scan_data_structure(data)
  712. def test_missing_metadata_raises(self):
  713. """Test that missing metadata field raises ValueError."""
  714. data = self._create_valid_scan_data()
  715. del data["metadata"]
  716. with self.assertRaises(ValueError) as context:
  717. self.scanner._validate_scan_data_structure(data)
  718. self.assertIn("metadata", str(context.exception))
  719. def test_missing_resources_raises(self):
  720. """Test that missing resources field raises ValueError."""
  721. data = self._create_valid_scan_data()
  722. del data["resources"]
  723. with self.assertRaises(ValueError) as context:
  724. self.scanner._validate_scan_data_structure(data)
  725. self.assertIn("resources", str(context.exception))
  726. def test_missing_errors_raises(self):
  727. """Test that missing errors field raises ValueError."""
  728. data = self._create_valid_scan_data()
  729. del data["errors"]
  730. with self.assertRaises(ValueError) as context:
  731. self.scanner._validate_scan_data_structure(data)
  732. self.assertIn("errors", str(context.exception))
  733. def test_missing_metadata_field_raises(self):
  734. """Test that missing metadata sub-field raises ValueError."""
  735. # Validates: Requirements 2.1
  736. data = self._create_valid_scan_data()
  737. del data["metadata"]["account_id"]
  738. with self.assertRaises(ValueError) as context:
  739. self.scanner._validate_scan_data_structure(data)
  740. self.assertIn("account_id", str(context.exception))
  741. def test_invalid_account_id_type_raises(self):
  742. """Test that non-string account_id raises ValueError."""
  743. data = self._create_valid_scan_data()
  744. data["metadata"]["account_id"] = 123456789012 # Should be string
  745. with self.assertRaises(ValueError) as context:
  746. self.scanner._validate_scan_data_structure(data)
  747. self.assertIn("account_id", str(context.exception))
  748. def test_invalid_regions_type_raises(self):
  749. """Test that non-list regions_scanned raises ValueError."""
  750. data = self._create_valid_scan_data()
  751. data["metadata"]["regions_scanned"] = "us-east-1" # Should be list
  752. with self.assertRaises(ValueError) as context:
  753. self.scanner._validate_scan_data_structure(data)
  754. self.assertIn("regions_scanned", str(context.exception))
  755. def test_invalid_resources_type_raises(self):
  756. """Test that non-dict resources raises ValueError."""
  757. data = self._create_valid_scan_data()
  758. data["resources"] = [] # Should be dict
  759. with self.assertRaises(ValueError) as context:
  760. self.scanner._validate_scan_data_structure(data)
  761. self.assertIn("resources", str(context.exception))
  762. def test_invalid_errors_type_raises(self):
  763. """Test that non-list errors raises ValueError."""
  764. data = self._create_valid_scan_data()
  765. data["errors"] = {} # Should be list
  766. with self.assertRaises(ValueError) as context:
  767. self.scanner._validate_scan_data_structure(data)
  768. self.assertIn("errors", str(context.exception))
  769. class TestCreateScanData(unittest.TestCase):
  770. """Tests for the create_scan_data factory method."""
  771. def test_create_scan_data_structure(self):
  772. """Test that create_scan_data creates correct structure."""
  773. # Validates: Requirements 2.1, 2.2, 2.3
  774. resources = {
  775. "vpc": [{"resource_id": "vpc-123"}],
  776. "ec2": [{"resource_id": "i-123"}, {"resource_id": "i-456"}],
  777. }
  778. errors = [{"service": "rds", "error": "Access denied"}]
  779. result = CloudShellScanner.create_scan_data(
  780. account_id="123456789012",
  781. regions_scanned=["us-east-1", "us-west-2"],
  782. services_scanned=["vpc", "ec2", "rds"],
  783. resources=resources,
  784. errors=errors,
  785. )
  786. # Check structure
  787. self.assertIn("metadata", result)
  788. self.assertIn("resources", result)
  789. self.assertIn("errors", result)
  790. # Check metadata
  791. self.assertEqual(result["metadata"]["account_id"], "123456789012")
  792. self.assertEqual(result["metadata"]["regions_scanned"], ["us-east-1", "us-west-2"])
  793. self.assertEqual(result["metadata"]["services_scanned"], ["vpc", "ec2", "rds"])
  794. self.assertEqual(result["metadata"]["total_resources"], 3)
  795. self.assertEqual(result["metadata"]["total_errors"], 1)
  796. def test_create_scan_data_with_custom_timestamp(self):
  797. """Test that create_scan_data accepts custom timestamp."""
  798. custom_timestamp = "2024-01-15T10:30:00Z"
  799. result = CloudShellScanner.create_scan_data(
  800. account_id="123456789012",
  801. regions_scanned=["us-east-1"],
  802. services_scanned=["vpc"],
  803. resources={},
  804. errors=[],
  805. scan_timestamp=custom_timestamp,
  806. )
  807. self.assertEqual(result["metadata"]["scan_timestamp"], custom_timestamp)
  808. def test_create_scan_data_auto_timestamp(self):
  809. """Test that create_scan_data generates timestamp if not provided."""
  810. result = CloudShellScanner.create_scan_data(
  811. account_id="123456789012",
  812. regions_scanned=["us-east-1"],
  813. services_scanned=["vpc"],
  814. resources={},
  815. errors=[],
  816. )
  817. # Should have a timestamp in ISO 8601 format
  818. timestamp = result["metadata"]["scan_timestamp"]
  819. self.assertIsInstance(timestamp, str)
  820. self.assertIn("T", timestamp)
  821. self.assertTrue(timestamp.endswith("Z"))
  822. def test_create_scan_data_includes_version(self):
  823. """Test that create_scan_data includes scanner version."""
  824. result = CloudShellScanner.create_scan_data(
  825. account_id="123456789012",
  826. regions_scanned=["us-east-1"],
  827. services_scanned=["vpc"],
  828. resources={},
  829. errors=[],
  830. )
  831. self.assertIn("scanner_version", result["metadata"])
  832. self.assertIsInstance(result["metadata"]["scanner_version"], str)
  833. class TestLoadScanData(unittest.TestCase):
  834. """Tests for the load_scan_data method."""
  835. def setUp(self):
  836. """Set up test fixtures."""
  837. self.temp_dir = tempfile.mkdtemp()
  838. def tearDown(self):
  839. """Clean up temporary files."""
  840. import shutil
  841. shutil.rmtree(self.temp_dir, ignore_errors=True)
  842. def _create_valid_scan_data(self) -> dict:
  843. """Create a valid scan data structure."""
  844. return {
  845. "metadata": {
  846. "account_id": "123456789012",
  847. "scan_timestamp": "2024-01-15T10:30:00Z",
  848. "regions_scanned": ["us-east-1"],
  849. "services_scanned": ["vpc"],
  850. "scanner_version": "1.0.0",
  851. "total_resources": 0,
  852. "total_errors": 0,
  853. },
  854. "resources": {},
  855. "errors": [],
  856. }
  857. def test_load_scan_data_success(self):
  858. """Test loading valid scan data from file."""
  859. # Validates: Requirements 2.5 (round-trip consistency)
  860. data = self._create_valid_scan_data()
  861. file_path = os.path.join(self.temp_dir, "test_load.json")
  862. with open(file_path, "w", encoding="utf-8") as f:
  863. json.dump(data, f)
  864. loaded = CloudShellScanner.load_scan_data(file_path)
  865. self.assertEqual(loaded["metadata"]["account_id"], "123456789012")
  866. def test_load_scan_data_file_not_found(self):
  867. """Test that loading non-existent file raises FileNotFoundError."""
  868. with self.assertRaises(FileNotFoundError):
  869. CloudShellScanner.load_scan_data("/nonexistent/file.json")
  870. def test_load_scan_data_invalid_json(self):
  871. """Test that loading invalid JSON raises JSONDecodeError."""
  872. file_path = os.path.join(self.temp_dir, "invalid.json")
  873. with open(file_path, "w") as f:
  874. f.write("not valid json {{{")
  875. with self.assertRaises(json.JSONDecodeError):
  876. CloudShellScanner.load_scan_data(file_path)
  877. def test_load_scan_data_invalid_structure(self):
  878. """Test that loading JSON with invalid structure raises ValueError."""
  879. file_path = os.path.join(self.temp_dir, "invalid_structure.json")
  880. with open(file_path, "w") as f:
  881. json.dump({"invalid": "structure"}, f)
  882. with self.assertRaises(ValueError):
  883. CloudShellScanner.load_scan_data(file_path)
  884. class TestJsonRoundTrip(unittest.TestCase):
  885. """Tests for JSON round-trip consistency (Property 1)."""
  886. @patch('cloudshell_scanner.boto3.Session')
  887. def setUp(self, mock_session):
  888. """Set up test fixtures."""
  889. self.mock_session = MagicMock()
  890. mock_session.return_value = self.mock_session
  891. self.mock_sts = MagicMock()
  892. self.mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
  893. self.mock_session.client.return_value = self.mock_sts
  894. self.scanner = CloudShellScanner()
  895. self.temp_dir = tempfile.mkdtemp()
  896. def tearDown(self):
  897. """Clean up temporary files."""
  898. import shutil
  899. shutil.rmtree(self.temp_dir, ignore_errors=True)
  900. def test_round_trip_preserves_data(self):
  901. """Test that export and load preserves all data."""
  902. # Validates: Requirements 2.4, 2.5 (Property 1: JSON round-trip consistency)
  903. original_data = {
  904. "metadata": {
  905. "account_id": "123456789012",
  906. "scan_timestamp": "2024-01-15T10:30:00Z",
  907. "regions_scanned": ["us-east-1", "us-west-2", "eu-west-1"],
  908. "services_scanned": ["vpc", "ec2", "rds", "s3"],
  909. "scanner_version": "1.0.0",
  910. "total_resources": 10,
  911. "total_errors": 2,
  912. },
  913. "resources": {
  914. "vpc": [
  915. {
  916. "account_id": "123456789012",
  917. "region": "us-east-1",
  918. "service": "vpc",
  919. "resource_type": "VPC",
  920. "resource_id": "vpc-12345",
  921. "name": "main-vpc",
  922. "attributes": {
  923. "CIDR": "10.0.0.0/16",
  924. "IsDefault": False,
  925. "Tags": [{"Key": "Name", "Value": "main-vpc"}],
  926. },
  927. }
  928. ],
  929. "ec2": [
  930. {
  931. "account_id": "123456789012",
  932. "region": "us-east-1",
  933. "service": "ec2",
  934. "resource_type": "Instance",
  935. "resource_id": "i-12345",
  936. "name": "web-server",
  937. "attributes": {
  938. "InstanceType": "t3.micro",
  939. "State": "running",
  940. },
  941. }
  942. ],
  943. },
  944. "errors": [
  945. {
  946. "service": "rds",
  947. "region": "us-west-2",
  948. "error": "Access denied",
  949. "error_type": "ClientError",
  950. "details": {
  951. "error_code": "AccessDenied",
  952. "error_message": "User is not authorized",
  953. },
  954. }
  955. ],
  956. }
  957. file_path = os.path.join(self.temp_dir, "round_trip.json")
  958. # Export
  959. self.scanner.export_json(original_data, file_path)
  960. # Load
  961. loaded_data = CloudShellScanner.load_scan_data(file_path)
  962. # Verify all data is preserved
  963. self.assertEqual(loaded_data["metadata"], original_data["metadata"])
  964. self.assertEqual(loaded_data["resources"], original_data["resources"])
  965. self.assertEqual(loaded_data["errors"], original_data["errors"])
  966. def test_round_trip_with_empty_resources(self):
  967. """Test round-trip with empty resources."""
  968. original_data = {
  969. "metadata": {
  970. "account_id": "123456789012",
  971. "scan_timestamp": "2024-01-15T10:30:00Z",
  972. "regions_scanned": [],
  973. "services_scanned": [],
  974. "scanner_version": "1.0.0",
  975. "total_resources": 0,
  976. "total_errors": 0,
  977. },
  978. "resources": {},
  979. "errors": [],
  980. }
  981. file_path = os.path.join(self.temp_dir, "empty_round_trip.json")
  982. self.scanner.export_json(original_data, file_path)
  983. loaded_data = CloudShellScanner.load_scan_data(file_path)
  984. self.assertEqual(loaded_data, original_data)
  985. def test_round_trip_with_special_characters(self):
  986. """Test round-trip with special characters in data."""
  987. original_data = {
  988. "metadata": {
  989. "account_id": "123456789012",
  990. "scan_timestamp": "2024-01-15T10:30:00Z",
  991. "regions_scanned": ["us-east-1"],
  992. "services_scanned": ["vpc"],
  993. "scanner_version": "1.0.0",
  994. "total_resources": 1,
  995. "total_errors": 0,
  996. },
  997. "resources": {
  998. "vpc": [
  999. {
  1000. "account_id": "123456789012",
  1001. "region": "us-east-1",
  1002. "service": "vpc",
  1003. "resource_type": "VPC",
  1004. "resource_id": "vpc-12345",
  1005. "name": "测试VPC-日本語-émoji-🚀",
  1006. "attributes": {
  1007. "Description": "Special chars: <>&\"'",
  1008. },
  1009. }
  1010. ],
  1011. },
  1012. "errors": [],
  1013. }
  1014. file_path = os.path.join(self.temp_dir, "special_chars.json")
  1015. self.scanner.export_json(original_data, file_path)
  1016. loaded_data = CloudShellScanner.load_scan_data(file_path)
  1017. self.assertEqual(
  1018. loaded_data["resources"]["vpc"][0]["name"],
  1019. "测试VPC-日本語-émoji-🚀"
  1020. )