import logging import os import tempfile import time from typing import Tuple, Sequence import blosc import grpc import lmdb import numpy as np from tqdm import tqdm from . import chunks from . import hangar_service_pb2 from . import hangar_service_pb2_grpc from .header_manipulator_client_interceptor import header_adder_interceptor from .. import constants as c from ..context import Environments from ..txnctx import TxnRegister from ..backends import BACKEND_ACCESSOR_MAP, backend_decoder from ..records import commiting from ..records import hashs from ..records.hashmachine import hash_type_code_from_digest, hash_func_from_tcode from ..records import hash_data_db_key_from_raw_key from ..records import queries from ..records import summarize from ..utils import set_blosc_nthreads set_blosc_nthreads() logger = logging.getLogger(__name__) class HangarClient(object): """Client which connects and handles data transfer to the hangar server. Parameters ---------- envs : Environments environment handles to manage all required calls to the local repostory state. address : str IP:PORT where the hangar server can be reached. auth_username : str, optional, kwarg-only credentials to use for authentication. auth_password : str, optional, kwarg-only, by default ''. credentials to use for authentication, by default ''. wait_for_ready : bool, optional, kwarg-only, be default True. If the client should wait before erring for a short period of time while a server is `UNAVAILABLE`, typically due to it just starting up at the time the connection was made wait_for_read_timeout : float, optional, kwarg-only, by default 5. If `wait_for_ready` is True, the time in seconds which the client should wait before raising an error. Must be positive value (greater than 0) """ def __init__(self, envs: Environments, address: str, *, auth_username: str = '', auth_password: str = '', wait_for_ready: bool = True, wait_for_ready_timeout: float = 5): self.env: Environments = envs self.address: str = address self.wait_ready: bool = wait_for_ready self.wait_ready_timeout: float = abs(wait_for_ready_timeout + 0.001) self.channel: grpc.Channel = None self.stub: hangar_service_pb2_grpc.HangarServiceStub = None self.header_adder_int = header_adder_interceptor(auth_username, auth_password) self.cfg: dict = {} self._rFs: BACKEND_ACCESSOR_MAP = {} for backend, accessor in BACKEND_ACCESSOR_MAP.items(): if accessor is not None: self._rFs[backend] = accessor( repo_path=self.env.repo_path, schema_shape=None, schema_dtype=None) self._rFs[backend].open(mode='r') self._setup_client_channel_config() def _setup_client_channel_config(self): """get grpc client configuration from server and setup channel and stub for use. """ tmp_insec_channel = grpc.insecure_channel(self.address) tmp_channel = grpc.intercept_channel(tmp_insec_channel, self.header_adder_int) tmp_stub = hangar_service_pb2_grpc.HangarServiceStub(tmp_channel) t_init, t_tot = time.time(), 0 while t_tot < self.wait_ready_timeout: try: request = hangar_service_pb2.GetClientConfigRequest() response = tmp_stub.GetClientConfig(request) self.cfg['push_max_nbytes'] = int(response.config['push_max_nbytes']) self.cfg['optimization_target'] = response.config['optimization_target'] enable_compression = response.config['enable_compression'] if enable_compression == 'NoCompression': compression_val = grpc.Compression.NoCompression elif enable_compression == 'Deflate': compression_val = grpc.Compression.Deflate elif enable_compression == 'Gzip': compression_val = grpc.Compression.Gzip else: compression_val = grpc.Compression.NoCompression self.cfg['enable_compression'] = compression_val except grpc.RpcError as err: if not (err.code() == grpc.StatusCode.UNAVAILABLE) and (self.wait_ready is True): logger.error(err) raise err else: break time.sleep(0.05) t_tot = time.time() - t_init else: err = ConnectionError(f'Server did not connect after: {self.wait_ready_timeout} sec.') logger.error(err) raise err tmp_channel.close() tmp_insec_channel.close() configured_channel = grpc.insecure_channel( self.address, options=[ ('grpc.optimization_target', self.cfg['optimization_target']), ("grpc.keepalive_time_ms", 1000 * 60 * 1), ("grpc.keepalive_timeout_ms", 1000 * 10), ("grpc.http2_min_sent_ping_interval_without_data_ms", 1000 * 10), ("grpc.http2_max_pings_without_data", 0), ("grpc.keepalive_permit_without_calls", 1), ], compression=self.cfg['enable_compression']) self.channel = grpc.intercept_channel(configured_channel, self.header_adder_int) self.stub = hangar_service_pb2_grpc.HangarServiceStub(self.channel) def close(self): """Close reader file handles and the GRPC channel connection, invalidating this instance. """ for backend_accessor in self._rFs.values(): backend_accessor.close() self.channel.close() def ping_pong(self) -> str: """Ping server to ensure that connection is working Returns ------- str Should be value 'PONG' """ request = hangar_service_pb2.PingRequest() response: hangar_service_pb2.PingReply = self.stub.PING(request) return response.result def push_branch_record(self, name: str, head: str ) -> hangar_service_pb2.PushBranchRecordReply: """Create a branch (if new) or update the server branch HEAD to new commit. Parameters ---------- name : str branch name to be pushed head : str commit hash to update the server head to Returns ------- hangar_service_pb2.PushBranchRecordReply code indicating success, message with human readable info """ rec = hangar_service_pb2.BranchRecord(name=name, commit=head) request = hangar_service_pb2.PushBranchRecordRequest(rec=rec) response = self.stub.PushBranchRecord(request) return response def fetch_branch_record(self, name: str ) -> hangar_service_pb2.FetchBranchRecordReply: """Get the latest head commit the server knows about for a given branch Parameters ---------- name : str name of the branch to query on the server Returns ------- hangar_service_pb2.FetchBranchRecordReply rec containing name and head commit if branch exists, along with standard error proto if it does not exist on the server. """ rec = hangar_service_pb2.BranchRecord(name=name) request = hangar_service_pb2.FetchBranchRecordRequest(rec=rec) response = self.stub.FetchBranchRecord(request) return response def push_commit_record(self, commit: str, parentVal: bytes, specVal: bytes, refVal: bytes ) -> hangar_service_pb2.PushBranchRecordReply: """Push a new commit reference to the server. Parameters ---------- commit : str hash digest of the commit to send parentVal : bytes lmdb ref parentVal of the commit specVal : bytes lmdb ref specVal of the commit refVal : bytes lmdb ref refVal of the commit Returns ------- hangar_service_pb2.PushBranchRecordReply standard error proto """ cIter = chunks.clientCommitChunkedIterator(commit=commit, parentVal=parentVal, specVal=specVal, refVal=refVal) response = self.stub.PushCommit(cIter) return response def fetch_commit_record(self, commit: str) -> Tuple[str, bytes, bytes, bytes]: """get the refs for a commit digest Parameters ---------- commit : str digest of the commit to retrieve the references for Returns ------- Tuple[str, bytes, bytes, bytes] ['commit hash', 'parentVal', 'specVal', 'refVal'] """ request = hangar_service_pb2.FetchCommitRequest(commit=commit) replies = self.stub.FetchCommit(request) for idx, reply in enumerate(replies): if idx == 0: refVal = bytearray(reply.total_byte_size) specVal = reply.record.spec parentVal = reply.record.parent offset = 0 size = len(reply.record.ref) refVal[offset: offset + size] = reply.record.ref offset += size if reply.error.code != 0: logger.error(reply.error) return False return (commit, parentVal, specVal, refVal) def fetch_schema(self, schema_hash: str) -> Tuple[str, bytes]: """get the schema specification for a schema hash Parameters ---------- schema_hash : str schema hash to retrieve from the server Returns ------- Tuple[str, bytes] ['schema hash', 'schemaVal'] """ schema_rec = hangar_service_pb2.SchemaRecord(digest=schema_hash) request = hangar_service_pb2.FetchSchemaRequest(rec=schema_rec) reply = self.stub.FetchSchema(request) if reply.error.code != 0: logger.error(reply.error) return False schemaVal = reply.rec.blob return (schema_hash, schemaVal) def push_schema(self, schema_hash: str, schemaVal: bytes) -> hangar_service_pb2.PushSchemaReply: """push a schema hash record to the remote server Parameters ---------- schema_hash : str hash digest of the schema being sent schemaVal : bytes ref value of the schema representation Returns ------- hangar_service_pb2.PushSchemaReply standard error proto indicating success """ rec = hangar_service_pb2.SchemaRecord(digest=schema_hash, blob=schemaVal) request = hangar_service_pb2.PushSchemaRequest(rec=rec) response = self.stub.PushSchema(request) return response def fetch_data( self, schema_hash: str, digests: Sequence[str] ) -> Sequence[Tuple[str, np.ndarray]]: """Fetch data hash digests for a particular schema. As the total size of the data to be transferred isn't known before this operation occurs, if more tensor data digests are requested then the Client is configured to allow in memory at a time, only a portion of the requested digests will actually be materialized. The received digests are listed as the return value of this function, be sure to check that all requested digests have been received! Parameters ---------- schema_hash : str hash of the schema each of the digests is associated with digests : Sequence[str] iterable of data digests to receive Returns ------- Sequence[Tuple[str, np.ndarray]] iterable containing 2-tuples' of the hash digest and np.ndarray data. Raises ------ RuntimeError if received digest != requested or what was reported to be sent. """ try: raw_digests = c.SEP_LST.join(digests).encode() cIter = chunks.tensorChunkedIterator(buf=raw_digests, uncomp_nbytes=len(raw_digests), pb2_request=hangar_service_pb2.FetchDataRequest) replies = self.stub.FetchData(cIter) for idx, reply in enumerate(replies): if idx == 0: uncomp_nbytes, comp_nbytes = reply.uncomp_nbytes, reply.comp_nbytes dBytes, offset = bytearray(comp_nbytes), 0 size = len(reply.raw_data) if size > 0: dBytes[offset:offset + size] = reply.raw_data offset += size except grpc.RpcError as rpc_error: if rpc_error.code() == grpc.StatusCode.RESOURCE_EXHAUSTED: logger.info(rpc_error.details()) else: logger.error(rpc_error.details()) raise rpc_error uncompBytes = blosc.decompress(dBytes) if uncomp_nbytes != len(uncompBytes): raise RuntimeError(f'uncomp_nbytes: {uncomp_nbytes} != received {comp_nbytes}') received_data = [] unpacked_records = chunks.deserialize_record_pack(uncompBytes) for record in unpacked_records: data = chunks.deserialize_record(record) expected_hasher_tcode = hash_type_code_from_digest(data.digest) hash_func = hash_func_from_tcode(expected_hasher_tcode) received_hash = hash_func(data.data) if received_hash != data.digest: logger.error(data.data) raise RuntimeError(f'MANGLED! got: {received_hash} != requested: {data.digest}') received_data.append((received_hash, data.data)) return received_data def push_data(self, schema_hash: str, digests: Sequence[str], pbar: tqdm = None) -> hangar_service_pb2.PushDataReply: """Given a schema and digest list, read the data and send to the server Parameters ---------- schema_hash : str hash of the digest schemas digests : Sequence[str] iterable of digests to be read in and sent to the server pbar : tqdm, optional progress bar instance to be updated as the operation occurs, by default None Returns ------- hangar_service_pb2.PushDataReply standard error proto indicating success Raises ------ KeyError if one of the input digests does not exist on the client rpc_error if the server received corrupt data """ try: specs = [] hashTxn = TxnRegister().begin_reader_txn(self.env.hashenv) for digest in digests: hashKey = hash_data_db_key_from_raw_key(digest) hashVal = hashTxn.get(hashKey, default=False) if not hashVal: raise KeyError(f'No hash record with key: {hashKey}') be_loc = backend_decoder(hashVal) specs.append((digest, be_loc)) finally: TxnRegister().abort_reader_txn(self.env.hashenv) try: totalSize, records = 0, [] for k in self._rFs.keys(): self._rFs[k].__enter__() responses = [] for digest, spec in specs: data = self._rFs[spec.backend].read_data(spec) record = chunks.serialize_record(data, digest, schema_hash) records.append(record) totalSize += len(record) if (totalSize >= self.cfg['push_max_nbytes']) or (len(records) > 2000): # send tensor pack when >= configured max nbytes occupied in memory pbar.update(len(records)) pack = chunks.serialize_record_pack(records) cIter = chunks.tensorChunkedIterator(buf=pack, uncomp_nbytes=len(pack), pb2_request=hangar_service_pb2.PushDataRequest) response = self.stub.PushData.future(cIter) responses.append(response) totalSize = 0 records = [] except grpc.RpcError as rpc_error: logger.error(rpc_error.with_traceback()) raise rpc_error finally: for k in self._rFs.keys(): self._rFs[k].__exit__() if totalSize > 0: # finish sending all remaining tensors if max size has not been hit. pack = chunks.serialize_record_pack(records) cIter = chunks.tensorChunkedIterator(buf=pack, uncomp_nbytes=len(pack), pb2_request=hangar_service_pb2.PushDataRequest) response = self.stub.PushData.future(cIter) responses.append(response) for fut in responses: last = fut.result() return last def fetch_find_missing_commits(self, branch_name): c_commits = commiting.list_all_commits(self.env.refenv) branch_rec = hangar_service_pb2.BranchRecord(name=branch_name) request = hangar_service_pb2.FindMissingCommitsRequest() request.commits.extend(c_commits) request.branch.CopyFrom(branch_rec) reply = self.stub.FetchFindMissingCommits(request) return reply def push_find_missing_commits(self, branch_name): branch_commits = summarize.list_history( refenv=self.env.refenv, branchenv=self.env.branchenv, branch_name=branch_name) branch_rec = hangar_service_pb2.BranchRecord( name=branch_name, commit=branch_commits['head']) request = hangar_service_pb2.FindMissingCommitsRequest() request.commits.extend(branch_commits['order']) request.branch.CopyFrom(branch_rec) reply = self.stub.PushFindMissingCommits(request) return reply def fetch_find_missing_hash_records(self, commit): all_hashs = hashs.HashQuery(self.env.hashenv).list_all_hash_keys_raw() all_hashs_raw = [chunks.serialize_ident(digest, '') for digest in all_hashs] raw_pack = chunks.serialize_record_pack(all_hashs_raw) pb2_func = hangar_service_pb2.FindMissingHashRecordsRequest cIter = chunks.missingHashRequestIterator(commit, raw_pack, pb2_func) responses = self.stub.FetchFindMissingHashRecords(cIter) for idx, response in enumerate(responses): if idx == 0: hBytes, offset = bytearray(response.total_byte_size), 0 size = len(response.hashs) hBytes[offset: offset + size] = response.hashs offset += size uncompBytes = blosc.decompress(hBytes) raw_idents = chunks.deserialize_record_pack(uncompBytes) idents = [chunks.deserialize_ident(raw) for raw in raw_idents] return idents def push_find_missing_hash_records(self, commit, tmpDB: lmdb.Environment = None): if tmpDB is None: with tempfile.TemporaryDirectory() as tempD: tmpDF = os.path.join(tempD, 'test.lmdb') tmpDB = lmdb.open(path=tmpDF, **c.LMDB_SETTINGS) commiting.unpack_commit_ref(self.env.refenv, tmpDB, commit) c_hashs_schemas = queries.RecordQuery(tmpDB).data_hash_to_schema_hash() c_hashes = list(set(c_hashs_schemas.keys())) tmpDB.close() else: c_hashs_schemas = queries.RecordQuery(tmpDB).data_hash_to_schema_hash() c_hashes = list(set(c_hashs_schemas.keys())) c_hashs_raw = [chunks.serialize_ident(digest, '') for digest in c_hashes] raw_pack = chunks.serialize_record_pack(c_hashs_raw) pb2_func = hangar_service_pb2.FindMissingHashRecordsRequest cIter = chunks.missingHashRequestIterator(commit, raw_pack, pb2_func) responses = self.stub.PushFindMissingHashRecords(cIter) for idx, response in enumerate(responses): if idx == 0: hBytes, offset = bytearray(response.total_byte_size), 0 size = len(response.hashs) hBytes[offset: offset + size] = response.hashs offset += size uncompBytes = blosc.decompress(hBytes) s_missing_raw = chunks.deserialize_record_pack(uncompBytes) s_mis_hsh = [chunks.deserialize_ident(raw).digest for raw in s_missing_raw] s_mis_hsh_sch = [(s_hsh, c_hashs_schemas[s_hsh]) for s_hsh in s_mis_hsh] return s_mis_hsh_sch def fetch_find_missing_schemas(self, commit): c_schemaset = set(hashs.HashQuery(self.env.hashenv).list_all_schema_digests()) c_schemas = list(c_schemaset) request = hangar_service_pb2.FindMissingSchemasRequest() request.commit = commit request.schema_digests.extend(c_schemas) response = self.stub.FetchFindMissingSchemas(request) return response def push_find_missing_schemas(self, commit, tmpDB: lmdb.Environment = None): if tmpDB is None: with tempfile.TemporaryDirectory() as tempD: tmpDF = os.path.join(tempD, 'test.lmdb') tmpDB = lmdb.open(path=tmpDF, **c.LMDB_SETTINGS) commiting.unpack_commit_ref(self.env.refenv, tmpDB, commit) c_schemaset = set(queries.RecordQuery(tmpDB).schema_hashes()) c_schemas = list(c_schemaset) tmpDB.close() else: c_schemaset = set(queries.RecordQuery(tmpDB).schema_hashes()) c_schemas = list(c_schemaset) request = hangar_service_pb2.FindMissingSchemasRequest() request.commit = commit request.schema_digests.extend(c_schemas) response = self.stub.PushFindMissingSchemas(request) return response