# Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for derivatives variant_transform_options.""" import unittest import argparse from typing import List # pylint: disable=unused-import import mock from apache_beam.io.gcp.internal.clients import bigquery from apitools.base.py import exceptions from gcp_variant_transforms.options import variant_transform_options from gcp_variant_transforms.testing import temp_dir def make_args(options, args): parser = argparse.ArgumentParser() parser.register('type', 'bool', lambda v: v.lower() == 'true') options.add_arguments(parser) namespace, remaining_args = parser.parse_known_args(args) assert not remaining_args return namespace class VcfReadOptionsTest(unittest.TestCase): """Tests cases for the VcfReadOptions class.""" def setUp(self): self._options = variant_transform_options.VcfReadOptions() def _make_args(self, args): # type: (List[str]) -> argparse.Namespace return make_args(self._options, args) def test_no_inputs(self): args = self._make_args([]) self.assertRaises(ValueError, self._options.validate, args) def test_failure_for_conflicting_flags_inputs(self): args = self._make_args(['--input_pattern', '*', '--input_file', 'asd']) self.assertRaises(ValueError, self._options.validate, args) def test_failure_for_conflicting_flags_headers(self): args = self._make_args(['--input_pattern', '*', '--infer_headers', '--representative_header_file', 'gs://some_file']) self.assertRaises(ValueError, self._options.validate, args) def test_failure_for_conflicting_flags_no_errors_with_pattern_input(self): args = self._make_args(['--input_pattern', '*', '--representative_header_file', 'gs://some_file']) self._options.validate(args) def test_failure_for_conflicting_flags_no_errors_with_file_input(self): lines = ['./gcp_variant_transforms/testing/data/vcf/valid-4.0.vcf\n', './gcp_variant_transforms/testing/data/vcf/valid-4.0.vcf\n', './gcp_variant_transforms/testing/data/vcf/valid-4.0.vcf\n'] with temp_dir.TempDir() as tempdir: filename = tempdir.create_temp_file(lines=lines) args = self._make_args([ '--input_file', filename, '--representative_header_file', 'gs://some_file']) self._options.validate(args) class BigQueryWriteOptionsTest(unittest.TestCase): """Tests cases for the BigQueryWriteOptions class.""" def setUp(self): self._options = variant_transform_options.BigQueryWriteOptions() def _make_args(self, args): # type: (List[str]) -> argparse.Namespace return make_args(self._options, args) def test_valid_table_path(self): args = self._make_args(['--append', '--output_table', 'project:dataset.table']) client = mock.Mock() client.datasets.Get.return_value = bigquery.Dataset( datasetReference=bigquery.DatasetReference( projectId='project', datasetId='dataset')) self._options.validate(args, client) def test_existing_sample_table(self): args = self._make_args( ['--append', 'False', '--output_table', 'project:dataset.table', '--sharding_config_path', 'gcp_variant_transforms/testing/data/sharding_configs/' 'residual_at_end.yaml']) client = mock.Mock() client.tables.Get.return_value = bigquery.Table( tableReference=bigquery.TableReference(projectId='project', datasetId='dataset', tableId='table__sample_info')) with self.assertRaisesRegexp( ValueError, 'project:dataset.table__sample_info already exists'): self._options.validate(args, client) def test_existing_main_table(self): def side_effect(request): if (request == bigquery.BigqueryTablesGetRequest( projectId='project', datasetId='dataset', tableId='table__sample_info')): raise exceptions.HttpError(response={'status': '404'}, url='', content='') else: return bigquery.Table(tableReference=bigquery.TableReference( projectId='project', datasetId='dataset', tableId='table__chr1_part1')) args = self._make_args( ['--append', 'False', '--output_table', 'project:dataset.table', '--sharding_config_path', 'gcp_variant_transforms/testing/data/sharding_configs/' 'residual_at_end.yaml']) client = mock.Mock() client.tables.Get.side_effect = side_effect with self.assertRaisesRegexp( ValueError, 'project:dataset.table__chr01_part1 already exists'): self._options.validate(args, client) def test_missing_sample_table(self): args = self._make_args( ['--append', 'True', '--output_table', 'project:dataset.table', '--sharding_config_path', 'gcp_variant_transforms/testing/data/sharding_configs/' 'residual_at_end.yaml']) client = mock.Mock() client.tables.Get.side_effect = exceptions.HttpError( response={'status': '404'}, url='', content='') with self.assertRaisesRegexp( ValueError, 'project:dataset.table__sample_info does not exist'): self._options.validate(args, client) def test_missing_main_table(self): def side_effect(request): if (request == bigquery.BigqueryTablesGetRequest( projectId='project', datasetId='dataset', tableId='table__sample_info')): return bigquery.Table(tableReference=bigquery.TableReference( projectId='project', datasetId='dataset', tableId='table__sample_info')) else: raise exceptions.HttpError(response={'status': '404'}, url='', content='') args = self._make_args( ['--append', 'True', '--output_table', 'project:dataset.table', '--sharding_config_path', 'gcp_variant_transforms/testing/data/sharding_configs/' 'residual_at_end.yaml']) client = mock.Mock() client.tables.Get.side_effect = side_effect with self.assertRaisesRegexp( ValueError, 'project:dataset.table__chr01_part1 does not exist'): self._options.validate(args, client) def test_no_project(self): args = self._make_args(['--output_table', 'dataset.table']) client = mock.Mock() self.assertRaises(ValueError, self._options.validate, args, client) def test_invalid_table_path(self): no_table = self._make_args(['--output_table', 'project:dataset']) incorrect_sep1 = self._make_args(['--output_table', 'project.dataset.table']) incorrect_sep2 = self._make_args(['--output_table', 'project:dataset:table']) client = mock.Mock() self.assertRaises( ValueError, self._options.validate, no_table, client) self.assertRaises( ValueError, self._options.validate, incorrect_sep1, client) self.assertRaises( ValueError, self._options.validate, incorrect_sep2, client) def test_dataset_does_not_exists(self): args = self._make_args(['--output_table', 'project:dataset.table']) client = mock.Mock() client.datasets.Get.side_effect = exceptions.HttpError( response={'status': '404'}, url='', content='') self.assertRaises(ValueError, self._options.validate, args, client) class AnnotationOptionsTest(unittest.TestCase): def setUp(self): self._options = variant_transform_options.AnnotationOptions() def _make_args(self, args): # type: (List[str]) -> argparse.Namespace return make_args(self._options, args) def test_validate_okay(self): """Tests that no exceptions are raised for valid arguments.""" args = self._make_args(['--run_annotation_pipeline', '--annotation_output_dir', 'gs://GOOD_DIR', '--vep_image_uri', 'AN_IMAGE', '--vep_cache_path', 'gs://VEP_CACHE']) self._options.validate(args) def test_invalid_output_dir(self): args = self._make_args(['--run_annotation_pipeline', '--annotation_output_dir', 'BAD_DIR', '--vep_image_uri', 'AN_IMAGE', '--vep_cache_path', 'gs://VEP_CACHE']) self.assertRaises(ValueError, self._options.validate, args) def test_failure_for_no_image(self): args = self._make_args(['--run_annotation_pipeline', '--annotation_output_dir', 'BAD_DIR', '--vep_cache_path', 'gs://VEP_CACHE']) self.assertRaises(ValueError, self._options.validate, args) def test_failure_for_invalid_vep_cache(self): args = self._make_args(['--run_annotation_pipeline', '--annotation_output_dir', 'gs://GOOD_DIR', '--vep_image_uri', 'AN_IMAGE', '--vep_cache_path', 'VEP_CACHE']) self.assertRaises(ValueError, self._options.validate, args) class PreprocessOptionsTest(unittest.TestCase): """Tests cases for the PreprocessOptions class.""" def setUp(self): self._options = variant_transform_options.PreprocessOptions() def _make_args(self, args): # type: (List[str]) -> argparse.Namespace return make_args(self._options, args) def test_failure_for_conflicting_flags_inputs(self): args = self._make_args(['--input_pattern', '*', '--report_path', 'some_path', '--input_file', 'asd']) self.assertRaises(ValueError, self._options.validate, args) def test_failure_for_conflicting_flags_no_errors(self): args = self._make_args(['--input_pattern', '*', '--report_path', 'some_path']) self._options.validate(args) def test_failure_for_conflicting_flags_no_errors_with_pattern_input(self): args = self._make_args(['--input_pattern', '*', '--report_path', 'some_path']) self._options.validate(args)