# Copyright 2018 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Evaluation functions. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from lib import data import numpy as np import scipy.spatial def closest_line(query_lines, metric='cosine'): """Compute the distance to, and parameters for, the closest line to each line in query_lines. Args: - query_lines: Array of lines to compute closest matches for, shape (n_lines, width, height, 1) - metric: String to pass to scipy.spatial.distance.cdist to choose which distance metric to use Returns: - min_dist, starts, ends: Arrays of shape (n_lines,) denoting the distance to the nearest ``true'' line and the start and end points. """ h, w = query_lines.shape[1:-1] # Construct 10000 lines with these dimensions angles = np.linspace(0, 2*np.pi - 2*np.pi/10000, 10000) all_lines = np.array( [(data.draw_line(angle, h, w)) for angle in angles]) # Produce vectorized versions of both for use with scipy.spatial flat_query = query_lines.reshape(query_lines.shape[0], -1) flat_all = all_lines.reshape(all_lines.shape[0], -1) # Compute pairwise distance matrix of query lines with all valid lines distances = scipy.spatial.distance.cdist(flat_query, flat_all, metric) min_dist_idx = np.argmin(distances, axis=-1) min_dist = distances[np.arange(distances.shape[0]), min_dist_idx] angles = np.array([angles[n] for n in min_dist_idx]) return min_dist, angles def smoothness_score(angles): """Computes the smoothness score of a line interpolation according to the angles of each line. Args: - angles: Array of shape (n_interpolations, n_lines_per_interpolation) giving the angle of each line in each interpolation. Returns: - smoothness_scores: Array of shape (n_interpolations,) giving the average smoothness score for all of the provided interpolations. """ angles = np.atleast_2d(angles) # Remove discontinuities larger than np.pi angles = np.unwrap(angles) diffs = np.abs(np.diff(angles, axis=-1)) # Compute the angle difference from the first and last point total_diff = np.abs(angles[:, :1] - angles[:, -1:]) # When total_diff is zero, there's no way to compute this score zero_diff = (total_diff < 1e-4).flatten() normalized_diffs = diffs/total_diff deviation = np.max(normalized_diffs, axis=-1) - 1./(angles.shape[1] - 1) # Set score to NaN when we aren't able to compute it deviation[zero_diff] = np.nan return deviation def line_eval(interpolated_lines): """Given a group of line interpolations, compute mean nearest line distance and mean smoothness score for all of the interpolations. This version of this metric is meant for vertical lines only. Args: - interpolated_lines: Collection of line interpolation images, shape (n_interpolations, n_lines_per_interpolation, height, width, 1) Returns: - mean_distance: Average distance to closest ``real'' line. - mean_smoothness: Average interpolation smoothness """ original_shape = interpolated_lines.shape min_dist, angles = closest_line( interpolated_lines.reshape((-1,) + original_shape[2:])) mean_distance = np.mean(min_dist) smoothness_scores = smoothness_score( angles.reshape(original_shape[0], original_shape[1])) nan_scores = np.isnan(smoothness_scores) # If all scores were NaN, set the mean score to NaN if np.all(nan_scores): mean_smoothness = np.nan # Otherwise only compute mean for non-NaN scores else: sum_smoothness = np.sum(smoothness_scores[np.logical_not(nan_scores)]) mean_smoothness = sum_smoothness/float(len(nan_scores)) return np.float32(mean_distance), np.float32(mean_smoothness)