# Copyright 2018 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for vcf_parser module.""" from __future__ import absolute_import import logging import unittest from itertools import permutations from gcp_variant_transforms.beam_io import vcfio from gcp_variant_transforms.beam_io.vcfio import Variant from gcp_variant_transforms.beam_io.vcfio import VariantCall from gcp_variant_transforms.testing.testdata_util import hash_name class VariantTest(unittest.TestCase): def _assert_variants_equal(self, actual, expected): self.assertEqual( sorted(expected), sorted(actual)) def test_sort_variants(self): sorted_variants = [ Variant(reference_name='a', start=20, end=22), Variant(reference_name='a', start=20, end=22, quality=20), Variant(reference_name='b', start=20, end=22), Variant(reference_name='b', start=21, end=22), Variant(reference_name='b', start=21, end=23)] for permutation in permutations(sorted_variants): self.assertEqual(sorted(permutation), sorted_variants) def test_variant_equality(self): base_variant = Variant(reference_name='a', start=20, end=22, reference_bases='a', alternate_bases=['g', 't'], names=['variant'], quality=9, filters=['q10'], info={'key': 'value'}, calls=[VariantCall(genotype=[0, 0])]) equal_variant = Variant(reference_name='a', start=20, end=22, reference_bases='a', alternate_bases=['g', 't'], names=['variant'], quality=9, filters=['q10'], info={'key': 'value'}, calls=[VariantCall(genotype=[0, 0])]) different_calls = Variant(reference_name='a', start=20, end=22, reference_bases='a', alternate_bases=['g', 't'], names=['variant'], quality=9, filters=['q10'], info={'key': 'value'}, calls=[VariantCall(genotype=[1, 0])]) missing_field = Variant(reference_name='a', start=20, end=22, reference_bases='a', alternate_bases=['g', 't'], names=['variant'], quality=9, filters=['q10'], info={'key': 'value'}) self.assertEqual(base_variant, equal_variant) self.assertNotEqual(base_variant, different_calls) self.assertNotEqual(base_variant, missing_field) class VariantCallTest(unittest.TestCase): def _default_variant_call(self): return vcfio.VariantCall( sample_id=hash_name('Sample1'), genotype=[1, 0], phaseset=vcfio.DEFAULT_PHASESET_VALUE, info={'GQ': 48}) def test_variant_call_order(self): variant_call_1 = self._default_variant_call() variant_call_2 = self._default_variant_call() self.assertEqual(variant_call_1, variant_call_2) variant_call_1.phaseset = 0 variant_call_2.phaseset = 1 self.assertGreater(variant_call_2, variant_call_1) if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main()