# Copyright 2019 D-Wave Systems Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import io import os import time import json import unittest import tempfile import collections from unittest import mock from concurrent.futures import ThreadPoolExecutor, wait from requests.exceptions import HTTPError from dwave.cloud.utils import tictoc from dwave.cloud.client import Client from dwave.cloud.exceptions import ProblemUploadError from dwave.cloud.upload import ( Gettable, GettableFile, GettableMemory, FileView, ChunkedData) from tests import config class TestGettableABC(unittest.TestCase): def test_invalid(self): class InvalidGettable(Gettable): pass with self.assertRaises(TypeError): InvalidGettable() def test_valid(self): class ValidGettable(Gettable): def __len__(self): return NotImplementedError def __getitem__(self, key): return NotImplementedError def getinto(self, key): return NotImplementedError try: ValidGettable() except: self.fail("unexpected interface of Gettable") class TestGettables(unittest.TestCase): data = b'0123456789' def verify_getitem(self, gettable, data): n = len(data) # python 2 fix: indexing of bytes returns a slice (not int) data = bytearray(data) # integer indexing self.assertEqual(gettable[0], data[0]) self.assertEqual(gettable[n-1], data[n-1]) # negative integer indexing self.assertEqual(gettable[-1], data[-1]) self.assertEqual(gettable[-n], data[-n]) # out of bounds integer indexing with self.assertRaises(IndexError): gettable[n] # non-integer key with self.assertRaises(TypeError): gettable['a'] # empty slices self.assertEqual(gettable[1:0], b'') # slicing self.assertEqual(gettable[:], data[:]) self.assertEqual(gettable[0:n//2], data[0:n//2]) self.assertEqual(gettable[-n//2:], data[-n//2:]) def verify_getinto(self, gettable, data): n = len(data) # python 2 fix: indexing of bytes returns a slice (not int) data = bytearray(data) # integer indexing b = bytearray(n) self.assertEqual(gettable.getinto(0, b), 1) self.assertEqual(b[0], data[0]) self.assertEqual(gettable.getinto(n-1, b), 1) self.assertEqual(b[0], data[n-1]) # negative integer indexing self.assertEqual(gettable.getinto(-1, b), 1) self.assertEqual(b[0], data[-1]) self.assertEqual(gettable.getinto(-n, b), 1) self.assertEqual(b[0], data[-n]) # out of bounds integer indexing => nop self.assertEqual(gettable.getinto(n, b), 0) # non-integer key with self.assertRaises(TypeError): gettable.getinto('a', b) # empty slices self.assertEqual(gettable.getinto(slice(1, 0), b), 0) # slicing b = bytearray(n) self.assertEqual(gettable.getinto(slice(None), b), n) self.assertEqual(b, data) b = bytearray(n) self.assertEqual(gettable.getinto(slice(0, n//2), b), n//2) self.assertEqual(b[0:n//2], data[0:n//2]) self.assertEqual(b[n//2:], bytearray(n//2)) b = bytearray(n) self.assertEqual(gettable.getinto(slice(-n//2, None), b), n//2) self.assertEqual(b[:n//2], data[-n//2:]) self.assertEqual(b[n//2:], bytearray(n//2)) # slicing into a buffer too small m = 3 b = bytearray(m) self.assertEqual(gettable.getinto(slice(None), b), m) self.assertEqual(b, data[:m]) def test_gettable_file_from_memory_bytes(self): data = self.data fp = io.BytesIO(data) gf = GettableFile(fp) self.assertEqual(len(gf), len(data)) self.verify_getitem(gf, data) self.verify_getinto(gf, data) def test_gettable_file_from_memory_string(self): data = self.data.decode() fp = io.StringIO(data) with self.assertRaises(TypeError): GettableFile(fp) def test_gettable_file_from_file_like(self): data = self.data # create file-like temporary object (on POSIX this is a tmp file) with tempfile.TemporaryFile() as fp: fp.write(data) fp.seek(0) gf = GettableFile(fp, strict=False) self.assertEqual(len(gf), len(data)) self.verify_getitem(gf, data) self.verify_getinto(gf, data) def test_gettable_file_from_disk_file(self): data = self.data # create temporary file fd, path = tempfile.mkstemp() os.write(fd, data) os.close(fd) # test GettableFile from file on disk (read access) with io.open(path, 'rb') as fp: gf = GettableFile(fp) self.assertEqual(len(gf), len(data)) self.verify_getitem(gf, data) self.verify_getinto(gf, data) # works also for read+write access with io.open(path, 'r+b') as fp: gf = GettableFile(fp) self.assertEqual(len(gf), len(data)) self.verify_getitem(gf, data) self.verify_getinto(gf, data) # fail without read access with io.open(path, 'wb') as fp: with self.assertRaises(TypeError): GettableFile(fp) # remove temp file os.unlink(path) def test_gettable_file_critical_section_respected(self): # setup a shared file view data = self.data fp = io.BytesIO(data) gf = GettableFile(fp) # file slices slice_a = slice(0, 7) slice_b = slice(3, 5) # add a noticeable sleep inside the critical section (on `file.seek`), # resulting in minimal runtime equal to (N runs * sleep in crit sect) sleep = 0.25 def blocking_seek(start): time.sleep(sleep) return io.BytesIO.seek(gf._fp, start) gf._fp.seek = blocking_seek # define the worker def worker(slice_): return gf[slice_] # run the worker a few times in parallel executor = ThreadPoolExecutor(max_workers=3) slices = [slice_a, slice_b, slice_a] futures = [executor.submit(worker, s) for s in slices] with tictoc() as timer: wait(futures) # verify results results = [f.result() for f in futures] expected = [data[s] for s in slices] self.assertEqual(results, expected) # verify runtime is consistent with a blocking critical section self.assertGreaterEqual(timer.dt, 0.9 * len(results) * sleep) def test_gettable_memory_from_bytes_like(self): data_objects = [ bytes(self.data), bytearray(self.data), memoryview(self.data) ] for data in data_objects: gm = GettableMemory(data) self.assertEqual(len(gm), len(data)) self.verify_getitem(gm, data) self.verify_getinto(gm, data) class TestFileView(unittest.TestCase): # python 2 fix: indexing of bytes returns a slice (not int) data = bytearray(b'0123456789') def test_file_interface(self): data = self.data size = len(data) fp = io.BytesIO(data) gf = GettableFile(fp) fv = FileView(gf) # partial read self.assertEqual(fv.read(1), data[0:1]) # read all, also check continuity self.assertEqual(fv.read(), data[1:]) # seek and tell self.assertEqual(len(fv), size) self.assertEqual(fv.seek(2), 2) self.assertEqual(fv.tell(), 2) self.assertEqual(fv.seek(2, io.SEEK_CUR), 4) self.assertEqual(fv.tell(), 4) self.assertEqual(fv.seek(0, io.SEEK_END), size) self.assertEqual(fv.tell(), size) # IOBase derived methods fv.seek(0) self.assertEqual(fv.readlines(), [data]) def test_view_interface(self): data = self.data size = len(data) fp = io.BytesIO(data) gf = GettableFile(fp) fv = FileView(gf) # view, slice index subfv = fv[1:-1] self.assertEqual(subfv.read(), data[1:-1]) self.assertEqual(len(subfv), size - 2) # view, integer index self.assertEqual(fv[2], data[2]) # view, out of bounds index with self.assertRaises(IndexError): fv[size] # view are independent self.assertEqual(fv[:2].read(), data[:2]) self.assertEqual(fv[-2:].read(), data[-2:]) class TestChunkedData(unittest.TestCase): data = b'0123456789' def verify_chunking(self, cd, chunks_expected): self.assertEqual(len(cd), len(chunks_expected)) self.assertEqual(cd.num_chunks, len(chunks_expected)) chunks_iter = [c.read() for c in cd] chunks_explicit = [] for idx in range(len(cd)): chunks_explicit.append(cd.chunk(idx).read()) self.assertListEqual(chunks_iter, chunks_expected) self.assertListEqual(chunks_explicit, chunks_iter) def test_chunks_from_bytes(self): cd = ChunkedData(self.data, chunk_size=3) chunks_expected = [b'012', b'345', b'678', b'9'] self.verify_chunking(cd, chunks_expected) def test_chunks_from_bytearray(self): cd = ChunkedData(bytearray(self.data), chunk_size=3) chunks_expected = [b'012', b'345', b'678', b'9'] self.verify_chunking(cd, chunks_expected) def test_chunks_from_str(self): cd = ChunkedData(self.data.decode('ascii'), chunk_size=3) chunks_expected = [b'012', b'345', b'678', b'9'] self.verify_chunking(cd, chunks_expected) def test_chunks_from_memory_file(self): data = io.BytesIO(self.data) cd = ChunkedData(data, chunk_size=3) chunks_expected = [b'012', b'345', b'678', b'9'] self.verify_chunking(cd, chunks_expected) def test_chunk_size_edges(self): with self.assertRaises(ValueError): cd = ChunkedData(self.data, chunk_size=0) cd = ChunkedData(self.data, chunk_size=1) chunks_expected = [self.data[i:i+1] for i in range(len(self.data))] self.verify_chunking(cd, chunks_expected) cd = ChunkedData(self.data, chunk_size=len(self.data)) chunks_expected = [self.data] self.verify_chunking(cd, chunks_expected) @unittest.skipUnless(config, "No live server configuration available.") class TestMultipartUpload(unittest.TestCase): def test_smoke_test(self): data = b'123' with Client(**config) as client: future = client.upload_problem_encoded(data) try: problem_id = future.result() except Exception as e: self.fail(e) def choose_reply(key, replies, statuses=None): """Choose the right response based on a hashable `key` and make a mock response. """ if statuses is None: statuses = collections.defaultdict(lambda: iter([200])) if key in replies: response = mock.Mock(['text', 'json', 'raise_for_status']) response.status_code = next(statuses[key]) response.text = replies[key] response.json.side_effect = lambda: json.loads(replies[key]) def raise_for_status(): if not 200 <= response.status_code < 300: raise HTTPError(response.status_code) response.raise_for_status = raise_for_status return response else: raise NotImplementedError(key) @mock.patch('time.sleep', lambda *args: None) class TestMockedMultipartUpload(unittest.TestCase): @mock.patch.multiple(Client, _UPLOAD_PART_SIZE_BYTES=1) def test_single_problem_end_to_end(self): """Verify a fresh problem multipart upload works end to end.""" upload_data = b'123' upload_problem_id = '84ef154c-28f9-46ed-9f22-aec0583499f2' parts = list(range(len(upload_data))) part_data = [upload_data[i:i+1] for i in parts] _md5 = Client._digest _hex = Client._checksum_hex _b64 = Client._checksum_b64 part_digest = [_md5(part_data[i]) for i in parts] combine_checksum = _hex(_md5(b''.join(part_digest))) # we need a "global session", because mocked responses are stateful def global_mock_session(): session = mock.MagicMock() session.__enter__ = lambda *args: session def get(path, seq=iter(range(2))): all_parts = [{"part_number": i+1, "checksum": _hex(part_digest[i])} for i in parts] return choose_reply((path, next(seq)), { # initial upload status ('bqm/multipart/{}/status'.format(upload_problem_id), 0): json.dumps({"status": "UPLOAD_IN_PROGRESS", "parts": []}), # final upload status ('bqm/multipart/{}/status'.format(upload_problem_id), 1): json.dumps({"status": "UPLOAD_IN_PROGRESS", "parts": all_parts}), }) def post(path, **kwargs): json_ = kwargs.pop('json') body = json.dumps(sorted(json_.items())) return choose_reply((path, body), { # initiate upload ('bqm/multipart', json.dumps([('size', len(upload_data))])): json.dumps({'id': upload_problem_id}), # combine parts ('bqm/multipart/{}/combine'.format(upload_problem_id), json.dumps([('checksum', combine_checksum)])): json.dumps({}), }) def put(path, data, headers): body = data.read() headers = json.dumps(sorted(headers.items())) replies = { ( 'bqm/multipart/{}/part/{}'.format(upload_problem_id, i+1), part_data[i], json.dumps(sorted([ ('Content-MD5', _b64(part_digest[i])), ('Content-Type', 'application/octet-stream') ])) ): json.dumps({}) for i in parts } return choose_reply((path, body, headers), replies) session.get = get session.put = put session.post = post return session session = global_mock_session() with mock.patch.object(Client, 'create_session', lambda self: session): with Client('endpoint', 'token') as client: future = client.upload_problem_encoded(upload_data) try: returned_problem_id = future.result() except Exception as e: self.fail(e) self.assertEqual(returned_problem_id, upload_problem_id) @mock.patch.multiple(Client, _UPLOAD_PART_SIZE_BYTES=1) def test_partial_upload(self): """Verify only missing parts are uploaded.""" upload_data = b'123' upload_problem_id = '84ef154c-28f9-46ed-9f22-aec0583499f2' parts = list(range(len(upload_data))) part_data = [upload_data[i:i+1] for i in parts] _md5 = Client._digest _hex = Client._checksum_hex _b64 = Client._checksum_b64 part_digest = [_md5(part_data[i]) for i in parts] combine_checksum = _hex(_md5(b''.join(part_digest))) # we need a "global session", because mocked responses are stateful def global_mock_session(): session = mock.MagicMock() session.__enter__ = lambda *args: session def get(path, seq=iter(range(2))): all_parts = [{"part_number": i+1, "checksum": _hex(part_digest[i])} for i in parts] return choose_reply((path, next(seq)), { # initial upload status: all parts uploaded except the first one ('bqm/multipart/{}/status'.format(upload_problem_id), 0): json.dumps({"status": "UPLOAD_IN_PROGRESS", "parts": all_parts[1:]}), # final upload status ('bqm/multipart/{}/status'.format(upload_problem_id), 1): json.dumps({"status": "UPLOAD_IN_PROGRESS", "parts": all_parts}), }) def post(path, **kwargs): json_ = kwargs.pop('json') body = json.dumps(sorted(json_.items())) return choose_reply((path, body), { # initiate upload ('bqm/multipart', json.dumps([('size', len(upload_data))])): json.dumps({'id': upload_problem_id}), # combine parts ('bqm/multipart/{}/combine'.format(upload_problem_id), json.dumps([('checksum', combine_checksum)])): json.dumps({}), }) def put(path, data, headers): body = data.read() headers = json.dumps(sorted(headers.items())) replies = { # only the first part! ( 'bqm/multipart/{}/part/{}'.format(upload_problem_id, i+1), part_data[i], json.dumps(sorted([ ('Content-MD5', _b64(part_digest[i])), ('Content-Type', 'application/octet-stream') ])) ): json.dumps({}) for i in parts[:1] } return choose_reply((path, body, headers), replies) session.get = get session.put = put session.post = post return session session = global_mock_session() with mock.patch.object(Client, 'create_session', lambda self: session): with Client('endpoint', 'token') as client: future = client.upload_problem_encoded(upload_data) try: returned_problem_id = future.result() except Exception as e: self.fail(e) self.assertEqual(returned_problem_id, upload_problem_id) def test_part_upload_retried(self): """Verify upload successful even if part upload fails a few times.""" # using the default part size here (5MB), so we have only one part upload_data = b'123' upload_problem_id = '84ef154c-28f9-46ed-9f22-aec0583499f2' parts = [0] part_data = [upload_data] _md5 = Client._digest _hex = Client._checksum_hex _b64 = Client._checksum_b64 part_digest = [_md5(part_data[i]) for i in parts] combine_checksum = _hex(_md5(b''.join(part_digest))) # we need a "global session", because mocked responses are stateful def global_mock_session(n_failures): session = mock.MagicMock() session.__enter__ = lambda *args: session def get(path, seq=iter(range(2))): all_parts = [{"part_number": i+1, "checksum": _hex(part_digest[i])} for i in parts] return choose_reply((path, next(seq)), { # initial upload status ('bqm/multipart/{}/status'.format(upload_problem_id), 0): json.dumps({"status": "UPLOAD_IN_PROGRESS", "parts": []}), # final upload status ('bqm/multipart/{}/status'.format(upload_problem_id), 1): json.dumps({"status": "UPLOAD_IN_PROGRESS", "parts": all_parts}), }) def post(path, **kwargs): json_ = kwargs.pop('json') body = json.dumps(sorted(json_.items())) return choose_reply((path, body), { # initiate upload ('bqm/multipart', json.dumps([('size', len(upload_data))])): json.dumps({'id': upload_problem_id}), # combine parts ('bqm/multipart/{}/combine'.format(upload_problem_id), json.dumps([('checksum', combine_checksum)])): json.dumps({}), }) def put(path, data, headers, seq=iter(range(Client._UPLOAD_PART_RETRIES+1))): body = data.read() data.seek(0) headers = json.dumps(sorted(headers.items())) keys = [ ( 'bqm/multipart/{}/part/{}'.format(upload_problem_id, i+1), part_data[i], json.dumps(sorted([ ('Content-MD5', _b64(part_digest[i])), ('Content-Type', 'application/octet-stream') ])) ) for i in parts ] attempt = next(seq) if attempt < n_failures: return choose_reply((path, body, headers), replies={key: '{}' for key in keys}, statuses={key: iter([500]) for key in keys}) else: return choose_reply((path, body, headers), replies={key: '{}' for key in keys}) session.get = get session.put = put session.post = post return session # part upload fails exactly _UPLOAD_PART_RETRIES times; # problem upload must recover session = global_mock_session(n_failures=Client._UPLOAD_PART_RETRIES) with mock.patch.object(Client, 'create_session', lambda self: session): with Client('endpoint', 'token') as client: future = client.upload_problem_encoded(upload_data) try: returned_problem_id = future.result() except Exception as e: self.fail(e) self.assertEqual(returned_problem_id, upload_problem_id) # part upload fails exactly _UPLOAD_PART_RETRIES + 1 times; # problem upload will also fail session = global_mock_session(n_failures=Client._UPLOAD_PART_RETRIES + 1) with mock.patch.object(Client, 'create_session', lambda self: session): with Client('endpoint', 'token') as client: with self.assertRaises(ProblemUploadError): client.upload_problem_encoded(upload_data).result() @mock.patch.multiple(Client, _UPLOAD_PART_SIZE_BYTES=1) def test_problem_reupload_end_to_end(self): """Verify problem multipart upload continued.""" upload_data = b'123' upload_problem_id = '84ef154c-28f9-46ed-9f22-aec0583499f2' parts = list(range(len(upload_data))) part_data = [upload_data[i:i+1] for i in parts] _md5 = Client._digest _hex = Client._checksum_hex _b64 = Client._checksum_b64 part_digest = [_md5(part_data[i]) for i in parts] combine_checksum = _hex(_md5(b''.join(part_digest))) # we need a "global session", because mocked responses are stateful def global_mock_session(): session = mock.MagicMock() session.__enter__ = lambda *args: session def get(path, seq=iter(range(2))): all_parts = [{"part_number": i+1, "checksum": _hex(part_digest[i])} for i in parts] return choose_reply((path, next(seq)), { # initial upload status ('bqm/multipart/{}/status'.format(upload_problem_id), 0): json.dumps({"status": "UPLOAD_IN_PROGRESS", "parts": all_parts[:2]}), # final upload status ('bqm/multipart/{}/status'.format(upload_problem_id), 1): json.dumps({"status": "UPLOAD_IN_PROGRESS", "parts": all_parts}), }) def post(path, **kwargs): json_ = kwargs.pop('json') body = json.dumps(sorted(json_.items())) return choose_reply((path, body), { # combine parts ('bqm/multipart/{}/combine'.format(upload_problem_id), json.dumps([('checksum', combine_checksum)])): json.dumps({}), }) def put(path, data, headers): body = data.read() headers = json.dumps(sorted(headers.items())) replies = { ( 'bqm/multipart/{}/part/{}'.format(upload_problem_id, i+1), part_data[i], json.dumps(sorted([ ('Content-MD5', _b64(part_digest[i])), ('Content-Type', 'application/octet-stream') ])) ): json.dumps({}) for i in parts[2:] } return choose_reply((path, body, headers), replies) session.get = get session.put = put session.post = post return session session = global_mock_session() with mock.patch.object(Client, 'create_session', lambda self: session): with Client('endpoint', 'token') as client: future = client.upload_problem_encoded( upload_data, problem_id=upload_problem_id) try: returned_problem_id = future.result() except Exception as e: self.fail(e) self.assertEqual(returned_problem_id, upload_problem_id)