#!/usr/bin/env python import ipaddress import io import os import textwrap import unittest import unittest.mock import dns.exception import dns.name import dns.resolver from pyfakefs import fake_filesystem_unittest from fierce import fierce # Simply getting a dns.resolver.Answer with a specific result was # more difficult than I'd like, let's just go with this less than # ideal approach for now class MockAnswer(object): def __init__(self, response): self.response = response def to_text(self): return self.response class TestFierce(unittest.TestCase): def test_concatenate_subdomains_empty(self): domain = dns.name.from_text("example.com.") subdomains = [] result = fierce.concatenate_subdomains(domain, subdomains) expected = dns.name.from_text("example.com.") assert expected == result def test_concatenate_subdomains_single_subdomain(self): domain = dns.name.from_text("example.com.") subdomains = ["sd1"] result = fierce.concatenate_subdomains(domain, subdomains) expected = dns.name.from_text("sd1.example.com.") assert expected == result def test_concatenate_subdomains_multiple_subdomains(self): domain = dns.name.from_text("example.com.") subdomains = ["sd1", "sd2"] result = fierce.concatenate_subdomains(domain, subdomains) expected = dns.name.from_text("sd1.sd2.example.com.") assert expected == result def test_concatenate_subdomains_makes_root(self): # Domain is missing '.' at the end domain = dns.name.from_text("example.com") subdomains = [] result = fierce.concatenate_subdomains(domain, subdomains) expected = dns.name.from_text("example.com.") assert expected == result def test_concatenate_subdomains_single_sub_subdomain(self): domain = dns.name.from_text("example.com.") subdomains = ["sd1.sd2"] result = fierce.concatenate_subdomains(domain, subdomains) expected = dns.name.from_text("sd1.sd2.example.com.") assert expected == result def test_concatenate_subdomains_multiple_sub_subdomain(self): domain = dns.name.from_text("example.com.") subdomains = ["sd1.sd2", "sd3.sd4"] result = fierce.concatenate_subdomains(domain, subdomains) expected = dns.name.from_text("sd1.sd2.sd3.sd4.example.com.") assert expected == result def test_concatenate_subdomains_fqdn_subdomain(self): domain = dns.name.from_text("example.") subdomains = ["sd1.sd2."] result = fierce.concatenate_subdomains(domain, subdomains) expected = dns.name.from_text("sd1.sd2.example.") assert expected == result def test_default_expander(self): ip = ipaddress.IPv4Address('192.168.1.1') result = fierce.default_expander(ip) expected = [ ipaddress.IPv4Address('192.168.1.1'), ] assert expected == result def test_traverse_expander_basic(self): ip = ipaddress.IPv4Address('192.168.1.1') expand = 1 result = fierce.traverse_expander(ip, expand) expected = [ ipaddress.IPv4Address('192.168.1.0'), ipaddress.IPv4Address('192.168.1.1'), ipaddress.IPv4Address('192.168.1.2'), ] assert expected == result def test_traverse_expander_no_cross_lower_boundary(self): ip = ipaddress.IPv4Address('192.168.1.1') expand = 2 result = fierce.traverse_expander(ip, expand) expected = [ ipaddress.IPv4Address('192.168.1.0'), ipaddress.IPv4Address('192.168.1.1'), ipaddress.IPv4Address('192.168.1.2'), ipaddress.IPv4Address('192.168.1.3'), ] assert expected == result def test_traverse_expander_no_cross_upper_boundary(self): ip = ipaddress.IPv4Address('192.168.1.254') expand = 2 result = fierce.traverse_expander(ip, expand) expected = [ ipaddress.IPv4Address('192.168.1.252'), ipaddress.IPv4Address('192.168.1.253'), ipaddress.IPv4Address('192.168.1.254'), ipaddress.IPv4Address('192.168.1.255'), ] assert expected == result # Upper and lower bound tests are to avoid reintroducing out of # bounds error from IPv4Address. (no_cross_*_boundary tests won't # necessarily cover this; GitHub issue #29) def test_traverse_expander_lower_bound_regression(self): ip = ipaddress.IPv4Address('0.0.0.1') expand = 2 result = fierce.traverse_expander(ip, expand) expected = [ ipaddress.IPv4Address('0.0.0.0'), ipaddress.IPv4Address('0.0.0.1'), ipaddress.IPv4Address('0.0.0.2'), ipaddress.IPv4Address('0.0.0.3') ] assert expected == result def test_traverse_expander_upper_bound_regression(self): ip = ipaddress.IPv4Address('255.255.255.254') expand = 2 result = fierce.traverse_expander(ip, expand) expected = [ ipaddress.IPv4Address('255.255.255.252'), ipaddress.IPv4Address('255.255.255.253'), ipaddress.IPv4Address('255.255.255.254'), ipaddress.IPv4Address('255.255.255.255') ] assert expected == result def test_wide_expander_basic(self): ip = ipaddress.IPv4Address('192.168.1.50') result = fierce.wide_expander(ip) expected = [ ipaddress.IPv4Address('192.168.1.{}'.format(i)) for i in range(256) ] assert expected == result def test_wide_expander_lower_boundary(self): ip = ipaddress.IPv4Address('192.168.1.0') result = fierce.wide_expander(ip) expected = [ ipaddress.IPv4Address('192.168.1.{}'.format(i)) for i in range(256) ] assert expected == result def test_wide_expander_upper_boundary(self): ip = ipaddress.IPv4Address('192.168.1.255') result = fierce.wide_expander(ip) expected = [ ipaddress.IPv4Address('192.168.1.{}'.format(i)) for i in range(256) ] assert expected == result def test_range_expander(self): ip = '192.168.1.0/31' result = fierce.range_expander(ip) expected = [ ipaddress.IPv4Address('192.168.1.0'), ipaddress.IPv4Address('192.168.1.1'), ] assert expected == result def test_recursive_query_basic_failure(self): resolver = dns.resolver.Resolver() domain = dns.name.from_text('example.com.') record_type = 'NS' with unittest.mock.patch.object(fierce, 'query', return_value=None) as mock_method: result = fierce.recursive_query(resolver, domain, record_type=record_type) expected = [ unittest.mock.call(resolver, 'example.com.', record_type, tcp=False), unittest.mock.call(resolver, 'com.', record_type, tcp=False), unittest.mock.call(resolver, '', record_type, tcp=False), ] mock_method.assert_has_calls(expected) assert result is None def test_recursive_query_long_domain_failure(self): resolver = dns.resolver.Resolver() domain = dns.name.from_text('sd1.sd2.example.com.') record_type = 'NS' with unittest.mock.patch.object(fierce, 'query', return_value=None) as mock_method: result = fierce.recursive_query(resolver, domain, record_type=record_type) expected = [ unittest.mock.call(resolver, 'sd1.sd2.example.com.', record_type, tcp=False), unittest.mock.call(resolver, 'sd2.example.com.', record_type, tcp=False), unittest.mock.call(resolver, 'example.com.', record_type, tcp=False), unittest.mock.call(resolver, 'com.', record_type, tcp=False), unittest.mock.call(resolver, '', record_type, tcp=False), ] mock_method.assert_has_calls(expected) assert result is None def test_recursive_query_basic_success(self): resolver = dns.resolver.Resolver() domain = dns.name.from_text('example.com.') record_type = 'NS' good_response = unittest.mock.MagicMock() side_effect = [ None, good_response, None, ] with unittest.mock.patch.object(fierce, 'query', side_effect=side_effect) as mock_method: result = fierce.recursive_query(resolver, domain, record_type=record_type) expected = [ unittest.mock.call(resolver, 'example.com.', record_type, tcp=False), unittest.mock.call(resolver, 'com.', record_type, tcp=False), ] mock_method.assert_has_calls(expected) assert result == good_response def test_query_nxdomain(self): resolver = dns.resolver.Resolver() domain = dns.name.from_text('example.com.') with unittest.mock.patch.object(resolver, 'query', side_effect=dns.resolver.NXDOMAIN()): result = fierce.query(resolver, domain) assert result is None def test_query_no_nameservers(self): resolver = dns.resolver.Resolver() domain = dns.name.from_text('example.com.') with unittest.mock.patch.object(resolver, 'query', side_effect=dns.resolver.NoNameservers()): result = fierce.query(resolver, domain) assert result is None def test_query_timeout(self): resolver = dns.resolver.Resolver() domain = dns.name.from_text('example.com.') with unittest.mock.patch.object(resolver, 'query', side_effect=dns.exception.Timeout()): result = fierce.query(resolver, domain) assert result is None def test_zone_transfer_connection_error(self): address = 'test' domain = dns.name.from_text('example.com.') with unittest.mock.patch.object(fierce.dns.zone, 'from_xfr', side_effect=ConnectionError()): result = fierce.zone_transfer(address, domain) assert result is None def test_zone_transfer_eof_error(self): address = 'test' domain = dns.name.from_text('example.com.') with unittest.mock.patch.object(fierce.dns.zone, 'from_xfr', side_effect=EOFError()): result = fierce.zone_transfer(address, domain) assert result is None def test_zone_transfer_timeout_error(self): address = 'test' domain = dns.name.from_text('example.com.') with unittest.mock.patch.object(fierce.dns.zone, 'from_xfr', side_effect=TimeoutError()): result = fierce.zone_transfer(address, domain) assert result is None def test_zone_transfer_form_error(self): address = 'test' domain = dns.name.from_text('example.com.') with unittest.mock.patch.object(fierce.dns.zone, 'from_xfr', side_effect=dns.exception.FormError()): result = fierce.zone_transfer(address, domain) assert result is None def test_find_nearby_empty(self): resolver = 'unused' ips = [] result = fierce.find_nearby(resolver, ips) expected = {} assert expected == result def test_find_nearby_basic(self): resolver = 'unused' ips = [ ipaddress.IPv4Address('192.168.1.0'), ipaddress.IPv4Address('192.168.1.1'), ] side_effect = [ [MockAnswer('sd1.example.com.')], [MockAnswer('sd2.example.com.')], ] with unittest.mock.patch.object(fierce, 'reverse_query', side_effect=side_effect): result = fierce.find_nearby(resolver, ips) expected = { '192.168.1.0': 'sd1.example.com.', '192.168.1.1': 'sd2.example.com.', } assert expected == result def test_find_nearby_filter_func(self): resolver = 'unused' ips = [ ipaddress.IPv4Address('192.168.1.0'), ipaddress.IPv4Address('192.168.1.1'), ] side_effect = [ [MockAnswer('sd1.example.com.')], [MockAnswer('sd2.example.com.')], ] def filter_func(reverse_result): return reverse_result == 'sd1.example.com.' with unittest.mock.patch.object(fierce, 'reverse_query', side_effect=side_effect): result = fierce.find_nearby(resolver, ips, filter_func=filter_func) expected = { '192.168.1.0': 'sd1.example.com.', } assert expected == result def test_print_subdomain_result_basic(self): url = 'example.com' ip = '192.168.1.0' with io.StringIO() as stream: fierce.print_subdomain_result(url, ip, stream=stream) result = stream.getvalue() expected = 'Found: example.com (192.168.1.0)\n' assert expected == result def test_print_subdomain_result_nearby(self): url = 'example.com' ip = '192.168.1.0' nearby = {'192.168.1.1': 'nearby.com'} with io.StringIO() as stream: fierce.print_subdomain_result(url, ip, nearby=nearby, stream=stream) result = stream.getvalue() expected = textwrap.dedent(''' Found: example.com (192.168.1.0) Nearby: {'192.168.1.1': 'nearby.com'} ''').lstrip() assert expected == result def test_print_subdomain_result_http_header(self): url = 'example.com' ip = '192.168.1.0' http_connection_headers = {'HTTP HEADER': 'value'} with io.StringIO() as stream: fierce.print_subdomain_result( url, ip, http_connection_headers=http_connection_headers, stream=stream ) result = stream.getvalue() expected = textwrap.dedent(''' Found: example.com (192.168.1.0) HTTP connected: {'HTTP HEADER': 'value'} ''').lstrip() assert expected == result def test_print_subdomain_result_both(self): url = 'example.com' ip = '192.168.1.0' http_connection_headers = {'HTTP HEADER': 'value'} nearby = {'192.168.1.1': 'nearby.com'} with io.StringIO() as stream: fierce.print_subdomain_result( url, ip, http_connection_headers=http_connection_headers, nearby=nearby, stream=stream ) result = stream.getvalue() expected = textwrap.dedent(''' Found: example.com (192.168.1.0) HTTP connected: {'HTTP HEADER': 'value'} Nearby: {'192.168.1.1': 'nearby.com'} ''').lstrip() assert expected == result def test_unvisited_closure_empty(self): unvisited = fierce.unvisited_closure() ips = set() result = unvisited(ips) expected = set() assert expected == result def test_unvisited_closure_empty_intersection(self): unvisited = fierce.unvisited_closure() unvisited(set([1, 2, 3])) result = unvisited(set([4, 5, 6])) expected = set([4, 5, 6]) assert expected == result def test_unvisited_closure_overlapping_intersection(self): unvisited = fierce.unvisited_closure() unvisited(set([1, 2, 3])) result = unvisited(set([2, 3, 4])) expected = set([4]) assert expected == result def test_search_filter_empty(self): domains = [] address = 'test.example.com' result = fierce.search_filter(domains, address) assert not result def test_search_filter_true(self): domains = ['example.com'] address = 'test.example.com' result = fierce.search_filter(domains, address) assert result def test_search_filter_false(self): domains = ['not.com'] address = 'test.example.com' result = fierce.search_filter(domains, address) assert not result class TestArgumentParsing(fake_filesystem_unittest.TestCase): def test_parse_args_basic(self): domain = 'example.com' args = fierce.parse_args([ '--domain', domain, ]) result = args.domain expected = domain assert expected == result def test_parse_args_included_list_file(self): filename = '5000.txt' args = fierce.parse_args([ '--domain', 'example.com', '--subdomain-file', filename, ]) result = args.subdomain_file expected = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'fierce', 'lists', filename, ) exists = os.path.exists(result) assert expected == result assert exists def test_parse_args_missing_list_file(self): filename = 'missing.txt' args = fierce.parse_args([ '--domain', 'example.com', '--subdomain-file', filename, ]) result = args.subdomain_file expected = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'fierce', 'lists', filename, ) exists = os.path.exists(result) assert expected == result assert not exists def test_parse_args_custom_list_file(self): self.setUpPyfakefs() filename = os.path.join('test', 'custom.txt') self.fs.create_file( filename, contents='subdomain' ) args = fierce.parse_args([ '--domain', 'example.com', '--subdomain-file', filename, ]) result = args.subdomain_file expected = filename exists = os.path.exists(result) assert expected == result assert exists if __name__ == "__main__": unittest.main()