#!/usr/bin/env python3 """ Unit tests for CloudShell Scanner - Region Filtering and Error Handling Tests for Task 1.6: - Region list retrieval and filtering logic - Error capture and continue scanning logic - Retry mechanism with exponential backoff Requirements tested: - 1.3: Scan only specified regions when provided - 1.4: Scan all available regions when not specified - 1.8: Record errors and continue scanning other resources """ import time import unittest from unittest.mock import MagicMock, patch, PropertyMock import pytest from botocore.exceptions import ClientError, BotoCoreError # Import the module under test from cloudshell_scanner import ( CloudShellScanner, retry_with_exponential_backoff, is_retryable_error, RETRYABLE_ERROR_CODES, RETRYABLE_EXCEPTIONS, ) class TestRetryWithExponentialBackoff(unittest.TestCase): """Tests for the retry_with_exponential_backoff decorator.""" def test_successful_call_no_retry(self): """Test that successful calls don't trigger retries.""" call_count = 0 @retry_with_exponential_backoff(max_retries=3, base_delay=0.01) def successful_func(): nonlocal call_count call_count += 1 return "success" result = successful_func() self.assertEqual(result, "success") self.assertEqual(call_count, 1) def test_retry_on_throttling_error(self): """Test that throttling errors trigger retries.""" call_count = 0 @retry_with_exponential_backoff(max_retries=2, base_delay=0.01) def throttled_func(): nonlocal call_count call_count += 1 if call_count < 3: error_response = { "Error": { "Code": "Throttling", "Message": "Rate exceeded" } } raise ClientError(error_response, "TestOperation") return "success" result = throttled_func() self.assertEqual(result, "success") self.assertEqual(call_count, 3) def test_no_retry_on_non_retryable_error(self): """Test that non-retryable errors are raised immediately.""" call_count = 0 @retry_with_exponential_backoff(max_retries=3, base_delay=0.01) def access_denied_func(): nonlocal call_count call_count += 1 error_response = { "Error": { "Code": "AccessDenied", "Message": "Access Denied" } } raise ClientError(error_response, "TestOperation") with self.assertRaises(ClientError): access_denied_func() # Should only be called once since AccessDenied is not retryable self.assertEqual(call_count, 1) def test_max_retries_exhausted(self): """Test that exception is raised after max retries.""" call_count = 0 @retry_with_exponential_backoff(max_retries=2, base_delay=0.01) def always_fails(): nonlocal call_count call_count += 1 error_response = { "Error": { "Code": "ServiceUnavailable", "Message": "Service unavailable" } } raise ClientError(error_response, "TestOperation") with self.assertRaises(ClientError): always_fails() # Should be called max_retries + 1 times self.assertEqual(call_count, 3) def test_exponential_backoff_timing(self): """Test that delays increase exponentially.""" call_times = [] @retry_with_exponential_backoff(max_retries=2, base_delay=0.1, exponential_base=2.0) def timed_func(): call_times.append(time.time()) if len(call_times) < 3: error_response = { "Error": { "Code": "Throttling", "Message": "Rate exceeded" } } raise ClientError(error_response, "TestOperation") return "success" timed_func() # Check that delays are approximately exponential # First delay should be ~0.1s, second should be ~0.2s if len(call_times) >= 2: first_delay = call_times[1] - call_times[0] self.assertGreater(first_delay, 0.05) # At least half the base delay if len(call_times) >= 3: second_delay = call_times[2] - call_times[1] self.assertGreater(second_delay, first_delay * 0.8) # Second delay should be larger class TestIsRetryableError(unittest.TestCase): """Tests for the is_retryable_error function.""" def test_throttling_is_retryable(self): """Test that throttling errors are retryable.""" error_response = { "Error": { "Code": "Throttling", "Message": "Rate exceeded" } } error = ClientError(error_response, "TestOperation") self.assertTrue(is_retryable_error(error)) def test_service_unavailable_is_retryable(self): """Test that service unavailable errors are retryable.""" error_response = { "Error": { "Code": "ServiceUnavailable", "Message": "Service unavailable" } } error = ClientError(error_response, "TestOperation") self.assertTrue(is_retryable_error(error)) def test_access_denied_not_retryable(self): """Test that access denied errors are not retryable.""" error_response = { "Error": { "Code": "AccessDenied", "Message": "Access Denied" } } error = ClientError(error_response, "TestOperation") self.assertFalse(is_retryable_error(error)) def test_connection_error_is_retryable(self): """Test that connection errors are retryable.""" error = ConnectionError("Connection refused") self.assertTrue(is_retryable_error(error)) def test_timeout_error_is_retryable(self): """Test that timeout errors are retryable.""" error = TimeoutError("Request timed out") self.assertTrue(is_retryable_error(error)) class TestRegionFiltering(unittest.TestCase): """Tests for region filtering functionality.""" @patch('cloudshell_scanner.boto3.Session') def setUp(self, mock_session): """Set up test fixtures.""" # Mock the boto3 session self.mock_session = MagicMock() mock_session.return_value = self.mock_session # Mock STS client for get_account_id self.mock_sts = MagicMock() self.mock_sts.get_caller_identity.return_value = {"Account": "123456789012"} # Mock EC2 client for list_regions self.mock_ec2 = MagicMock() self.mock_ec2.describe_regions.return_value = { "Regions": [ {"RegionName": "us-east-1"}, {"RegionName": "us-west-2"}, {"RegionName": "eu-west-1"}, {"RegionName": "ap-northeast-1"}, ] } def get_client(service, **kwargs): if service == "sts": return self.mock_sts elif service == "ec2": return self.mock_ec2 return MagicMock() self.mock_session.client.side_effect = get_client self.scanner = CloudShellScanner() def test_list_regions_returns_available_regions(self): """Test that list_regions returns available regions from AWS.""" regions = self.scanner.list_regions() self.assertEqual(len(regions), 4) self.assertIn("us-east-1", regions) self.assertIn("us-west-2", regions) self.assertIn("eu-west-1", regions) self.assertIn("ap-northeast-1", regions) def test_list_regions_fallback_on_error(self): """Test that list_regions falls back to defaults on error.""" self.mock_ec2.describe_regions.side_effect = Exception("API Error") regions = self.scanner.list_regions() # Should return default regions self.assertIn("us-east-1", regions) self.assertIn("us-west-2", regions) self.assertGreater(len(regions), 0) def test_filter_regions_with_valid_regions(self): """Test filtering with valid regions.""" # Validates: Requirements 1.3 requested = ["us-east-1", "us-west-2"] filtered = self.scanner.filter_regions(requested) self.assertEqual(len(filtered), 2) self.assertIn("us-east-1", filtered) self.assertIn("us-west-2", filtered) def test_filter_regions_with_invalid_regions(self): """Test filtering removes invalid regions.""" # Validates: Requirements 1.3 requested = ["us-east-1", "invalid-region", "us-west-2"] filtered = self.scanner.filter_regions(requested) self.assertEqual(len(filtered), 2) self.assertIn("us-east-1", filtered) self.assertIn("us-west-2", filtered) self.assertNotIn("invalid-region", filtered) def test_filter_regions_none_returns_all(self): """Test that None returns all available regions.""" # Validates: Requirements 1.4 filtered = self.scanner.filter_regions(None) self.assertEqual(len(filtered), 4) def test_filter_regions_all_invalid_falls_back(self): """Test that all invalid regions falls back to all available.""" requested = ["invalid-1", "invalid-2"] filtered = self.scanner.filter_regions(requested) # Should fall back to all available regions self.assertEqual(len(filtered), 4) def test_filter_regions_normalizes_input(self): """Test that region names are normalized (whitespace, case).""" requested = [" US-EAST-1 ", "us-west-2"] filtered = self.scanner.filter_regions(requested) self.assertEqual(len(filtered), 2) self.assertIn("us-east-1", filtered) self.assertIn("us-west-2", filtered) def test_validate_region_valid(self): """Test validate_region with valid region.""" self.assertTrue(self.scanner.validate_region("us-east-1")) def test_validate_region_invalid(self): """Test validate_region with invalid region.""" self.assertFalse(self.scanner.validate_region("invalid-region")) class TestErrorHandling(unittest.TestCase): """Tests for error handling functionality.""" @patch('cloudshell_scanner.boto3.Session') def setUp(self, mock_session): """Set up test fixtures.""" self.mock_session = MagicMock() mock_session.return_value = self.mock_session # Mock STS client self.mock_sts = MagicMock() self.mock_sts.get_caller_identity.return_value = {"Account": "123456789012"} # Mock EC2 client self.mock_ec2 = MagicMock() self.mock_ec2.describe_regions.return_value = { "Regions": [{"RegionName": "us-east-1"}] } def get_client(service, **kwargs): if service == "sts": return self.mock_sts elif service == "ec2": return self.mock_ec2 return MagicMock() self.mock_session.client.side_effect = get_client self.scanner = CloudShellScanner() def test_create_error_info_client_error(self): """Test error info creation for ClientError.""" error_response = { "Error": { "Code": "AccessDenied", "Message": "User is not authorized" } } exception = ClientError(error_response, "DescribeInstances") error_info = self.scanner._create_error_info( service="ec2", region="us-east-1", exception=exception, ) self.assertEqual(error_info["service"], "ec2") self.assertEqual(error_info["region"], "us-east-1") self.assertEqual(error_info["error_type"], "ClientError") self.assertIsNotNone(error_info["details"]) self.assertEqual(error_info["details"]["error_code"], "AccessDenied") self.assertIn("permission_hint", error_info["details"]) def test_create_error_info_generic_exception(self): """Test error info creation for generic exceptions.""" exception = ValueError("Invalid value") error_info = self.scanner._create_error_info( service="vpc", region="eu-west-1", exception=exception, ) self.assertEqual(error_info["service"], "vpc") self.assertEqual(error_info["region"], "eu-west-1") self.assertEqual(error_info["error_type"], "ValueError") self.assertIn("Invalid value", error_info["error"]) def test_scan_continues_after_error(self): """Test that scanning continues after encountering an error.""" # Validates: Requirements 1.8 # Mock _scan_service to fail for one service but succeed for another call_count = {"vpc": 0, "ec2": 0} def mock_scan_service(account_id, region, service): call_count[service] = call_count.get(service, 0) + 1 if service == "vpc": raise Exception("VPC scan failed") return [{"resource_id": "i-123", "service": service}] self.scanner._scan_service = mock_scan_service # Scan with both services result = self.scanner.scan_resources( regions=["us-east-1"], services=["vpc", "ec2"], ) # Both services should have been attempted self.assertEqual(call_count["vpc"], 1) self.assertEqual(call_count["ec2"], 1) # Should have one error and one successful resource self.assertEqual(len(result["errors"]), 1) self.assertEqual(result["errors"][0]["service"], "vpc") self.assertIn("ec2", result["resources"]) def test_error_info_includes_region(self): """Test that error info includes the correct region.""" # Validates: Requirements 1.8 def mock_scan_service(account_id, region, service): raise Exception(f"Error in {region}") self.scanner._scan_service = mock_scan_service result = self.scanner.scan_resources( regions=["us-east-1"], services=["vpc"], ) self.assertEqual(len(result["errors"]), 1) self.assertEqual(result["errors"][0]["region"], "us-east-1") class TestCallWithRetry(unittest.TestCase): """Tests for the _call_with_retry method.""" @patch('cloudshell_scanner.boto3.Session') def setUp(self, mock_session): """Set up test fixtures.""" self.mock_session = MagicMock() mock_session.return_value = self.mock_session self.mock_sts = MagicMock() self.mock_sts.get_caller_identity.return_value = {"Account": "123456789012"} self.mock_session.client.return_value = self.mock_sts self.scanner = CloudShellScanner() def test_call_with_retry_success(self): """Test successful call without retries.""" def success_func(): return "result" result = self.scanner._call_with_retry(success_func, max_retries=3, base_delay=0.01) self.assertEqual(result, "result") def test_call_with_retry_eventual_success(self): """Test call that succeeds after retries.""" call_count = 0 def eventual_success(): nonlocal call_count call_count += 1 if call_count < 3: error_response = { "Error": { "Code": "Throttling", "Message": "Rate exceeded" } } raise ClientError(error_response, "TestOperation") return "success" result = self.scanner._call_with_retry( eventual_success, max_retries=3, base_delay=0.01, ) self.assertEqual(result, "success") self.assertEqual(call_count, 3) def test_call_with_retry_exhausted(self): """Test call that exhausts all retries.""" def always_fails(): error_response = { "Error": { "Code": "ServiceUnavailable", "Message": "Service unavailable" } } raise ClientError(error_response, "TestOperation") with self.assertRaises(ClientError): self.scanner._call_with_retry( always_fails, max_retries=2, base_delay=0.01, ) class TestScanResourcesIntegration(unittest.TestCase): """Integration tests for scan_resources with region filtering.""" @patch('cloudshell_scanner.boto3.Session') def setUp(self, mock_session): """Set up test fixtures.""" self.mock_session = MagicMock() mock_session.return_value = self.mock_session self.mock_sts = MagicMock() self.mock_sts.get_caller_identity.return_value = {"Account": "123456789012"} self.mock_ec2 = MagicMock() self.mock_ec2.describe_regions.return_value = { "Regions": [ {"RegionName": "us-east-1"}, {"RegionName": "us-west-2"}, {"RegionName": "eu-west-1"}, ] } def get_client(service, **kwargs): if service == "sts": return self.mock_sts elif service == "ec2": return self.mock_ec2 return MagicMock() self.mock_session.client.side_effect = get_client self.scanner = CloudShellScanner() def test_scan_resources_with_specified_regions(self): """Test scan_resources only scans specified regions.""" # Validates: Requirements 1.3 scanned_regions = set() def mock_scan_service(account_id, region, service): if region != "global": scanned_regions.add(region) return [] self.scanner._scan_service = mock_scan_service result = self.scanner.scan_resources( regions=["us-east-1", "us-west-2"], services=["vpc"], ) # Only specified regions should be scanned self.assertEqual(scanned_regions, {"us-east-1", "us-west-2"}) self.assertEqual(set(result["metadata"]["regions_scanned"]), {"us-east-1", "us-west-2"}) def test_scan_resources_with_no_regions_scans_all(self): """Test scan_resources scans all regions when none specified.""" # Validates: Requirements 1.4 scanned_regions = set() def mock_scan_service(account_id, region, service): if region != "global": scanned_regions.add(region) return [] self.scanner._scan_service = mock_scan_service result = self.scanner.scan_resources( regions=None, services=["vpc"], ) # All available regions should be scanned self.assertEqual(scanned_regions, {"us-east-1", "us-west-2", "eu-west-1"}) def test_scan_resources_filters_invalid_regions(self): """Test scan_resources filters out invalid regions.""" scanned_regions = set() def mock_scan_service(account_id, region, service): if region != "global": scanned_regions.add(region) return [] self.scanner._scan_service = mock_scan_service result = self.scanner.scan_resources( regions=["us-east-1", "invalid-region", "us-west-2"], services=["vpc"], ) # Invalid region should be filtered out self.assertNotIn("invalid-region", scanned_regions) self.assertEqual(scanned_regions, {"us-east-1", "us-west-2"}) if __name__ == "__main__": unittest.main() # ========================================================================= # JSON Export Tests (Task 1.7) # ========================================================================= import json import os import tempfile from datetime import datetime, timezone, date class TestJsonExport(unittest.TestCase): """Tests for JSON export functionality (Task 1.7).""" @patch('cloudshell_scanner.boto3.Session') def setUp(self, mock_session): """Set up test fixtures.""" self.mock_session = MagicMock() mock_session.return_value = self.mock_session self.mock_sts = MagicMock() self.mock_sts.get_caller_identity.return_value = {"Account": "123456789012"} self.mock_session.client.return_value = self.mock_sts self.scanner = CloudShellScanner() # Create a temporary directory for test files self.temp_dir = tempfile.mkdtemp() def tearDown(self): """Clean up temporary files.""" import shutil shutil.rmtree(self.temp_dir, ignore_errors=True) def _create_valid_scan_data(self) -> dict: """Create a valid scan data structure for testing.""" return { "metadata": { "account_id": "123456789012", "scan_timestamp": "2024-01-15T10:30:00Z", "regions_scanned": ["us-east-1", "us-west-2"], "services_scanned": ["vpc", "ec2"], "scanner_version": "1.0.0", "total_resources": 5, "total_errors": 1, }, "resources": { "vpc": [ { "account_id": "123456789012", "region": "us-east-1", "service": "vpc", "resource_type": "VPC", "resource_id": "vpc-12345", "name": "main-vpc", "attributes": {"CIDR": "10.0.0.0/16"}, } ], "ec2": [ { "account_id": "123456789012", "region": "us-east-1", "service": "ec2", "resource_type": "Instance", "resource_id": "i-12345", "name": "web-server", "attributes": {"InstanceType": "t3.micro"}, } ], }, "errors": [ { "service": "rds", "region": "us-west-2", "error": "Access denied", "error_type": "ClientError", "details": {"error_code": "AccessDenied"}, } ], } def test_export_json_creates_file(self): """Test that export_json creates a JSON file.""" # Validates: Requirements 1.6 scan_data = self._create_valid_scan_data() output_path = os.path.join(self.temp_dir, "test_output.json") self.scanner.export_json(scan_data, output_path) self.assertTrue(os.path.exists(output_path)) def test_export_json_valid_json_format(self): """Test that exported file contains valid JSON.""" # Validates: Requirements 2.4 scan_data = self._create_valid_scan_data() output_path = os.path.join(self.temp_dir, "test_output.json") self.scanner.export_json(scan_data, output_path) with open(output_path, "r", encoding="utf-8") as f: loaded_data = json.load(f) self.assertIsInstance(loaded_data, dict) def test_export_json_contains_metadata(self): """Test that exported JSON contains all required metadata fields.""" # Validates: Requirements 2.1 scan_data = self._create_valid_scan_data() output_path = os.path.join(self.temp_dir, "test_output.json") self.scanner.export_json(scan_data, output_path) with open(output_path, "r", encoding="utf-8") as f: loaded_data = json.load(f) metadata = loaded_data["metadata"] self.assertIn("account_id", metadata) self.assertIn("scan_timestamp", metadata) self.assertIn("regions_scanned", metadata) self.assertIn("services_scanned", metadata) self.assertIn("scanner_version", metadata) self.assertIn("total_resources", metadata) self.assertIn("total_errors", metadata) def test_export_json_contains_resources(self): """Test that exported JSON contains resources organized by service.""" # Validates: Requirements 2.2 scan_data = self._create_valid_scan_data() output_path = os.path.join(self.temp_dir, "test_output.json") self.scanner.export_json(scan_data, output_path) with open(output_path, "r", encoding="utf-8") as f: loaded_data = json.load(f) self.assertIn("resources", loaded_data) self.assertIsInstance(loaded_data["resources"], dict) self.assertIn("vpc", loaded_data["resources"]) self.assertIn("ec2", loaded_data["resources"]) def test_export_json_contains_errors(self): """Test that exported JSON contains errors field.""" # Validates: Requirements 2.3 scan_data = self._create_valid_scan_data() output_path = os.path.join(self.temp_dir, "test_output.json") self.scanner.export_json(scan_data, output_path) with open(output_path, "r", encoding="utf-8") as f: loaded_data = json.load(f) self.assertIn("errors", loaded_data) self.assertIsInstance(loaded_data["errors"], list) self.assertEqual(len(loaded_data["errors"]), 1) def test_export_json_preserves_data_integrity(self): """Test that exported data matches original data.""" # Validates: Requirements 2.4, 2.5 (round-trip consistency) scan_data = self._create_valid_scan_data() output_path = os.path.join(self.temp_dir, "test_output.json") self.scanner.export_json(scan_data, output_path) with open(output_path, "r", encoding="utf-8") as f: loaded_data = json.load(f) # Check metadata values self.assertEqual(loaded_data["metadata"]["account_id"], "123456789012") self.assertEqual(loaded_data["metadata"]["regions_scanned"], ["us-east-1", "us-west-2"]) self.assertEqual(loaded_data["metadata"]["services_scanned"], ["vpc", "ec2"]) # Check resources self.assertEqual(len(loaded_data["resources"]["vpc"]), 1) self.assertEqual(loaded_data["resources"]["vpc"][0]["resource_id"], "vpc-12345") def test_export_json_handles_unicode(self): """Test that export handles Unicode characters correctly.""" scan_data = self._create_valid_scan_data() scan_data["resources"]["vpc"][0]["name"] = "测试VPC-日本語" output_path = os.path.join(self.temp_dir, "test_unicode.json") self.scanner.export_json(scan_data, output_path) with open(output_path, "r", encoding="utf-8") as f: loaded_data = json.load(f) self.assertEqual(loaded_data["resources"]["vpc"][0]["name"], "测试VPC-日本語") def test_export_json_raises_on_invalid_path(self): """Test that export raises error for invalid file path.""" scan_data = self._create_valid_scan_data() invalid_path = "/nonexistent/directory/output.json" with self.assertRaises((IOError, OSError)): self.scanner.export_json(scan_data, invalid_path) class TestJsonSerializer(unittest.TestCase): """Tests for the custom JSON serializer.""" @patch('cloudshell_scanner.boto3.Session') def setUp(self, mock_session): """Set up test fixtures.""" self.mock_session = MagicMock() mock_session.return_value = self.mock_session self.mock_sts = MagicMock() self.mock_sts.get_caller_identity.return_value = {"Account": "123456789012"} self.mock_session.client.return_value = self.mock_sts self.scanner = CloudShellScanner() def test_serializer_handles_datetime(self): """Test that serializer converts datetime to ISO 8601 format.""" dt = datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone.utc) result = self.scanner._json_serializer(dt) self.assertEqual(result, "2024-01-15T10:30:00Z") def test_serializer_handles_naive_datetime(self): """Test that serializer handles naive datetime (no timezone).""" dt = datetime(2024, 1, 15, 10, 30, 0) result = self.scanner._json_serializer(dt) # Should add UTC timezone self.assertIn("2024-01-15T10:30:00", result) def test_serializer_handles_date(self): """Test that serializer converts date to ISO format.""" d = date(2024, 1, 15) result = self.scanner._json_serializer(d) self.assertEqual(result, "2024-01-15") def test_serializer_handles_bytes(self): """Test that serializer converts bytes to string.""" b = b"test bytes" result = self.scanner._json_serializer(b) self.assertEqual(result, "test bytes") def test_serializer_handles_set(self): """Test that serializer converts set to list.""" s = {"a", "b", "c"} result = self.scanner._json_serializer(s) self.assertIsInstance(result, list) self.assertEqual(set(result), {"a", "b", "c"}) def test_serializer_handles_frozenset(self): """Test that serializer converts frozenset to list.""" fs = frozenset(["x", "y", "z"]) result = self.scanner._json_serializer(fs) self.assertIsInstance(result, list) self.assertEqual(set(result), {"x", "y", "z"}) def test_serializer_fallback_to_string(self): """Test that serializer falls back to string for unknown types.""" class CustomObject: __slots__ = [] # No __dict__ attribute def __str__(self): return "custom_object_str" obj = CustomObject() result = self.scanner._json_serializer(obj) self.assertEqual(result, "custom_object_str") def test_serializer_handles_object_with_dict(self): """Test that serializer converts objects with __dict__ to dict.""" class DataObject: def __init__(self): self.name = "test" self.value = 42 obj = DataObject() result = self.scanner._json_serializer(obj) self.assertIsInstance(result, dict) self.assertEqual(result["name"], "test") self.assertEqual(result["value"], 42) class TestValidateScanDataStructure(unittest.TestCase): """Tests for scan data structure validation.""" @patch('cloudshell_scanner.boto3.Session') def setUp(self, mock_session): """Set up test fixtures.""" self.mock_session = MagicMock() mock_session.return_value = self.mock_session self.mock_sts = MagicMock() self.mock_sts.get_caller_identity.return_value = {"Account": "123456789012"} self.mock_session.client.return_value = self.mock_sts self.scanner = CloudShellScanner() def _create_valid_scan_data(self) -> dict: """Create a valid scan data structure.""" return { "metadata": { "account_id": "123456789012", "scan_timestamp": "2024-01-15T10:30:00Z", "regions_scanned": ["us-east-1"], "services_scanned": ["vpc"], "scanner_version": "1.0.0", "total_resources": 0, "total_errors": 0, }, "resources": {}, "errors": [], } def test_valid_structure_passes(self): """Test that valid structure passes validation.""" data = self._create_valid_scan_data() # Should not raise self.scanner._validate_scan_data_structure(data) def test_missing_metadata_raises(self): """Test that missing metadata field raises ValueError.""" data = self._create_valid_scan_data() del data["metadata"] with self.assertRaises(ValueError) as context: self.scanner._validate_scan_data_structure(data) self.assertIn("metadata", str(context.exception)) def test_missing_resources_raises(self): """Test that missing resources field raises ValueError.""" data = self._create_valid_scan_data() del data["resources"] with self.assertRaises(ValueError) as context: self.scanner._validate_scan_data_structure(data) self.assertIn("resources", str(context.exception)) def test_missing_errors_raises(self): """Test that missing errors field raises ValueError.""" data = self._create_valid_scan_data() del data["errors"] with self.assertRaises(ValueError) as context: self.scanner._validate_scan_data_structure(data) self.assertIn("errors", str(context.exception)) def test_missing_metadata_field_raises(self): """Test that missing metadata sub-field raises ValueError.""" # Validates: Requirements 2.1 data = self._create_valid_scan_data() del data["metadata"]["account_id"] with self.assertRaises(ValueError) as context: self.scanner._validate_scan_data_structure(data) self.assertIn("account_id", str(context.exception)) def test_invalid_account_id_type_raises(self): """Test that non-string account_id raises ValueError.""" data = self._create_valid_scan_data() data["metadata"]["account_id"] = 123456789012 # Should be string with self.assertRaises(ValueError) as context: self.scanner._validate_scan_data_structure(data) self.assertIn("account_id", str(context.exception)) def test_invalid_regions_type_raises(self): """Test that non-list regions_scanned raises ValueError.""" data = self._create_valid_scan_data() data["metadata"]["regions_scanned"] = "us-east-1" # Should be list with self.assertRaises(ValueError) as context: self.scanner._validate_scan_data_structure(data) self.assertIn("regions_scanned", str(context.exception)) def test_invalid_resources_type_raises(self): """Test that non-dict resources raises ValueError.""" data = self._create_valid_scan_data() data["resources"] = [] # Should be dict with self.assertRaises(ValueError) as context: self.scanner._validate_scan_data_structure(data) self.assertIn("resources", str(context.exception)) def test_invalid_errors_type_raises(self): """Test that non-list errors raises ValueError.""" data = self._create_valid_scan_data() data["errors"] = {} # Should be list with self.assertRaises(ValueError) as context: self.scanner._validate_scan_data_structure(data) self.assertIn("errors", str(context.exception)) class TestCreateScanData(unittest.TestCase): """Tests for the create_scan_data factory method.""" def test_create_scan_data_structure(self): """Test that create_scan_data creates correct structure.""" # Validates: Requirements 2.1, 2.2, 2.3 resources = { "vpc": [{"resource_id": "vpc-123"}], "ec2": [{"resource_id": "i-123"}, {"resource_id": "i-456"}], } errors = [{"service": "rds", "error": "Access denied"}] result = CloudShellScanner.create_scan_data( account_id="123456789012", regions_scanned=["us-east-1", "us-west-2"], services_scanned=["vpc", "ec2", "rds"], resources=resources, errors=errors, ) # Check structure self.assertIn("metadata", result) self.assertIn("resources", result) self.assertIn("errors", result) # Check metadata self.assertEqual(result["metadata"]["account_id"], "123456789012") self.assertEqual(result["metadata"]["regions_scanned"], ["us-east-1", "us-west-2"]) self.assertEqual(result["metadata"]["services_scanned"], ["vpc", "ec2", "rds"]) self.assertEqual(result["metadata"]["total_resources"], 3) self.assertEqual(result["metadata"]["total_errors"], 1) def test_create_scan_data_with_custom_timestamp(self): """Test that create_scan_data accepts custom timestamp.""" custom_timestamp = "2024-01-15T10:30:00Z" result = CloudShellScanner.create_scan_data( account_id="123456789012", regions_scanned=["us-east-1"], services_scanned=["vpc"], resources={}, errors=[], scan_timestamp=custom_timestamp, ) self.assertEqual(result["metadata"]["scan_timestamp"], custom_timestamp) def test_create_scan_data_auto_timestamp(self): """Test that create_scan_data generates timestamp if not provided.""" result = CloudShellScanner.create_scan_data( account_id="123456789012", regions_scanned=["us-east-1"], services_scanned=["vpc"], resources={}, errors=[], ) # Should have a timestamp in ISO 8601 format timestamp = result["metadata"]["scan_timestamp"] self.assertIsInstance(timestamp, str) self.assertIn("T", timestamp) self.assertTrue(timestamp.endswith("Z")) def test_create_scan_data_includes_version(self): """Test that create_scan_data includes scanner version.""" result = CloudShellScanner.create_scan_data( account_id="123456789012", regions_scanned=["us-east-1"], services_scanned=["vpc"], resources={}, errors=[], ) self.assertIn("scanner_version", result["metadata"]) self.assertIsInstance(result["metadata"]["scanner_version"], str) class TestLoadScanData(unittest.TestCase): """Tests for the load_scan_data method.""" def setUp(self): """Set up test fixtures.""" self.temp_dir = tempfile.mkdtemp() def tearDown(self): """Clean up temporary files.""" import shutil shutil.rmtree(self.temp_dir, ignore_errors=True) def _create_valid_scan_data(self) -> dict: """Create a valid scan data structure.""" return { "metadata": { "account_id": "123456789012", "scan_timestamp": "2024-01-15T10:30:00Z", "regions_scanned": ["us-east-1"], "services_scanned": ["vpc"], "scanner_version": "1.0.0", "total_resources": 0, "total_errors": 0, }, "resources": {}, "errors": [], } def test_load_scan_data_success(self): """Test loading valid scan data from file.""" # Validates: Requirements 2.5 (round-trip consistency) data = self._create_valid_scan_data() file_path = os.path.join(self.temp_dir, "test_load.json") with open(file_path, "w", encoding="utf-8") as f: json.dump(data, f) loaded = CloudShellScanner.load_scan_data(file_path) self.assertEqual(loaded["metadata"]["account_id"], "123456789012") def test_load_scan_data_file_not_found(self): """Test that loading non-existent file raises FileNotFoundError.""" with self.assertRaises(FileNotFoundError): CloudShellScanner.load_scan_data("/nonexistent/file.json") def test_load_scan_data_invalid_json(self): """Test that loading invalid JSON raises JSONDecodeError.""" file_path = os.path.join(self.temp_dir, "invalid.json") with open(file_path, "w") as f: f.write("not valid json {{{") with self.assertRaises(json.JSONDecodeError): CloudShellScanner.load_scan_data(file_path) def test_load_scan_data_invalid_structure(self): """Test that loading JSON with invalid structure raises ValueError.""" file_path = os.path.join(self.temp_dir, "invalid_structure.json") with open(file_path, "w") as f: json.dump({"invalid": "structure"}, f) with self.assertRaises(ValueError): CloudShellScanner.load_scan_data(file_path) class TestJsonRoundTrip(unittest.TestCase): """Tests for JSON round-trip consistency (Property 1).""" @patch('cloudshell_scanner.boto3.Session') def setUp(self, mock_session): """Set up test fixtures.""" self.mock_session = MagicMock() mock_session.return_value = self.mock_session self.mock_sts = MagicMock() self.mock_sts.get_caller_identity.return_value = {"Account": "123456789012"} self.mock_session.client.return_value = self.mock_sts self.scanner = CloudShellScanner() self.temp_dir = tempfile.mkdtemp() def tearDown(self): """Clean up temporary files.""" import shutil shutil.rmtree(self.temp_dir, ignore_errors=True) def test_round_trip_preserves_data(self): """Test that export and load preserves all data.""" # Validates: Requirements 2.4, 2.5 (Property 1: JSON round-trip consistency) original_data = { "metadata": { "account_id": "123456789012", "scan_timestamp": "2024-01-15T10:30:00Z", "regions_scanned": ["us-east-1", "us-west-2", "eu-west-1"], "services_scanned": ["vpc", "ec2", "rds", "s3"], "scanner_version": "1.0.0", "total_resources": 10, "total_errors": 2, }, "resources": { "vpc": [ { "account_id": "123456789012", "region": "us-east-1", "service": "vpc", "resource_type": "VPC", "resource_id": "vpc-12345", "name": "main-vpc", "attributes": { "CIDR": "10.0.0.0/16", "IsDefault": False, "Tags": [{"Key": "Name", "Value": "main-vpc"}], }, } ], "ec2": [ { "account_id": "123456789012", "region": "us-east-1", "service": "ec2", "resource_type": "Instance", "resource_id": "i-12345", "name": "web-server", "attributes": { "InstanceType": "t3.micro", "State": "running", }, } ], }, "errors": [ { "service": "rds", "region": "us-west-2", "error": "Access denied", "error_type": "ClientError", "details": { "error_code": "AccessDenied", "error_message": "User is not authorized", }, } ], } file_path = os.path.join(self.temp_dir, "round_trip.json") # Export self.scanner.export_json(original_data, file_path) # Load loaded_data = CloudShellScanner.load_scan_data(file_path) # Verify all data is preserved self.assertEqual(loaded_data["metadata"], original_data["metadata"]) self.assertEqual(loaded_data["resources"], original_data["resources"]) self.assertEqual(loaded_data["errors"], original_data["errors"]) def test_round_trip_with_empty_resources(self): """Test round-trip with empty resources.""" original_data = { "metadata": { "account_id": "123456789012", "scan_timestamp": "2024-01-15T10:30:00Z", "regions_scanned": [], "services_scanned": [], "scanner_version": "1.0.0", "total_resources": 0, "total_errors": 0, }, "resources": {}, "errors": [], } file_path = os.path.join(self.temp_dir, "empty_round_trip.json") self.scanner.export_json(original_data, file_path) loaded_data = CloudShellScanner.load_scan_data(file_path) self.assertEqual(loaded_data, original_data) def test_round_trip_with_special_characters(self): """Test round-trip with special characters in data.""" original_data = { "metadata": { "account_id": "123456789012", "scan_timestamp": "2024-01-15T10:30:00Z", "regions_scanned": ["us-east-1"], "services_scanned": ["vpc"], "scanner_version": "1.0.0", "total_resources": 1, "total_errors": 0, }, "resources": { "vpc": [ { "account_id": "123456789012", "region": "us-east-1", "service": "vpc", "resource_type": "VPC", "resource_id": "vpc-12345", "name": "测试VPC-日本語-émoji-🚀", "attributes": { "Description": "Special chars: <>&\"'", }, } ], }, "errors": [], } file_path = os.path.join(self.temp_dir, "special_chars.json") self.scanner.export_json(original_data, file_path) loaded_data = CloudShellScanner.load_scan_data(file_path) self.assertEqual( loaded_data["resources"]["vpc"][0]["name"], "测试VPC-日本語-émoji-🚀" )