# Copyright (c) 2017-2018 Uber Technologies, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import division import copy import logging import time from collections import namedtuple from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from enum import Enum import psutil import six import tensorflow.compat.v1 as tf # pylint: disable=import-error from petastorm import make_reader from petastorm.etl.dataset_metadata import get_schema_from_dataset_url from petastorm.reader_impl.pickle_serializer import PickleSerializer from petastorm.reader_impl.pyarrow_serializer import PyArrowSerializer from petastorm.tf_utils import tf_tensors from petastorm.unischema import match_unischema_fields from petastorm.workers_pool.dummy_pool import DummyPool from petastorm.workers_pool.process_pool import ProcessPool from petastorm.workers_pool.thread_pool import ThreadPool logger = logging.getLogger(__name__) BenchmarkResult = namedtuple('BenchmarkResult', ['time_mean', 'samples_per_second', 'memory_info', 'cpu']) class WorkerPoolType(Enum): """Defines a type of parallelism used in the benchmark: multithreading, multiprocessing or none (single-thread)""" THREAD = 'thread' """A thread pool is used by the benchmark""" PROCESS = 'process' """A process pool is used by the benchmark""" NONE = 'dummy' """IO and loading will be done on a single thread. No parallelism.""" def __str__(self): return self.value class ReadMethod(Enum): """Defines whether a Tensorflow or plain Python reading method would be used during the benchmark""" TF = 'tf' """Tensorflow reading method will be used during the benchmark (``tf_tensor`` method)""" PYTHON = 'python' """Pure python reading method will be used during the benchmark (``next(reader)``)""" def __str__(self): return self.value def _time_warmup_and_work(reader, warmup_cycles_count, measure_cycles_count, do_work_func=None): if not do_work_func: do_work_func = lambda: next(reader) # noqa _time_multiple_iterations(warmup_cycles_count, do_work_func, lambda: reader.diagnostics) logger.info('Done warmup') this_process = psutil.Process() this_process.cpu_percent() duration = _time_multiple_iterations(measure_cycles_count, do_work_func, lambda: reader.diagnostics) cpu_percent = this_process.cpu_percent() time_mean = duration / measure_cycles_count result = BenchmarkResult(time_mean=time_mean, samples_per_second=1.0 / time_mean, memory_info=this_process.memory_full_info(), cpu=cpu_percent) logger.info('Done measuring: %s', str(result)) return result def _time_warmup_and_work_tf(reader, warmup_cycles_count, measure_cycles_count, shuffling_queue_size, min_after_dequeue): with tf.Session() as sess: sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) readout_tensors = tf_tensors(reader, shuffling_queue_size, min_after_dequeue) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord, start=True, sess=sess) result = _time_warmup_and_work(reader, warmup_cycles_count, measure_cycles_count, lambda: sess.run(readout_tensors)) coord.request_stop() coord.join(threads) return result def reader_throughput(dataset_url, field_regex=None, warmup_cycles_count=300, measure_cycles_count=1000, pool_type=WorkerPoolType.THREAD, loaders_count=3, profile_threads=False, read_method=ReadMethod.PYTHON, shuffling_queue_size=500, min_after_dequeue=400, reader_extra_args=None, pyarrow_serialize=False, spawn_new_process=True): """Constructs a Reader instance and uses it to performs throughput measurements. The function will spawn a new process if ``spawn_separate_process`` is set. This is needed to make memory footprint measurements accurate. :param dataset_url: A url of the dataset to be used for measurements. :param field_regex: A list of regular expressions. Only fields that match one of the regex patterns will be used during the benchmark. :param warmup_cycles_count: Number of warmup cycles. During warmup cycles no measurements are being recorded. :param measure_cycles_count: Number of measurements cycles. Only time elapsed during measurements cycles are used in throughput calculations. :param pool_type: :class:`WorkerPoolType` enum value. :param loaders_count: Number of threads (same thread is used for IO and decoding). :param profile_threads: Enables profiling threads. Will print result when thread pool is shut down. :param read_method: An enum :class:`ReadMethod` that defines whether a :class:`petastorm.reader.Reader` will be used. :param shuffling_queue_size: Maximum number of elements in the shuffling queue. :param min_after_dequeue: Minimum number of elements in a shuffling queue before entries can be read from it. :param reader_extra_args: Extra arguments that would be passed to Reader constructor. :param pyarrow_serialize: When True, pyarrow.serialize library will be used for serializing decoded payloads. :param spawn_new_process: This function will respawn itself in a new process if the argument is True. Spawning a new process is needed to get an accurate memory footprint. :return: An instance of ``BenchmarkResult`` namedtuple with the results of the benchmark. The namedtuple has the following fields: `time_mean`, `samples_per_second`, `memory_info` and `cpu` """ if not reader_extra_args: reader_extra_args = dict() if spawn_new_process: args = copy.deepcopy(locals()) args['spawn_new_process'] = False executor = ProcessPoolExecutor(1) future = executor.submit(reader_throughput, **args) return future.result() logger.info('Arguments: %s', locals()) if 'schema_fields' not in reader_extra_args: unischema_fields = match_unischema_fields(get_schema_from_dataset_url(dataset_url), field_regex) reader_extra_args['schema_fields'] = unischema_fields logger.info('Fields used in the benchmark: %s', str(reader_extra_args['schema_fields'])) with make_reader(dataset_url, num_epochs=None, reader_pool_type=str(pool_type), workers_count=loaders_count, pyarrow_serialize=pyarrow_serialize, **reader_extra_args) as reader: if read_method == ReadMethod.PYTHON: result = _time_warmup_and_work(reader, warmup_cycles_count, measure_cycles_count) elif read_method == ReadMethod.TF: result = _time_warmup_and_work_tf(reader, warmup_cycles_count, measure_cycles_count, shuffling_queue_size, min_after_dequeue) else: raise RuntimeError('Unexpected reader_type value: %s', str(read_method)) return result def _create_concurrent_executor(pool_type, decoders_count): if pool_type == WorkerPoolType.PROCESS: decoder_pool_executor = ProcessPoolExecutor(decoders_count) elif pool_type == WorkerPoolType.THREAD: decoder_pool_executor = ThreadPoolExecutor(decoders_count) else: raise ValueError('Unexpected pool type value: %s', pool_type) return decoder_pool_executor def _create_worker_pool(pool_type, workers_count, profiling_enabled, pyarrow_serialize): """Different worker pool implementation (in process none or thread-pool, out of process pool)""" if pool_type == WorkerPoolType.THREAD: worker_pool = ThreadPool(workers_count, profiling_enabled=profiling_enabled) elif pool_type == WorkerPoolType.PROCESS: worker_pool = ProcessPool(workers_count, serializer=PyArrowSerializer() if pyarrow_serialize else PickleSerializer()) elif pool_type == WorkerPoolType.NONE: worker_pool = DummyPool() else: raise ValueError('Supported pool types are thread, process or dummy. Got {}.'.format(pool_type)) return worker_pool def _time_multiple_iterations(iterations, work_func, diags_info_func=None, report_period=1.0): start_time = time.time() last_reported_time = start_time last_reported_count = 0 for current_cycle in six.moves.xrange(iterations): work_func() now = time.time() eps = 1e-9 if now - last_reported_time > report_period: message = '{:2.2f} (mean: {:2.2f}) iterations/sec.' \ .format(float(current_cycle - last_reported_count) / (eps + now - last_reported_time), float(current_cycle) / (eps + now - start_time)) last_reported_count = current_cycle last_reported_time = now if diags_info_func: message += ' diags:{}'.format(str(diags_info_func())) logging.debug(message) return time.time() - start_time