Python six.ensure_str() Examples

The following are 30 code examples of six.ensure_str(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module six , or try the search function .
Example #1
Source File: squad_utils.py    From albert with Apache License 2.0 6 votes vote down vote up
def evaluate_v1(dataset, predictions):
  f1 = exact_match = total = 0
  for article in dataset:
    for paragraph in article["paragraphs"]:
      for qa in paragraph["qas"]:
        total += 1
        if qa["id"] not in predictions:
          message = ("Unanswered question " + six.ensure_str(qa["id"]) +
                     "  will receive score 0.")
          print(message, file=sys.stderr)
          continue
        ground_truths = [x["text"] for x in qa["answers"]]
        # ground_truths = list(map(lambda x: x["text"], qa["answers"]))
        prediction = predictions[qa["id"]]
        exact_match += metric_max_over_ground_truths(exact_match_score,
                                                     prediction, ground_truths)
        f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)

  exact_match = 100.0 * exact_match / total
  f1 = 100.0 * f1 / total

  return {"exact_match": exact_match, "f1": f1}

####### above are from official SQuAD v1.1 evaluation scripts
####### following are from official SQuAD v2.0 evaluation scripts 
Example #2
Source File: base_model.py    From loaner with Apache License 2.0 6 votes vote down vote up
def to_document(self):
    """Creates a search.Document representation of a model.

    Returns:
      search.Document of the current model to be indexed with a valid doc_id. A
          doc_id will be autogenerated when inserted into the Index if one is
          not provided.

    Raises:
      DocumentCreationError: when unable to create a document for the
          model.
    """
    try:
      return search.Document(
          doc_id=six.ensure_str(self.key.urlsafe()),
          fields=self._get_document_fields())

    except (TypeError, ValueError) as e:
      raise DocumentCreationError(e) 
Example #3
Source File: predictor_runner_base.py    From lingvo with Apache License 2.0 6 votes vote down vote up
def _PredictContinuously(self):
    """Waits for new checkpoints and runs predictor continuously."""
    prev_step = -1000000
    while True:
      # TODO(jonathanasdf): how to determine when training finished?
      path = tf.train.latest_checkpoint(self._checkpoint)
      step_str = re.search(r'ckpt-(\d{8})', six.ensure_str(path)).group(1)
      step = int(step_str)
      if step - prev_step >= self._prediction_step_interval:
        if not self._output_dir:
          raise ValueError(
              'output_dir must be specified for _PredictContinuously.')
        output_dir = os.path.join(self._output_dir, 'step_' + step_str)
        tf.io.gfile.makedirs(output_dir)
        self._PredictOneCheckpoint(path, output_dir)
        prev_step = step
        tf.logging.info('Waiting for next checkpoint...')
      time.sleep(_RETRY_SLEEP_SECONDS) 
Example #4
Source File: testrunner.py    From ftw with Apache License 2.0 6 votes vote down vote up
def test_response(self, response_object, regex):
        """
        Checks if the response response contains a regex specified in the
        output stage. It will assert that the regex is present.
        """
        if response_object is None:
            raise errors.TestError(
                'Searching before response received',
                {
                    'regex': regex,
                    'response_object': response_object,
                    'function': 'testrunner.TestRunner.test_response'
                })
        if regex.search(ensure_str(response_object.response)):
            assert True
        else:
            assert False 
Example #5
Source File: msvc_wrapper_for_nvcc.py    From lingvo with Apache License 2.0 6 votes vote down vote up
def GetNvccOptions(argv):
  """Collect the -nvcc_options values from argv.

  Args:
    argv: A list of strings, possibly the argv passed to main().

  Returns:
    1. The string that can be passed directly to nvcc.
    2. The leftover options.
  """

  parser = ArgumentParser()
  parser.add_argument('-nvcc_options', nargs='*', action='append')

  args, leftover = parser.parse_known_args(argv)

  if args.nvcc_options:
    options = _update_options(sum(args.nvcc_options, []))
    return (['--' + six.ensure_str(a) for a in options], leftover)
  return ([], leftover) 
Example #6
Source File: utils.py    From git-pw with MIT License 6 votes vote down vote up
def _tabulate(output, headers, fmt):
    fmt = fmt or git_config('pw.format') or 'table'

    if fmt == 'table':
        return tabulate(output, headers, tablefmt='psql')
    elif fmt == 'simple':
        return tabulate(output, headers, tablefmt='simple')
    elif fmt == 'csv':
        result = six.StringIO()
        writer = csv.writer(
            result, quoting=csv.QUOTE_ALL, lineterminator=os.linesep)
        writer.writerow([ensure_str(h) for h in headers])
        for item in output:
            writer.writerow([ensure_str(i) for i in item])
        return result.getvalue()

    print('pw.format must be one of: table, simple, csv')
    sys.exit(1) 
Example #7
Source File: tokenization_test.py    From albert with Apache License 2.0 6 votes vote down vote up
def test_full_tokenizer(self):
    vocab_tokens = [
        "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
        "##ing", ","
    ]
    with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
      if six.PY2:
        vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
      else:
        contents = "".join([six.ensure_str(x) + "\n" for x in vocab_tokens])
        vocab_writer.write(six.ensure_binary(contents, "utf-8"))

      vocab_file = vocab_writer.name

    tokenizer = tokenization.FullTokenizer(vocab_file)
    os.unlink(vocab_file)

    tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
    self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])

    self.assertAllEqual(
        tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 
Example #8
Source File: __init__.py    From Xpedite with Apache License 2.0 6 votes vote down vote up
def readAtleast(transport, length, timeout):
  """
  Awaits reciept of at least length bytes from underlying transport, till timeout

  :param transport: Handle to a stream based transport
  :param timeout: Max amount to time to await for incoming data
  :param length: Length of data to read

  """
  import six
  LOGGER.debug('Awaiting data %d bytes', length)
  data = ''
  while len(data) < length:
    bufLen = length - len(data)
    block = transport.receive(bufLen, timeout)
    block = six.ensure_str(block)
    if block:
      data = data + block
    else:
      raise Exception('socket closed - failed to read datagram')

  logData = data if length < 400 else '{} ...'.format(data[0:45])
  LOGGER.debug('Received data |%s|', logData)
  return data 
Example #9
Source File: uarchSpecLoader.py    From Xpedite with Apache License 2.0 6 votes vote down vote up
def downloadFile(url, path):
  """
  Downloads micro architecture specifications from internet

  :param url: url of the website hosting the specifications
  :param path: Path of download directory

  """
  import six
  try:
    connection = urllib.request.urlopen(urllib.request.Request(url), context=CONFIG.sslContext)
    data = connection.read()
    with open(path, 'w') as fileHandle:
      fileHandle.write(six.ensure_str(data))
    return True
  except urllib.error.HTTPError as ex:
    LOGGER.exception('failed to retrieve file "%s" from url - %s', os.path.basename(path), url)
    raise Exception(ex)
  except IOError:
    LOGGER.exception('failed to open file - %s', path) 
Example #10
Source File: client.py    From dagster with Apache License 2.0 6 votes vote down vote up
def retrieve_pod_logs(self, pod_name, namespace):
        '''Retrieves the raw pod logs for the pod named `pod_name` from Kubernetes.

        Args:
            pod_name (str): The name of the pod from which to retrieve logs.
            namespace (str): The namespace of the pod.

        Returns:
            str: The raw logs retrieved from the pod.
        '''
        check.str_param(pod_name, 'pod_name')
        check.str_param(namespace, 'namespace')

        # We set _preload_content to False here to prevent the k8 python api from processing the response.
        # If the logs happen to be JSON - it will parse in to a dict and then coerce back to a str leaving
        # us with invalid JSON as the quotes have been switched to '
        #
        # https://github.com/kubernetes-client/python/issues/811
        return six.ensure_str(
            self.core_api.read_namespaced_pod_log(
                name=pod_name, namespace=namespace, _preload_content=False
            ).data
        ) 
Example #11
Source File: vrgripper_env_wtl_models.py    From tensor2robot with Apache License 2.0 6 votes vote down vote up
def model_eval_fn(
      self,
      features,
      labels,
      inference_outputs,
      train_loss,
      train_outputs,
      mode,
      config = None,
      params = None):
    """Log the streaming mean of any train outputs. See also base class."""
    if train_outputs is not None:
      eval_outputs = {}
      for key, value in train_outputs.items():
        eval_outputs['mean_' + six.ensure_str(key)] = tf.metrics.mean(value)
      return eval_outputs 
Example #12
Source File: vrgripper_env_wtl_models.py    From tensor2robot with Apache License 2.0 6 votes vote down vote up
def model_eval_fn(
      self,
      features,
      labels,
      inference_outputs,
      train_loss,
      train_outputs,
      mode,
      config = None,
      params = None):
    """Log the streaming mean of any train outputs. See also base class."""
    if train_outputs is not None:
      eval_outputs = {}
      for key, value in train_outputs.items():
        eval_outputs['mean_' + six.ensure_str(key)] = tf.metrics.mean(value)
      return eval_outputs 
Example #13
Source File: visualization.py    From tensor2robot with Apache License 2.0 6 votes vote down vote up
def plot_labels(labels, max_label=1, predictions=None, name=''):
  """Plots integer labels and optionally predictions as images.

  By default takes the first 3 in the batch.

  Args:
    labels: Batch x 1 size tensor of labels
    max_label:  An integer indicating the largest possible label
    predictions: Batch x max_label size tensor of predictions (range 0-1.0)
    name: string to name tensorflow summary
  """
  if max_label > 1:
    labels = tf.one_hot(
        labels, max_label, on_value=1.0, off_value=0.0, dtype=tf.float32)
  labels_image = tf.reshape(labels[:3], (1, 3, max_label, 1))
  empty_image = tf.zeros((1, 3, max_label, 1))
  image = tf.concat([labels_image, empty_image, empty_image], axis=-1)
  if predictions is not None:
    pred_image = tf.reshape(predictions[:3], (1, 3, 4, 1))
    image2 = tf.concat([empty_image, pred_image, empty_image], axis=-1)
    image = tf.concat([image, image2], axis=1)
  tf.summary.image('labels_' + six.ensure_str(name), image, max_outputs=1) 
Example #14
Source File: visualization.py    From tensor2robot with Apache License 2.0 6 votes vote down vote up
def add_heatmap_summary(feature_query, feature_map, name):
  """Plots dot produce of feature_query on feature_map.

  Args:
    feature_query: Batch x embedding size tensor of goal embeddings
    feature_map: Batch x h x w x embedding size of pregrasp scene embeddings
    name: string to name tensorflow summaries
  Returns:
     Batch x h x w x 1 heatmap
  """
  batch, dim = feature_query.shape
  reshaped_query = tf.reshape(feature_query, (int(batch), 1, 1, int(dim)))
  heatmaps = tf.reduce_sum(
      tf.multiply(feature_map, reshaped_query), axis=3, keep_dims=True)
  tf.summary.image(name, heatmaps)
  shape = tf.shape(heatmaps)
  softmaxheatmaps = tf.nn.softmax(tf.reshape(heatmaps, (int(batch), -1)))
  tf.summary.image(
      six.ensure_str(name) + 'softmax', tf.reshape(softmaxheatmaps, shape))
  return heatmaps 
Example #15
Source File: http.py    From jellyfin-kodi with GNU General Public License v3.0 6 votes vote down vote up
def _authorization(self, data):

        auth = "MediaBrowser "
        auth += "Client=%s, " % self.config.data.get('app.name', "Jellyfin for Kodi")
        auth += "Device=%s, " % self.config.data.get('app.device_name', 'Unknown Device')
        auth += "DeviceId=%s, " % self.config.data.get('app.device_id', 'Unknown Device id')
        auth += "Version=%s" % self.config.data.get('app.version', '0.0.0')

        data['headers'].update({'x-emby-authorization': ensure_str(auth, 'utf-8')})

        if self.config.data.get('auth.token') and self.config.data.get('auth.user_id'):

            auth += ', UserId=%s' % self.config.data.get('auth.user_id')
            data['headers'].update({
                'x-emby-authorization': ensure_str(auth, 'utf-8'),
                'X-MediaBrowser-Token': self.config.data.get('auth.token')})

        return data 
Example #16
Source File: api.py    From jellyfin-kodi with GNU General Public License v3.0 6 votes vote down vote up
def get_default_headers(self):
        auth = "MediaBrowser "
        auth += "Client=%s, " % self.config.data['app.name']
        auth += "Device=%s, " % self.config.data['app.device_name']
        auth += "DeviceId=%s, " % self.config.data['app.device_id']
        auth += "Version=%s" % self.config.data['app.version']

        return {
            "Accept": "application/json",
            "Content-type": "application/x-www-form-urlencoded; charset=UTF-8",
            "X-Application": "%s/%s" % (self.config.data['app.name'], self.config.data['app.version']),
            "Accept-Charset": "UTF-8,*",
            "Accept-encoding": "gzip",
            "User-Agent": self.config.data['http.user_agent'] or "%s/%s" % (self.config.data['app.name'], self.config.data['app.version']),
            "x-emby-authorization": ensure_str(auth, 'utf-8')
        } 
Example #17
Source File: constants.py    From dsrf with Apache License 2.0 6 votes vote down vote up
def get_xsd_files():
  """Builds a map of all the available XSD files."""
  all_files = []
  xsd_directory = get_xsd_directory()
  for root, unused_dirs, files in os.walk(xsd_directory):
    all_files.extend(
        os.path.join(root, filename)
        for filename in files
        if six.ensure_str(filename).endswith('.xsd'))

  schemas = collections.defaultdict(dict)
  for xsd_file in all_files:
    relative_path = os.path.relpath(xsd_file, xsd_directory)
    directory, unused_filename = os.path.split(relative_path)
    profile_name, profile_version = os.path.split(directory)
    schemas[profile_name][profile_version] = xsd_file

  return schemas


# This is a dict of dicts in the form
# {'ProfileName': {
#     '1.1': '/path/to/1.1/schema.xsd'
#     '1.2': '/path/to/1.2/schema.xsd'}} 
Example #18
Source File: compat_test.py    From scalyr-agent-2 with Apache License 2.0 5 votes vote down vote up
def test_environ_pop(self):
        os_environ_unicode[EnvironUnicode.TEST_VAR] = six.ensure_str("Test four string")
        value = os_environ_unicode.pop(EnvironUnicode.TEST_VAR)
        self.assertEqual(value, six.text_type("Test four string")) 
Example #19
Source File: conftest.py    From dagster with Apache License 2.0 5 votes vote down vote up
def events_jar():
    git_repo_root = six.ensure_str(
        subprocess.check_output(['git', 'rev-parse', '--show-toplevel']).strip()
    )

    temp_dir = os.path.join(
        get_system_temp_directory(), 'dagster_examples_tests', 'event_pipeline_demo_tests'
    )

    mkdir_p(temp_dir)
    dst = os.path.join(temp_dir, 'events.jar')

    if os.path.exists(dst):
        print('events jar already exists, skipping')
    else:
        subprocess.check_call(
            ['sbt', 'events/assembly'], cwd=os.path.join(git_repo_root, 'scala_modules')
        )

        src = os.path.join(
            git_repo_root,
            'scala_modules',
            'events/target/scala-2.11/events-assembly-0.1.0-SNAPSHOT.jar',
        )
        subprocess.check_call(['cp', src, dst])

    yield dst 
Example #20
Source File: k8s_test.py    From scalyr-agent-2 with Apache License 2.0 5 votes vote down vote up
def create_object_from_dict(d):
    """
    Takes a dict of key-value pairs and converts it to an object with attributes
    equal to the names of the keys and values equal to the values
    """
    # 2->TODO 'type' function accepts only str, not unicode, python3 has the opposite situation.
    result = type(six.ensure_str(""), (), {})()
    for key, value in six.iteritems(d):
        setattr(result, key, value)
    return result 
Example #21
Source File: connection.py    From scalyr-agent-2 with Apache License 2.0 5 votes vote down vote up
def _get(self, request_path):
        self.__http_response = None
        self.__connection.request(
            six.ensure_str("GET"),
            six.ensure_str(request_path),
            headers=self._standard_headers,
        ) 
Example #22
Source File: gke_launch.py    From lingvo with Apache License 2.0 5 votes vote down vote up
def build_docker_image(image, base_image, code_directory, extra_envs):
  """Build a docker image and push it to the location specified by image.

  Args:
    image: String name of tag to use, e.g., 'gcr.io/foo/bar:version'
    base_image: String name of base lingvo image to build from.
    code_directory: Location of directory whose contents will be copied into the
      image.
    extra_envs: A comma-separated list of key=value environment variables to be
      built into the docker.
  """
  preamble = [
      "FROM %s AS lingvo" % base_image,
  ]
  envs = [
      "ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/home/kubernetes/bin/nvidia/lib64",
      "ENV PATH=${PATH}:/home/kubernetes/bin/nvidia/bin",
  ]
  for env_pairs in six.ensure_str(extra_envs).split(","):
    envs += ["ENV %s" % env_pairs]

  copy_code = ["WORKDIR /tmp/lingvo", "COPY . ."]

  gpu_docker_file = preamble + envs + copy_code
  tmp_dockerfile = tempfile.mkstemp(suffix=".dockerfile")[1]
  with open(tmp_dockerfile, "w") as f:
    f.write("\n".join(gpu_docker_file))
    print("Writing Dockerfile to", tmp_dockerfile)

  os.system("docker build --tag %s --no-cache -f- %s < %s " %
            (image, code_directory, tmp_dockerfile))
  os.system("docker push %s" % image) 
Example #23
Source File: compat_test.py    From scalyr-agent-2 with Apache License 2.0 5 votes vote down vote up
def test_environ_set(self):
        os_environ_unicode[EnvironUnicode.TEST_VAR] = six.ensure_str("Test two string")
        self.assertEqual(
            os_environ_unicode.get(EnvironUnicode.TEST_VAR),
            six.text_type("Test two string"),
        ) 
Example #24
Source File: git_utils.py    From dagster with Apache License 2.0 5 votes vote down vote up
def git_check_status():
    changes = six.ensure_str(subprocess.check_output(['git', 'status', '--porcelain']))
    if changes != '':
        raise Exception(
            'Bailing: Cannot publish with changes present in git repo:\n{changes}'.format(
                changes=changes
            )
        ) 
Example #25
Source File: git_utils.py    From dagster with Apache License 2.0 5 votes vote down vote up
def git_user():
    return six.ensure_str(
        subprocess.check_output(['git', 'config', '--get', 'user.name']).decode('utf-8').strip()
    ) 
Example #26
Source File: git_utils.py    From dagster with Apache License 2.0 5 votes vote down vote up
def git_repo_root():
    return six.ensure_str(subprocess.check_output(['git', 'rev-parse', '--show-toplevel']).strip()) 
Example #27
Source File: check_library_docs.py    From dagster with Apache License 2.0 5 votes vote down vote up
def git_repo_root():
    return six.ensure_str(subprocess.check_output(['git', 'rev-parse', '--show-toplevel']).strip()) 
Example #28
Source File: test_resources.py    From dagster with Apache License 2.0 5 votes vote down vote up
def generate_ssh_key():
    # generate private/public key pair
    key = rsa.generate_private_key(backend=default_backend(), public_exponent=65537, key_size=2048)

    # get private key in PEM container format
    return six.ensure_str(
        key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.TraditionalOpenSSL,
            encryption_algorithm=serialization.NoEncryption(),
        )
    ) 
Example #29
Source File: connection.py    From scalyr-agent-2 with Apache License 2.0 5 votes vote down vote up
def _post(self, request_path, body):
        self.__http_response = None
        self.__connection.request(
            six.ensure_str("POST"),
            six.ensure_str(request_path),
            body=body,
            headers=self._standard_headers,
        ) 
Example #30
Source File: test_resources.py    From dagster with Apache License 2.0 5 votes vote down vote up
def test_ssh_connection_with_key_string(ssh_mock):
    ssh_key = generate_ssh_key()

    ssh_resource = SSHResource(
        remote_host='remote_host',
        remote_port=12345,
        username='username',
        password=None,
        timeout=10,
        key_string=six.ensure_str(ssh_key),
        keepalive_interval=30,
        compress=True,
        no_host_key_check=False,
        allow_host_key_change=False,
        logger=logging.root.getChild('test_resources'),
    )

    with ssh_resource.get_connection():
        ssh_mock.return_value.connect.assert_called_once_with(
            hostname='remote_host',
            username='username',
            key_filename=None,
            pkey=key_from_str(ssh_key),
            timeout=10,
            compress=True,
            port=12345,
            sock=None,
        )