# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Extracts non-empty patches of extracted stafflines.

Extracts vertical slices of the image where glyphs are expected
(see `staffline_extractor.py`), and takes horizontal windows of the slice which
will be clustered. Some patches will have a glyph roughly in their center, and
the corresponding cluster centroids will be labeled as such.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging

import apache_beam as beam
from apache_beam import metrics
from moonlight.staves import staffline_extractor
from moonlight.util import more_iter_tools
import numpy as np
from six.moves import filter
import tensorflow as tf


def _filter_patch(patch, min_num_dark_pixels=10):
  unused_patch_name, patch = patch
  return np.greater_equal(np.sum(np.less(patch, 0.5)), min_num_dark_pixels)


class StafflinePatchesDoFn(beam.DoFn):
  """Runs the staffline patches graph."""

  def __init__(self, patch_height, patch_width, num_stafflines, timeout_ms,
               max_patches_per_page):
    self.patch_height = patch_height
    self.patch_width = patch_width
    self.num_stafflines = num_stafflines
    self.timeout_ms = timeout_ms
    self.max_patches_per_page = max_patches_per_page
    self.total_pages_counter = metrics.Metrics.counter(self.__class__,
                                                       'total_pages')
    self.failed_pages_counter = metrics.Metrics.counter(self.__class__,
                                                        'failed_pages')
    self.successful_pages_counter = metrics.Metrics.counter(
        self.__class__, 'successful_pages')
    self.empty_pages_counter = metrics.Metrics.counter(self.__class__,
                                                       'empty_pages')
    self.total_patches_counter = metrics.Metrics.counter(
        self.__class__, 'total_patches')
    self.emitted_patches_counter = metrics.Metrics.counter(
        self.__class__, 'emitted_patches')

  def start_bundle(self):
    self.extractor = staffline_extractor.StafflinePatchExtractor(
        patch_height=self.patch_height,
        patch_width=self.patch_width,
        run_options=tf.RunOptions(timeout_in_ms=self.timeout_ms))
    self.session = tf.Session(graph=self.extractor.graph)

  def process(self, png_path):
    self.total_pages_counter.inc()
    try:
      with self.session.as_default():
        patches_iter = self.extractor.page_patch_iterator(png_path)
    # pylint: disable=broad-except
    except Exception:
      logging.exception('Skipping failed music score (%s)', png_path)
      self.failed_pages_counter.inc()
      return
    patches_iter = filter(_filter_patch, patches_iter)

    if 0 < self.max_patches_per_page:
      # Subsample patches.
      patches = more_iter_tools.iter_sample(patches_iter,
                                            self.max_patches_per_page)
    else:
      patches = list(patches_iter)

    if not patches:
      self.empty_pages_counter.inc()
    self.total_patches_counter.inc(len(patches))

    # Serialize each patch as an Example.
    for patch_name, patch in patches:
      example = tf.train.Example()
      example.features.feature['name'].bytes_list.value.append(
          patch_name.encode('utf-8'))
      example.features.feature['features'].float_list.value.extend(
          patch.ravel())
      example.features.feature['height'].int64_list.value.append(patch.shape[0])
      example.features.feature['width'].int64_list.value.append(patch.shape[1])
      yield example

    self.successful_pages_counter.inc()
    # Patches are sub-sampled by this point.
    self.emitted_patches_counter.inc(len(patches))

  def finish_bundle(self):
    self.session.close()
    del self.extractor
    del self.session