Python numpy.newaxis() Examples

The following are 30 code examples for showing how to use numpy.newaxis(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.

You may check out the related API usage on the sidebar.

You may also want to check out all available functions/classes of the module numpy , or try the search function .

Example 1
Project: DDPAE-video-prediction   Author: jthsieh   File: video_transforms.py    License: MIT License 8 votes vote down vote up
def resize(video, size, interpolation):
  if interpolation == 'bilinear':
    inter = cv2.INTER_LINEAR
  elif interpolation == 'nearest':
    inter = cv2.INTER_NEAREST
  else:
    raise NotImplementedError

  shape = video.shape[:-3]
  video = video.reshape((-1, *video.shape[-3:]))
  resized_video = np.zeros((video.shape[0], size[1], size[0], video.shape[-1]))
  for i in range(video.shape[0]):
    img = cv2.resize(video[i], size, inter)
    if len(img.shape) == 2:
      img = img[:, :, np.newaxis]
    resized_video[i] = img
  return resized_video.reshape((*shape, size[1], size[0], video.shape[-1])) 
Example 2
Project: EDeN   Author: fabriziocosta   File: __init__.py    License: MIT License 7 votes vote down vote up
def plot_confusion_matrix(y_true, y_pred, size=None, normalize=False):
    """plot_confusion_matrix."""
    cm = confusion_matrix(y_true, y_pred)
    fmt = "%d"
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        fmt = "%.2f"
    xticklabels = list(sorted(set(y_pred)))
    yticklabels = list(sorted(set(y_true)))
    if size is not None:
        plt.figure(figsize=(size, size))
    heatmap(cm, xlabel='Predicted label', ylabel='True label',
            xticklabels=xticklabels, yticklabels=yticklabels,
            cmap=plt.cm.Blues, fmt=fmt)
    if normalize:
        plt.title("Confusion matrix (norm.)")
    else:
        plt.title("Confusion matrix")
    plt.gca().invert_yaxis() 
Example 3
Project: mmdetection   Author: open-mmlab   File: test_masks.py    License: Apache License 2.0 7 votes vote down vote up
def test_bitmap_mask_resize():
    # resize with empty bitmap masks
    raw_masks = dummy_raw_bitmap_masks((0, 28, 28))
    bitmap_masks = BitmapMasks(raw_masks, 28, 28)
    resized_masks = bitmap_masks.resize((56, 72))
    assert len(resized_masks) == 0
    assert resized_masks.height == 56
    assert resized_masks.width == 72

    # resize with bitmap masks contain 1 instances
    raw_masks = np.diag(np.ones(4, dtype=np.uint8))[np.newaxis, ...]
    bitmap_masks = BitmapMasks(raw_masks, 4, 4)
    resized_masks = bitmap_masks.resize((8, 8))
    assert len(resized_masks) == 1
    assert resized_masks.height == 8
    assert resized_masks.width == 8
    truth = np.array([[[1, 1, 0, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0, 0],
                       [0, 0, 1, 1, 0, 0, 0, 0], [0, 0, 1, 1, 0, 0, 0, 0],
                       [0, 0, 0, 0, 1, 1, 0, 0], [0, 0, 0, 0, 1, 1, 0, 0],
                       [0, 0, 0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 0, 0, 1, 1]]])
    assert (resized_masks.masks == truth).all() 
Example 4
Project: cat-bbs   Author: aleju   File: common.py    License: MIT License 6 votes vote down vote up
def draw_heatmap(img, heatmap, alpha=0.5):
    """Draw a heatmap overlay over an image."""
    assert len(heatmap.shape) == 2 or \
        (len(heatmap.shape) == 3 and heatmap.shape[2] == 1)
    assert img.dtype in [np.uint8, np.int32, np.int64]
    assert heatmap.dtype in [np.float32, np.float64]

    if img.shape[0:2] != heatmap.shape[0:2]:
        heatmap_rs = np.clip(heatmap * 255, 0, 255).astype(np.uint8)
        heatmap_rs = ia.imresize_single_image(
            heatmap_rs[..., np.newaxis],
            img.shape[0:2],
            interpolation="nearest"
        )
        heatmap = np.squeeze(heatmap_rs) / 255.0

    cmap = plt.get_cmap('jet')
    heatmap_cmapped = cmap(heatmap)
    heatmap_cmapped = np.delete(heatmap_cmapped, 3, 2)
    heatmap_cmapped = heatmap_cmapped * 255
    mix = (1-alpha) * img + alpha * heatmap_cmapped
    mix = np.clip(mix, 0, 255).astype(np.uint8)
    return mix 
Example 5
def _project_im_rois(im_rois, scales):
    """Project image RoIs into the image pyramid built by _get_image_blob.
    Arguments:
        im_rois (ndarray): R x 4 matrix of RoIs in original image coordinates
        scales (list): scale factors as returned by _get_image_blob
    Returns:
        rois (ndarray): R x 4 matrix of projected RoI coordinates
        levels (list): image pyramid levels used by each projected RoI
    """
    im_rois = im_rois.astype(np.float, copy=False)

    if len(scales) > 1:
        widths = im_rois[:, 2] - im_rois[:, 0] + 1
        heights = im_rois[:, 3] - im_rois[:, 1] + 1
        areas = widths * heights
        scaled_areas = areas[:, np.newaxis] * (scales[np.newaxis, :] ** 2)
        diff_areas = np.abs(scaled_areas - 224 * 224)
        levels = diff_areas.argmin(axis=1)[:, np.newaxis]
    else:
        levels = np.zeros((im_rois.shape[0], 1), dtype=np.int)

    rois = im_rois * scales[levels]

    return rois, levels 
Example 6
def _project_im_rois(im_rois, scales):
    """Project image RoIs into the image pyramid built by _get_image_blob.
    Arguments:
        im_rois (ndarray): R x 4 matrix of RoIs in original image coordinates
        scales (list): scale factors as returned by _get_image_blob
    Returns:
        rois (ndarray): R x 4 matrix of projected RoI coordinates
        levels (list): image pyramid levels used by each projected RoI
    """
    im_rois = im_rois.astype(np.float, copy=False)

    if len(scales) > 1:
        widths = im_rois[:, 2] - im_rois[:, 0] + 1
        heights = im_rois[:, 3] - im_rois[:, 1] + 1
        areas = widths * heights
        scaled_areas = areas[:, np.newaxis] * (scales[np.newaxis, :] ** 2)
        diff_areas = np.abs(scaled_areas - 224 * 224)
        levels = diff_areas.argmin(axis=1)[:, np.newaxis]
    else:
        levels = np.zeros((im_rois.shape[0], 1), dtype=np.int)

    rois = im_rois * scales[levels]

    return rois, levels 
Example 7
Project: xrft   Author: xgcm   File: xrft.py    License: MIT License 6 votes vote down vote up
def _radial_wvnum(k, l, N, nfactor):
    """ Creates a radial wavenumber based on two horizontal wavenumbers
    along with the appropriate index map
    """

    # compute target wavenumbers
    k = k.values
    l = l.values
    K = np.sqrt(k[np.newaxis,:]**2 + l[:,np.newaxis]**2)
    nbins = int(N/nfactor)
    if k.max() > l.max():
        ki = np.linspace(0., l.max(), nbins)
    else:
        ki = np.linspace(0., k.max(), nbins)

    # compute bin index
    kidx = np.digitize(np.ravel(K), ki)
    # compute number of points for each wavenumber
    area = np.bincount(kidx)
    # compute the average radial wavenumber for each bin
    kr = (np.bincount(kidx, weights=K.ravel())
          / np.ma.masked_where(area==0, area))

    return ki, kr[1:-1] 
Example 8
Project: xrft   Author: xgcm   File: test_xrft.py    License: MIT License 6 votes vote down vote up
def test_dft_2d(self):
        """Test the discrete Fourier transform on 2D data"""
        N = 16
        da = xr.DataArray(np.random.rand(N,N), dims=['x','y'],
                        coords={'x':range(N),'y':range(N)}
                         )
        ft = xrft.dft(da, shift=False)
        npt.assert_almost_equal(ft.values, np.fft.fftn(da.values))

        ft = xrft.dft(da, shift=False, window=True, detrend='constant')
        dim = da.dims
        window = np.hanning(N) * np.hanning(N)[:, np.newaxis]
        da_prime = (da - da.mean(dim=dim)).values
        npt.assert_almost_equal(ft.values, np.fft.fftn(da_prime*window))

        da = xr.DataArray(np.random.rand(N,N), dims=['x','y'],
                         coords={'x':range(N,0,-1),'y':range(N,0,-1)}
                         )
        assert (xrft.power_spectrum(da, shift=False,
                                   density=True) >= 0.).all() 
Example 9
Project: xrft   Author: xgcm   File: test_xrft.py    License: MIT License 6 votes vote down vote up
def test_cross_phase_2d(self, dask):
        Ny, Nx = (32, 16)
        x = np.linspace(0, 1, num=Nx, endpoint=False)
        y = np.ones(Ny)
        f = 6
        phase_offset = np.pi/2
        signal1 = np.cos(2*np.pi*f*x)  # frequency = 1/(2*pi)
        signal2 = np.cos(2*np.pi*f*x - phase_offset)
        da1 = xr.DataArray(data=signal1*y[:,np.newaxis], name='a',
                          dims=['y','x'], coords={'y':y, 'x':x})
        da2 = xr.DataArray(data=signal2*y[:,np.newaxis], name='b',
                          dims=['y','x'], coords={'y':y, 'x':x})
        with pytest.raises(ValueError):
            xrft.cross_phase(da1, da2, dim=['y','x'])

        if dask:
            da1 = da1.chunk({'x': 16})
            da2 = da2.chunk({'x': 16})
        cp = xrft.cross_phase(da1, da2, dim=['x'])
        actual_phase_offset = cp.sel(freq_x=f).values
        npt.assert_almost_equal(actual_phase_offset, phase_offset) 
Example 10
Project: DDPAE-video-prediction   Author: jthsieh   File: video_transforms.py    License: MIT License 6 votes vote down vote up
def __call__(self, video):
    """
    Args:
        img (numpy array): Input image, shape (... x H x W x C), dtype uint8.
    Returns:
        PIL Image: Color jittered image.
    """
    transforms = self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
    reshaped_video = video.reshape((-1, *video.shape[-3:]))
    n_channels = video.shape[-1]
    for i in range(reshaped_video.shape[0]):
      img = reshaped_video[i]
      if n_channels == 1:
        img = img.squeeze(axis=2)
      img = Image.fromarray(img)
      for t in transforms:
        img = t(img)
      img = np.array(img)
      if n_channels == 1:
        img = img[..., np.newaxis]
      reshaped_video[i] = img
    video = reshaped_video.reshape(video.shape)
    return video 
Example 11
Project: DDPAE-video-prediction   Author: jthsieh   File: moving_mnist.py    License: MIT License 6 votes vote down vote up
def generate_moving_mnist(self, num_digits=2):
    '''
    Get random trajectories for the digits and generate a video.
    '''
    data = np.zeros((self.n_frames_total, self.image_size_, self.image_size_), dtype=np.float32)
    for n in range(num_digits):
      # Trajectory
      start_y, start_x = self.get_random_trajectory(self.n_frames_total)
      ind = random.randint(0, self.mnist.shape[0] - 1)
      digit_image = self.mnist[ind]
      for i in range(self.n_frames_total):
        top    = start_y[i]
        left   = start_x[i]
        bottom = top + self.digit_size_
        right  = left + self.digit_size_
        # Draw digit
        data[i, top:bottom, left:right] = np.maximum(data[i, top:bottom, left:right], digit_image)

    data = data[..., np.newaxis]
    return data 
Example 12
Project: gated-graph-transformer-network   Author: hexahedria   File: convert_story.py    License: MIT License 6 votes vote down vote up
def convert(story):
    # import pdb; pdb.set_trace()
    sentence_arr, graphs, query_arr, answer_arr = story
    node_id_w = graphs[2].shape[2]
    edge_type_w = graphs[3].shape[3]

    all_node_strengths = [np.zeros([1])]
    all_node_ids = [np.zeros([1,node_id_w])]
    for num_new_nodes, new_node_strengths, new_node_ids, _ in zip(*graphs):
        last_strengths = all_node_strengths[-1]
        last_ids = all_node_ids[-1]

        cur_strengths = np.concatenate([last_strengths, new_node_strengths], 0)
        cur_ids = np.concatenate([last_ids, new_node_ids], 0)

        all_node_strengths.append(cur_strengths)
        all_node_ids.append(cur_ids)

    all_edges = graphs[3]
    full_n_nodes = all_edges.shape[1]
    all_node_strengths = np.stack([np.pad(x, ((0, full_n_nodes-x.shape[0])), 'constant') for x in all_node_strengths[1:]])
    all_node_ids = np.stack([np.pad(x, ((0, full_n_nodes-x.shape[0]), (0, 0)), 'constant') for x in all_node_ids[1:]])
    all_node_states = np.zeros([len(all_node_strengths), full_n_nodes,0])

    return tuple(x[np.newaxis,...] for x in (all_node_strengths, all_node_ids, all_node_states, all_edges)) 
Example 13
Project: models   Author: kipoi   File: model.py    License: MIT License 6 votes vote down vote up
def predict_on_batch(self, inputs):
            if inputs.shape == (2,):
                inputs = inputs[np.newaxis, :]
            # Encode
            max_len = len(max(inputs, key=len))
            one_hot_ref =  self.encode(inputs[:,0])
            one_hot_alt = self.encode(inputs[:,1])
            # Construct dummy library indicator
            indicator = np.zeros((inputs.shape[0],2))
            indicator[:,1] = 1
            # Compute fold change for all three frames
            fc_changes = []
            for shift in range(3):
                if shift > 0:
                    shifter = np.zeros((one_hot_ref.shape[0],1,4))
                    one_hot_ref = np.concatenate([one_hot_ref, shifter], axis=1)
                    one_hot_alt = np.concatenate([one_hot_alt, shifter], axis=1)
                pred_ref = self.model.predict_on_batch([one_hot_ref, indicator]).reshape(-1)
                pred_variant = self.model.predict_on_batch([one_hot_alt, indicator]).reshape(-1)
                fc_changes.append(np.log2(pred_variant/pred_ref))
            # Return
            return {"mrl_fold_change":fc_changes[0], 
                    "shift_1":fc_changes[1],
                    "shift_2":fc_changes[2]} 
Example 14
Project: DOTA_models   Author: ringringyi   File: swiftshader_renderer.py    License: Apache License 2.0 6 votes vote down vote up
def render(self, take_screenshot=False, output_type=0):
    # self.render_timer.tic()
    self._actual_render()
    # self.render_timer.toc(log_at=1000, log_str='render timer', type='time')

    np_rgb_img = None
    np_d_img = None
    c = 1000.
    if take_screenshot:
      if self.modality == 'rgb':
        screenshot_rgba = np.zeros((self.height, self.width, 4), dtype=np.uint8)
        glReadPixels(0, 0, self.width, self.height, GL_RGBA, GL_UNSIGNED_BYTE, screenshot_rgba)
        np_rgb_img = screenshot_rgba[::-1,:,:3];

      if self.modality == 'depth': 
        screenshot_d = np.zeros((self.height, self.width, 4), dtype=np.uint8)
        glReadPixels(0, 0, self.width, self.height, GL_RGBA, GL_UNSIGNED_BYTE, screenshot_d)
        np_d_img = screenshot_d[::-1,:,:3];
        np_d_img = np_d_img[:,:,2]*(255.*255./c) + np_d_img[:,:,1]*(255./c) + np_d_img[:,:,0]*(1./c)
        np_d_img = np_d_img.astype(np.float32)
        np_d_img[np_d_img == 0] = np.NaN
        np_d_img = np_d_img[:,:,np.newaxis]

    glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)
    return np_rgb_img, np_d_img 
Example 15
Project: Caffe-Python-Data-Layer   Author: liuxianming   File: SampleIO.py    License: BSD 2-Clause "Simplified" License 5 votes vote down vote up
def substract_mean(img, image_mean):
    """Substract image mean from data sample

    image_mean is a numpy array,
    either 1 * 3 or of the same size as input image
    """
    if image_mean.ndim == 1:
        image_mean = image_mean[:, np.newaxis, np.newaxis]
    img -= image_mean
    return img 
Example 16
Project: svviz   Author: svviz   File: kde.py    License: MIT License 5 votes vote down vote up
def evaluate(self, points):
        points = atleast_2d(points)

        d, m = points.shape
        if d != self.d:
            if d == 1 and m == self.d:
                # points was passed in as a row vector
                points = reshape(points, (self.d, 1))
                m = 1
            else:
                msg = "points have dimension %s, dataset has dimension %s" % (d,
                    self.d)
                raise ValueError(msg)

        result = zeros((m,), dtype=np.float)

        if m >= self.n:
            # there are more points than data, so loop over data
            for i in range(self.n):
                diff = self.dataset[:, i, newaxis] - points
                tdiff = dot(self.inv_cov, diff)
                energy = sum(diff*tdiff,axis=0) / 2.0
                result = result + exp(-energy)
        else:
            # loop over points
            for i in range(m):
                diff = self.dataset - points[:, i, newaxis]
                tdiff = dot(self.inv_cov, diff)
                energy = sum(diff * tdiff, axis=0) / 2.0
                result[i] = sum(exp(-energy), axis=0)

        result = result / self._norm_factor

        return result 
Example 17
Project: fenics-topopt   Author: zfergus   File: filter.py    License: MIT License 5 votes vote down vote up
def filter_variables(self, x, xPhys, ft):
        if ft == 0:
            xPhys[:] = x
        elif ft == 1:
            xPhys[:] = np.asarray(self.H * x[np.newaxis].T / self.Hs)[:, 0] 
Example 18
Project: fenics-topopt   Author: zfergus   File: filter.py    License: MIT License 5 votes vote down vote up
def filter_compliance_sensitivities(self, xPhys, dc, ft):
        if ft == 0:
            dc[:] = (np.asarray((self.H * (xPhys * dc))[np.newaxis].T /
                self.Hs)[:, 0] / np.maximum(0.001, xPhys))
        elif ft == 1:
            dc[:] = np.asarray(self.H * (dc[np.newaxis].T / self.Hs))[:, 0] 
Example 19
Project: fenics-topopt   Author: zfergus   File: filter.py    License: MIT License 5 votes vote down vote up
def filter_volume_sensitivities(self, _xPhys, dv, ft):
        if ft == 0:
            pass
        elif ft == 1:
            dv[:] = np.asarray(self.H * (dv[np.newaxis].T / self.Hs))[:, 0] 
Example 20
Project: fenics-topopt   Author: zfergus   File: problem.py    License: MIT License 5 votes vote down vote up
def compute_displacements(self, xPhys):
        # Setup and solve FE problem
        sK = ((self.KE.flatten()[np.newaxis]).T * (
            self.Emin + (xPhys)**self.penal *
            (self.Emax - self.Emin))).flatten(order='F')
        K = scipy.sparse.coo_matrix((sK, (self.iK, self.jK)),
            shape=(self.ndof, self.ndof)).tocsc()
        # Remove constrained dofs from matrix and convert to coo
        K = deleterowcol(K, self.fixed, self.fixed).tocoo()
        # Solve system
        K1 = cvxopt.spmatrix(K.data, K.row.astype(np.int), K.col.astype(np.int))
        B = cvxopt.matrix(self.f[self.free, :])
        cvxopt.cholmod.linsolve(K1, B)
        self.u[self.free, :] = np.array(B)[:, :] 
Example 21
Project: fenics-topopt   Author: zfergus   File: filter.py    License: MIT License 5 votes vote down vote up
def filter_variables(self, x, xPhys, ft):
        if ft == 0:
            xPhys[:] = x
        elif ft == 1:
            xPhys[:] = np.asarray(self.H * x[np.newaxis].T / self.Hs)[:, 0] 
Example 22
Project: fenics-topopt   Author: zfergus   File: filter.py    License: MIT License 5 votes vote down vote up
def filter_volume_sensitivities(self, _xPhys, dv, ft):
        if ft == 0:
            pass
        elif ft == 1:
            dv[:] = np.asarray(self.H * (dv[np.newaxis].T / self.Hs))[:, 0] 
Example 23
Project: fenics-topopt   Author: zfergus   File: problem.py    License: MIT License 5 votes vote down vote up
def compute_displacements(self, xPhys):
        # Setup and solve FE problem
        sK = ((self.KE.flatten()[np.newaxis]).T * (
            self.Emin + (xPhys)**self.penal *
            (self.Emax - self.Emin))).flatten(order='F')
        K = scipy.sparse.coo_matrix((sK, (self.iK, self.jK)),
            shape=(self.ndof, self.ndof)).tocsc()
        # Remove constrained dofs from matrix and convert to coo
        K = deleterowcol(K, self.fixed, self.fixed).tocoo()
        # Solve system
        K1 = cvxopt.spmatrix(K.data, K.row.astype(np.int), K.col.astype(np.int))
        B = cvxopt.matrix(self.f[self.free, :])
        cvxopt.cholmod.linsolve(K1, B)
        self.u[self.free, :] = np.array(B)[:, :] 
Example 24
def _add_gt_image(self):
    # add back mean
    image = self._image_gt_summaries['image'] + cfg.PIXEL_MEANS
    image = imresize(image[0], self._im_info[:2] / self._im_info[2])
    # BGR to RGB (opencv uses BGR)
    self._gt_image = image[np.newaxis, :,:,::-1].copy(order='C') 
Example 25
def _mkanchors(ws, hs, x_ctr, y_ctr):
  """
  Given a vector of widths (ws) and heights (hs) around a center
  (x_ctr, y_ctr), output a set of anchors (windows).
  """

  ws = ws[:, np.newaxis]
  hs = hs[:, np.newaxis]
  anchors = np.hstack((x_ctr - 0.5 * (ws - 1),
                       y_ctr - 0.5 * (hs - 1),
                       x_ctr + 0.5 * (ws - 1),
                       y_ctr + 0.5 * (hs - 1)))
  return anchors 
Example 26
def demo(net, image_name):
    """Detect object classes in an image using pre-computed object proposals."""

    # Load the demo image
    im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
    im = cv2.imread(im_file)

    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(net, im)
    timer.toc()
    print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time(), boxes.shape[0]))

    # Visualize detections for each class
    CONF_THRESH = 0.8
    NMS_THRESH = 0.3
    for cls_ind, cls in enumerate(CLASSES[1:]):
        cls_ind += 1 # because we skipped background
        cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
        cls_scores = scores[:, cls_ind]
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(torch.from_numpy(dets), NMS_THRESH)
        dets = dets[keep.numpy(), :]
        vis_detections(im, cls, dets, thresh=CONF_THRESH) 
Example 27
Project: DDPAE-video-prediction   Author: jthsieh   File: moving_mnist.py    License: MIT License 5 votes vote down vote up
def load_fixed_set(root, is_train):
  # Load the fixed dataset
  filename = 'mnist_test_seq.npy'
  path = os.path.join(root, filename)
  dataset = np.load(path)
  dataset = dataset[..., np.newaxis]
  return dataset 
Example 28
Project: DDPAE-video-prediction   Author: jthsieh   File: DDPAE_utils.py    License: MIT License 5 votes vote down vote up
def draw_components(images, pose):
  '''
  Draw bounding box for the given pose.
  images: size (N x C x H x W), range [0, 1]
  pose: N x 3
  '''
  images = (images.cpu().numpy() * 255).astype(np.uint8) # [0, 255]
  pose = pose.cpu().numpy()
  N, C, H, W = images.shape
  for i in range(N):
    if C == 1:
      img = images[i][0]
    else:
      img = images[i].transpose((1, 2, 0))
    img = Image.fromarray(img)
    draw = ImageDraw.Draw(img)
    (x, y), w, h = bounding_box(pose[i], H)
    draw.rectangle([x, y, x + w, y + h], outline=128)
    new_img = np.array(img)
    new_img[0, ...] = 255 # Add line
    new_img[-1, ...] = 255 # Add line
    if C == 1:
      new_img = new_img[np.newaxis, :, :]
    else:
      new_img = new_img.transpose((2, 0, 1))
    images[i] = new_img

  # Back to torch tensor
  images = torch.FloatTensor(images / 255)
  return images 
Example 29
Project: disentangling_conditional_gans   Author: zalandoresearch   File: dataset_tool.py    License: MIT License 5 votes vote down vote up
def create_from_images(tfrecord_dir, image_dir, label_dir, shuffle):
    print('Loading images from "%s"' % image_dir)
    image_filenames = sorted(glob.glob(os.path.join(image_dir, '*')))
    if len(image_filenames) == 0:
        error('No input images found')
        
    img = np.asarray(PIL.Image.open(image_filenames[0]))
    resolution = img.shape[0]
    channels = img.shape[2] if img.ndim == 3 else 1
    if img.shape[1] != resolution:
        error('Input images must have the same width and height')
    if resolution != 2 ** int(np.floor(np.log2(resolution))):
        error('Input image resolution must be a power-of-two')
    if channels not in [1, 3]:
        error('Input images must be stored as RGB or grayscale')

    try:
        with open(label_dir, 'rb') as file:
            labels = pickle.load(file)
    except:
        error('Label file was not found')
    
    with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr:
        order = tfr.choose_shuffled_order() if shuffle else np.arange(len(image_filenames))
        reordered_names = []
        for idx in range(order.size):
            image_filename = image_filenames[order[idx]]
            img = np.asarray(PIL.Image.open(image_filename))
            if channels == 1:
                img = img[np.newaxis, :, :] # HW => CHW
            else:
                img = img.transpose(2, 0, 1) # HWC => CHW
            tfr.add_image(img)
            reordered_names.append(os.path.basename(image_filename))
        reordered_labels = []
        for key in reordered_names:
            reordered_labels += [labels[key]]
        reordered_labels = np.stack(reordered_labels, 0)
        tfr.add_labels(reordered_labels)

#---------------------------------------------------------------------------- 
Example 30
Project: disentangling_conditional_gans   Author: zalandoresearch   File: misc.py    License: MIT License 5 votes vote down vote up
def draw_text_label(img, text, x, y, alignx=0.5, aligny=0.5, color=255, opacity=1.0, glow_opacity=1.0, **kwargs):
    color = np.array(color).flatten().astype(np.float32)
    assert img.ndim == 3 and img.shape[2] == color.size or color.size == 1
    alpha, glow = setup_text_label(text, **kwargs)
    xx, yy = int(np.rint(x - alpha.shape[1] * alignx)), int(np.rint(y - alpha.shape[0] * aligny))
    xb, yb = max(-xx, 0), max(-yy, 0)
    xe, ye = min(alpha.shape[1], img.shape[1] - xx), min(alpha.shape[0], img.shape[0] - yy)
    img = np.array(img)
    slice = img[yy+yb : yy+ye, xx+xb : xx+xe, :]
    slice[:] = slice * (1.0 - (1.0 - (1.0 - alpha[yb:ye, xb:xe]) * (1.0 - glow[yb:ye, xb:xe] * glow_opacity)) * opacity)[:, :, np.newaxis]
    slice[:] = slice + alpha[yb:ye, xb:xe, np.newaxis] * (color * opacity)[np.newaxis, np.newaxis, :]
    return img