# -*- coding: utf-8 -*- # Upside Travel, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime import json import unittest import boto3 import botocore.session from botocore.stub import Stubber from common import AV_SCAN_START_METADATA from common import AV_SIGNATURE_METADATA from common import AV_SIGNATURE_OK from common import AV_STATUS_METADATA from common import AV_TIMESTAMP_METADATA from common import get_timestamp from scan import delete_s3_object from scan import event_object from scan import get_local_path from scan import set_av_metadata from scan import set_av_tags from scan import sns_start_scan from scan import sns_scan_results from scan import verify_s3_object_version class TestScan(unittest.TestCase): def setUp(self): # Common data self.s3_bucket_name = "test_bucket" self.s3_key_name = "test_key" # Clients and Resources self.s3 = boto3.resource("s3") self.s3_client = botocore.session.get_session().create_client("s3") self.sns_client = botocore.session.get_session().create_client( "sns", region_name="us-west-2" ) def test_sns_event_object(self): event = { "Records": [ { "s3": { "bucket": {"name": self.s3_bucket_name}, "object": {"key": self.s3_key_name}, } } ] } sns_event = {"Records": [{"Sns": {"Message": json.dumps(event)}}]} s3_obj = event_object(sns_event, event_source="sns") expected_s3_object = self.s3.Object(self.s3_bucket_name, self.s3_key_name) self.assertEquals(s3_obj, expected_s3_object) def test_s3_event_object(self): event = { "Records": [ { "s3": { "bucket": {"name": self.s3_bucket_name}, "object": {"key": self.s3_key_name}, } } ] } s3_obj = event_object(event) expected_s3_object = self.s3.Object(self.s3_bucket_name, self.s3_key_name) self.assertEquals(s3_obj, expected_s3_object) def test_s3_event_object_missing_bucket(self): event = {"Records": [{"s3": {"object": {"key": self.s3_key_name}}}]} with self.assertRaises(Exception) as cm: event_object(event) self.assertEquals(cm.exception.message, "No bucket found in event!") def test_s3_event_object_missing_key(self): event = {"Records": [{"s3": {"bucket": {"name": self.s3_bucket_name}}}]} with self.assertRaises(Exception) as cm: event_object(event) self.assertEquals(cm.exception.message, "No key found in event!") def test_s3_event_object_bucket_key_missing(self): event = {"Records": [{"s3": {"bucket": {}, "object": {}}}]} with self.assertRaises(Exception) as cm: event_object(event) self.assertEquals( cm.exception.message, "Unable to retrieve object from event.\n{}".format(event), ) def test_s3_event_object_no_records(self): event = {"Records": []} with self.assertRaises(Exception) as cm: event_object(event) self.assertEquals(cm.exception.message, "No records found in event!") def test_verify_s3_object_version(self): s3_obj = self.s3.Object(self.s3_bucket_name, self.s3_key_name) # Set up responses get_bucket_versioning_response = {"Status": "Enabled"} get_bucket_versioning_expected_params = {"Bucket": self.s3_bucket_name} s3_stubber_resource = Stubber(self.s3.meta.client) s3_stubber_resource.add_response( "get_bucket_versioning", get_bucket_versioning_response, get_bucket_versioning_expected_params, ) list_object_versions_response = { "Versions": [ { "ETag": "string", "Size": 123, "StorageClass": "STANDARD", "Key": "string", "VersionId": "string", "IsLatest": True, "LastModified": datetime.datetime(2015, 1, 1), "Owner": {"DisplayName": "string", "ID": "string"}, } ] } list_object_versions_expected_params = { "Bucket": self.s3_bucket_name, "Prefix": self.s3_key_name, } s3_stubber_resource.add_response( "list_object_versions", list_object_versions_response, list_object_versions_expected_params, ) try: with s3_stubber_resource: verify_s3_object_version(self.s3, s3_obj) except Exception as e: self.fail("verify_s3_object_version() raised Exception unexpectedly!") raise e def test_verify_s3_object_versioning_not_enabled(self): s3_obj = self.s3.Object(self.s3_bucket_name, self.s3_key_name) # Set up responses get_bucket_versioning_response = {"Status": "Disabled"} get_bucket_versioning_expected_params = {"Bucket": self.s3_bucket_name} s3_stubber_resource = Stubber(self.s3.meta.client) s3_stubber_resource.add_response( "get_bucket_versioning", get_bucket_versioning_response, get_bucket_versioning_expected_params, ) with self.assertRaises(Exception) as cm: with s3_stubber_resource: verify_s3_object_version(self.s3, s3_obj) self.assertEquals( cm.exception.message, "Object versioning is not enabled in bucket {}".format( self.s3_bucket_name ), ) def test_verify_s3_object_version_multiple_versions(self): s3_obj = self.s3.Object(self.s3_bucket_name, self.s3_key_name) # Set up responses get_bucket_versioning_response = {"Status": "Enabled"} get_bucket_versioning_expected_params = {"Bucket": self.s3_bucket_name} s3_stubber_resource = Stubber(self.s3.meta.client) s3_stubber_resource.add_response( "get_bucket_versioning", get_bucket_versioning_response, get_bucket_versioning_expected_params, ) list_object_versions_response = { "Versions": [ { "ETag": "string", "Size": 123, "StorageClass": "STANDARD", "Key": "string", "VersionId": "string", "IsLatest": True, "LastModified": datetime.datetime(2015, 1, 1), "Owner": {"DisplayName": "string", "ID": "string"}, }, { "ETag": "string", "Size": 123, "StorageClass": "STANDARD", "Key": "string", "VersionId": "string", "IsLatest": True, "LastModified": datetime.datetime(2015, 1, 1), "Owner": {"DisplayName": "string", "ID": "string"}, }, ] } list_object_versions_expected_params = { "Bucket": self.s3_bucket_name, "Prefix": self.s3_key_name, } s3_stubber_resource.add_response( "list_object_versions", list_object_versions_response, list_object_versions_expected_params, ) with self.assertRaises(Exception) as cm: with s3_stubber_resource: verify_s3_object_version(self.s3, s3_obj) self.assertEquals( cm.exception.message, "Detected multiple object versions in {}.{}, aborting processing".format( self.s3_bucket_name, self.s3_key_name ), ) def test_sns_start_scan(self): sns_stubber = Stubber(self.sns_client) s3_stubber_resource = Stubber(self.s3.meta.client) sns_arn = "some_arn" version_id = "version-id" timestamp = get_timestamp() message = { "bucket": self.s3_bucket_name, "key": self.s3_key_name, "version": version_id, AV_SCAN_START_METADATA: True, AV_TIMESTAMP_METADATA: timestamp, } publish_response = {"MessageId": "message_id"} publish_expected_params = { "TargetArn": sns_arn, "Message": json.dumps({"default": json.dumps(message)}), "MessageStructure": "json", } sns_stubber.add_response("publish", publish_response, publish_expected_params) head_object_response = {"VersionId": version_id} head_object_expected_params = { "Bucket": self.s3_bucket_name, "Key": self.s3_key_name, } s3_stubber_resource.add_response( "head_object", head_object_response, head_object_expected_params ) with sns_stubber, s3_stubber_resource: s3_obj = self.s3.Object(self.s3_bucket_name, self.s3_key_name) sns_start_scan(self.sns_client, s3_obj, sns_arn, timestamp) def test_get_local_path(self): local_prefix = "/tmp" s3_obj = self.s3.Object(self.s3_bucket_name, self.s3_key_name) file_path = get_local_path(s3_obj, local_prefix) expected_file_path = "/tmp/test_bucket/test_key" self.assertEquals(file_path, expected_file_path) def test_set_av_metadata(self): scan_result = "CLEAN" scan_signature = AV_SIGNATURE_OK timestamp = get_timestamp() s3_obj = self.s3.Object(self.s3_bucket_name, self.s3_key_name) s3_stubber_resource = Stubber(self.s3.meta.client) # First head call is done to get content type and meta data head_object_response = {"ContentType": "content", "Metadata": {}} head_object_expected_params = { "Bucket": self.s3_bucket_name, "Key": self.s3_key_name, } s3_stubber_resource.add_response( "head_object", head_object_response, head_object_expected_params ) # Next two calls are done when copy() is called head_object_response_2 = { "ContentType": "content", "Metadata": {}, "ContentLength": 200, } head_object_expected_params_2 = { "Bucket": self.s3_bucket_name, "Key": self.s3_key_name, } s3_stubber_resource.add_response( "head_object", head_object_response_2, head_object_expected_params_2 ) copy_object_response = {"VersionId": "version_id"} copy_object_expected_params = { "Bucket": self.s3_bucket_name, "Key": self.s3_key_name, "ContentType": "content", "CopySource": {"Bucket": self.s3_bucket_name, "Key": self.s3_key_name}, "Metadata": { AV_SIGNATURE_METADATA: scan_signature, AV_STATUS_METADATA: scan_result, AV_TIMESTAMP_METADATA: timestamp, }, "MetadataDirective": "REPLACE", } s3_stubber_resource.add_response( "copy_object", copy_object_response, copy_object_expected_params ) with s3_stubber_resource: set_av_metadata(s3_obj, scan_result, scan_signature, timestamp) def test_set_av_tags(self): scan_result = "CLEAN" scan_signature = AV_SIGNATURE_OK timestamp = get_timestamp() tag_set = { "TagSet": [ {"Key": AV_SIGNATURE_METADATA, "Value": scan_signature}, {"Key": AV_STATUS_METADATA, "Value": scan_result}, {"Key": AV_TIMESTAMP_METADATA, "Value": timestamp}, ] } s3_stubber = Stubber(self.s3_client) get_object_tagging_response = tag_set get_object_tagging_expected_params = { "Bucket": self.s3_bucket_name, "Key": self.s3_key_name, } s3_stubber.add_response( "get_object_tagging", get_object_tagging_response, get_object_tagging_expected_params, ) put_object_tagging_response = {} put_object_tagging_expected_params = { "Bucket": self.s3_bucket_name, "Key": self.s3_key_name, "Tagging": tag_set, } s3_stubber.add_response( "put_object_tagging", put_object_tagging_response, put_object_tagging_expected_params, ) with s3_stubber: s3_obj = self.s3.Object(self.s3_bucket_name, self.s3_key_name) set_av_tags(self.s3_client, s3_obj, scan_result, scan_signature, timestamp) def test_sns_scan_results(self): sns_stubber = Stubber(self.sns_client) s3_stubber_resource = Stubber(self.s3.meta.client) sns_arn = "some_arn" version_id = "version-id" scan_result = "CLEAN" scan_signature = AV_SIGNATURE_OK timestamp = get_timestamp() message = { "bucket": self.s3_bucket_name, "key": self.s3_key_name, "version": version_id, AV_SIGNATURE_METADATA: scan_signature, AV_STATUS_METADATA: scan_result, AV_TIMESTAMP_METADATA: timestamp, } publish_response = {"MessageId": "message_id"} publish_expected_params = { "TargetArn": sns_arn, "Message": json.dumps({"default": json.dumps(message)}), "MessageAttributes": { "av-status": {"DataType": "String", "StringValue": scan_result}, "av-signature": {"DataType": "String", "StringValue": scan_signature}, }, "MessageStructure": "json", } sns_stubber.add_response("publish", publish_response, publish_expected_params) head_object_response = {"VersionId": version_id} head_object_expected_params = { "Bucket": self.s3_bucket_name, "Key": self.s3_key_name, } s3_stubber_resource.add_response( "head_object", head_object_response, head_object_expected_params ) with sns_stubber, s3_stubber_resource: s3_obj = self.s3.Object(self.s3_bucket_name, self.s3_key_name) sns_scan_results( self.sns_client, s3_obj, sns_arn, scan_result, scan_signature, timestamp ) def test_delete_s3_object(self): s3_stubber = Stubber(self.s3.meta.client) delete_object_response = {} delete_object_expected_params = { "Bucket": self.s3_bucket_name, "Key": self.s3_key_name, } s3_stubber.add_response( "delete_object", delete_object_response, delete_object_expected_params ) with s3_stubber: s3_obj = self.s3.Object(self.s3_bucket_name, self.s3_key_name) delete_s3_object(s3_obj) def test_delete_s3_object_exception(self): s3_stubber = Stubber(self.s3.meta.client) with self.assertRaises(Exception) as cm: with s3_stubber: s3_obj = self.s3.Object(self.s3_bucket_name, self.s3_key_name) delete_s3_object(s3_obj) self.assertEquals( cm.exception.message, "Failed to delete infected file: {}.{}".format( self.s3_bucket_name, self.s3_key_name ), )