import numpy as np import time import datetime def check_value(inds, val): # Check to see if an array is a single element equaling a particular value # Good for pre-processing inputs in a function if(np.array(inds).size == 1): if(inds == val): return True return False def flatten_nd_array(pts_nd, axis=1): # Flatten an nd array into a 2d array with a certain axis # INPUTS # pts_nd N0xN1x...xNd array # axis integer # OUTPUTS # pts_flt prod(N \ N_axis) x N_axis array NDIM = pts_nd.ndim SHP = np.array(pts_nd.shape) nax = np.setdiff1d(np.arange(0, NDIM), np.array((axis))) # non axis indices NPTS = np.prod(SHP[nax]) axorder = np.concatenate((nax, np.array(axis).flatten()), axis=0) pts_flt = pts_nd.transpose((axorder)) pts_flt = pts_flt.reshape(NPTS, SHP[axis]) return pts_flt def unflatten_2d_array(pts_flt, pts_nd, axis=1, squeeze=False): # Unflatten a 2d array with a certain axis # INPUTS # pts_flt prod(N \ N_axis) x M array # pts_nd N0xN1x...xNd array # axis integer # squeeze bool if true, M=1, squeeze it out # OUTPUTS # pts_out N0xN1x...xNd array NDIM = pts_nd.ndim SHP = np.array(pts_nd.shape) nax = np.setdiff1d(np.arange(0, NDIM), np.array((axis))) # non axis indices if(squeeze): axorder = nax axorder_rev = np.argsort(axorder) M = pts_flt.shape[1] NEW_SHP = SHP[nax].tolist() pts_out = pts_flt.reshape(NEW_SHP) pts_out = pts_out.transpose(axorder_rev) else: axorder = np.concatenate((nax, np.array(axis).flatten()), axis=0) axorder_rev = np.argsort(axorder) M = pts_flt.shape[1] NEW_SHP = SHP[nax].tolist() NEW_SHP.append(M) pts_out = pts_flt.reshape(NEW_SHP) pts_out = pts_out.transpose(axorder_rev) return pts_out def na(): return np.newaxis class Timer(): def __init__(self): self.cur_t = time.time() def tic(self): self.cur_t = time.time() def toc(self): return time.time() - self.cur_t def tocStr(self, t=-1): if(t == -1): return str(datetime.timedelta(seconds=np.round(time.time() - self.cur_t, 3)))[:-4] else: return str(datetime.timedelta(seconds=np.round(t, 3)))[:-4]