Python matplotlib.cm.viridis() Examples

The following are 24 code examples of matplotlib.cm.viridis(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module matplotlib.cm , or try the search function .
Example #1
Source File: main.py    From form2fit with MIT License 6 votes vote down vote up
def _get_next_data(self):
        """Grabs a fresh pair of source and target data points.
        """
        self._pair_idx += 1
        self.imgs, labels, center = next(self._dloader)
        self.center = center[0]
        label = labels[0]
        self.xs, self.xt = self.imgs[:, :self._num_channels, :, :], self.imgs[:, self._num_channels:, :, :]
        if self._num_channels == 4:
            self._xs_np = ml.tensor2ndarray(self.xs[:, :3], [self._color_mean * 3, self._color_std * 3])
            self._xt_np = ml.tensor2ndarray(self.xt[:, :3], [self._color_mean * 3, self._color_std * 3])
        else:
            self._xs_np = ml.tensor2ndarray(self.xs[:, :1], [self._color_mean, self._color_std], False)
            self._xt_np = ml.tensor2ndarray(self.xt[:, :1], [self._color_mean, self._color_std], False)
            self._xs_np = np.uint8(cm.viridis(self._xs_np) * 255)[..., :3]
            self._xt_np = np.uint8(cm.viridis(self._xt_np) * 255)[..., :3]
        source_idxs = label[:, 0:2]
        target_idxs = label[:, 2:4]
        rot_idx = label[:, 4]
        is_match = label[:, 5]
        self.best_rot_idx = rot_idx[0].item()
        mask = (is_match == 1) & (rot_idx == self.best_rot_idx)
        self.source_pixel_idxs = source_idxs[mask].numpy()
        self.target_pixel_idxs = target_idxs[mask].numpy() 
Example #2
Source File: create_static_expl_plots.py    From safe-exploration with MIT License 6 votes vote down vote up
def plot_sample_set(x_train,z_all,env):
    """ plot the sample set"""
    
    s_train = x_train[:,:env.n_s]
    n_train = np.shape(s_train)[0]
    
    s_expl = z_all[:,:env.n_s]
    n_it = np.shape(s_expl)[0]
    fig, ax = env.plot_safety_bounds(color = "r")
    
    c_spectrum = viridis(np.arange(n_it))
    # plot initial dataset    
    for i in range(n_train):
        ax = env.plot_state(ax,s_train[i,:env.n_s],color = c_spectrum[0])
    
    # plot the data gatehred
    for i in range(n_it)        :
        ax = env.plot_state(ax,s_expl[i,:env.n_s],color = c_spectrum[i])
        
    return fig, ax 
Example #3
Source File: smoothlife.py    From SmoothLife with GNU General Public License v3.0 6 votes vote down vote up
def show_animation():
    w = 1 << 9
    h = 1 << 9
    # w = 1920
    # h = 1080
    sl = SmoothLife(h, w)
    sl.add_speckles()
    sl.step()

    fig = plt.figure()
    # Nice color maps: viridis, plasma, gray, binary, seismic, gnuplot
    im = plt.imshow(sl.field, animated=True,
                    cmap=plt.get_cmap("viridis"), aspect="equal")

    def animate(*args):
        im.set_array(sl.step())
        return (im, )

    ani = animation.FuncAnimation(fig, animate, interval=60, blit=True)
    plt.show() 
Example #4
Source File: create_dynamic_expl_plots.py    From safe-exploration with MIT License 6 votes vote down vote up
def plot_sample_set(x_train,z_all,env):
    """ plot the sample set"""
    
    s_train = x_train[:,:env.n_s]
    n_train = np.shape(s_train)[0]
    
    s_expl = z_all[:,:env.n_s]
    n_it = np.shape(s_expl)[0]
    fig, ax = env.plot_safety_bounds(color = "r")
    
    c_spectrum = viridis(np.arange(n_it))
    # plot initial dataset    
    for i in range(n_train):
        ax = env.plot_state(ax,s_train[i,:env.n_s],color = c_spectrum[0])
    
    # plot the data gatehred
    for i in range(n_it):
        ax = env.plot_state(ax,s_expl[i,:env.n_s],color = c_spectrum[i])
        
    return fig, ax 
Example #5
Source File: tgasSelect.py    From gaia_tools with MIT License 6 votes vote down vote up
def plot_mean_quantity_tgas(self,tag,func=None,**kwargs):
        """
        NAME:
           plot_mean_quantity_tgas
        PURPOSE:
           Plot the mean of a quantity in the TGAS catalog on the sky
        INPUT:
           tag - tag in the TGAS data to plot
           func= if set, a function to apply to the quantity
           +healpy.mollview plotting kwargs
        OUTPUT:
           plot to output device
        HISTORY:
           2017-01-17 - Written - Bovy (UofT/CCA)
        """
        mq= self._compute_mean_quantity_tgas(tag,func=func)
        cmap= cm.viridis
        cmap.set_under('w')
        kwargs['unit']= kwargs.get('unit',tag)
        kwargs['title']= kwargs.get('title',"")
        healpy.mollview(mq,nest=True,cmap=cmap,**kwargs)
        return None 
Example #6
Source File: create_exploration_plots_paper.py    From safe-exploration with MIT License 6 votes vote down vote up
def create_color_bar(n_iterations,bar_label = "Iteration"):
    fig = plt.figure(figsize=(2, 4.5))
    ax1 = fig.add_axes([0.05, 0.05, 0.2, 0.9])

    # Set the colormap and norm to correspond to the data for which
    # the colorbar will be used.
    cmap = mpl.cm.viridis
    norm = mpl.colors.Normalize(vmin=1, vmax=n_iterations)

    # ColorbarBase derives from ScalarMappable and puts a colorbar
    # in a specified axes, so it has everything needed for a
    # standalone colorbar.  There are many more kwargs, but the
    # following gives a basic continuous colorbar with ticks
    # and labels.
    cb1 = mpl.colorbar.ColorbarBase(ax1, cmap=cmap,
	                norm=norm,
	                orientation='vertical')
    cb1.set_label(bar_label)

    return fig, ax1 
Example #7
Source File: plane.py    From autoregressive-energy-machines with MIT License 6 votes vote down vote up
def test():
    n = int(1e6)
    dataset = GaussianGridDataset(n)
    samples = dataset.data

    fig, ax = plt.subplots(1, 1, figsize=(5, 5))

    # ax.hist2d(samples[:, 0], samples[:, 1],
    #               range=[[0, 1], [0, 1]], bins=512, cmap=cm.viridis)
    ax.hist2d(samples[:, 0], samples[:, 1], range=[[-4, 4], [-4, 4]], bins=512,
              cmap=cm.viridis)

    ax.set_xticks([])
    ax.set_yticks([])

    plt.show()
    # path = os.path.join(utils.get_output_root(), 'plane-test.png')
    # plt.savefig(path, rasterized=True) 
Example #8
Source File: tgasSelect.py    From gaia_tools with MIT License 5 votes vote down vote up
def plot_tgas(self,jmin=None,jmax=None,
                  jkmin=None,jkmax=None,
                  cut=False,
                  **kwargs):
        """
        NAME:
           plot_tgas
        PURPOSE:
           Plot star counts in TGAS
        INPUT:
           If the following are not set, fullsky will be plotted:
              jmin, jmax= minimum and maximum Jt
              jkmin, jkmax= minimum and maximum J-Ks
           cut= (False) if True, cut to the 'good' sky
           +healpy.mollview plotting kwargs
        OUTPUT:
           plot to output device
        HISTORY:
           2017-01-17 - Written - Bovy (UofT/CCA)
        """
        # Select stars
        if jmin is None or jmax is None \
                or jkmin is None or jkmax is None:
            pt= self._nstar_tgas_skyonly
        else:
            pindx= (self._full_jt > jmin)*(self._full_jt < jmax)\
                *(self._full_jk > jkmin)*(self._full_jk < jkmax)
            pt, e= numpy.histogram((self._full_tgas['source_id']/2**(35.\
                      +2*(12.-numpy.log2(_BASE_NSIDE)))).astype('int')[pindx],
                                   range=[-0.5,_BASE_NPIX-0.5],
                                   bins=_BASE_NPIX)
        pt= numpy.log10(pt)
        if cut: pt[self._exclude_mask_skyonly]= healpy.UNSEEN
        cmap= cm.viridis
        cmap.set_under('w')
        kwargs['unit']= r'$\log_{10}\mathrm{number\ counts}$'
        kwargs['title']= kwargs.get('title',"")
        healpy.mollview(pt,nest=True,cmap=cmap,**kwargs)
        return None 
Example #9
Source File: colorbar.py    From marvin with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _string_to_cmap(cm_name):
    """Return colormap given name.

    Parameters:
        cm_name (str):
            Name of colormap.

    Returns:
        `matplotlib.cm <http://matplotlib.org/api/cm_api.html>`_ (colormap)
        object
    """
    if isinstance(cm_name, str):
        if 'linearlab' in cm_name:
            try:
                cmap, cmap_r = linearlab()
            except IOError:
                cmap = cm.viridis
            else:
                if '_r' in cm_name:
                    cmap = cmap_r
        else:
            cmap = cm.get_cmap(cm_name)
    elif isinstance(cm_name, ListedColormap) or isinstance(cm_name, LinearSegmentedColormap):
        cmap = cm_name
    else:
        raise MarvinError('{} is not a valid cmap'.format(cm_name))

    return cmap 
Example #10
Source File: main.py    From form2fit with MIT License 5 votes vote down vote up
def _draw_rotations(self, init=False, heatmap=True):
        def _hist_eq(img):
            from skimage import exposure

            img_cdf, bin_centers = exposure.cumulative_distribution(img)
            return np.interp(img, bin_centers, img_cdf)

        for col in range(5):
            for row in range(4):
                offset = col * 4 + row
                if init:
                    img = self._zeros.copy()
                else:
                    if heatmap:
                        img = self.heatmaps[offset].copy()
                        img = img / img.max()
                        img = _hist_eq(img)
                        img = np.uint8(cm.viridis(img) * 255)[..., :3]
                        img = img.copy()
                    else:
                        img = misc.rotate_img(self._xs_np, -(360 / 20) * offset, center=(self.center[1], self.center[0]))
                        img = img.copy()
                    if offset == self._uv[-1]:
                        img[
                            self._uv[0] - 1 : self._uv[0] + 1,
                            self._uv[1] - 1 : self._uv[1] + 1,
                        ] = [255, 0, 0]
                        self._add_border_clr(img, [255, 0, 0])
                    if offset == self.best_rot_idx:
                        self._add_border_clr(img, [0, 255, 0])
                self._img = QImage(
                    img.data, self._w, self._h, self._c * self._w, QImage.Format_RGB888
                )
                pixmap = QPixmap.fromImage(self._img)
                self._grid_widgets[offset].setPixmap(pixmap)
                self._grid_widgets[offset].setScaledContents(True) 
Example #11
Source File: create_exploration_plots_paper.py    From safe-exploration with MIT License 5 votes vote down vote up
def plot_sample_set(z_all,env,y_label = False, x_train = None):
    """ plot the sample set"""
    
    
    
    s_expl = z_all[:,:env.n_s]
    n_it = np.shape(s_expl)[0]
    fig, ax = env.plot_safety_bounds(color = "r")
    
    c_spectrum = viridis(np.arange(n_it))
    # plot initial dataset    
    if not x_train is None:
	s_train = x_train[:,:env.n_s]
        n_train = np.shape(s_train)[0]
        for i in range(n_train):
            ax = env.plot_state(ax,s_train[i,:env.n_s],color = c_spectrum[0])
    
    # plot the data gatehred
    for i in range(n_it):
        ax = env.plot_state(ax,s_expl[i,:env.n_s],color = c_spectrum[i])
        
    ax.set_xlabel("Angular velocity $\dot{\\theta}$")
    print(y_label)
    if y_label:
	print("??")
	ax.set_ylabel("Angle $\\theta$")
    fig.set_size_inches(3.6,4.5)
    return fig, ax 
Example #12
Source File: tgasSelect.py    From gaia_tools with MIT License 5 votes vote down vote up
def plot_2mass(self,jmin=None,jmax=None,
                   jkmin=None,jkmax=None,
                   cut=False,
                   **kwargs):
        """
        NAME:
           plot_2mass
        PURPOSE:
           Plot star counts in 2MASS
        INPUT:
           If the following are not set, fullsky will be plotted:
              jmin, jmax= minimum and maximum Jt
              jkmin, jkmax= minimum and maximum J-Ks
           cut= (False) if True, cut to the 'good' sky
           +healpy.mollview plotting kwargs
        OUTPUT:
           plot to output device
        HISTORY:
           2017-01-17 - Written - Bovy (UofT/CCA)
        """
        # Select stars
        if jmin is None or jmax is None \
                or jkmin is None or jkmax is None:
            pt= _2mc_skyonly[1]
        else:
            pindx= (_2mc[0] > jmin)*(_2mc[0] < jmax)\
                *(_2mc[1] > jkmin)*(_2mc[1] < jkmax)
            pt, e= numpy.histogram(_2mc[2][pindx],
                                   range=[-0.5,_BASE_NPIX-0.5],
                                   bins=_BASE_NPIX)
        pt= numpy.log10(pt)
        if cut: pt[self._exclude_mask_skyonly]= healpy.UNSEEN
        cmap= cm.viridis
        cmap.set_under('w')
        kwargs['unit']= r'$\log_{10}\mathrm{number\ counts}$'
        kwargs['title']= kwargs.get('title',"")
        healpy.mollview(pt,nest=True,cmap=cmap,**kwargs)
        return None 
Example #13
Source File: smoothlife.py    From SmoothLife with GNU General Public License v3.0 5 votes vote down vote up
def save_animation():
    w = 1 << 8
    h = 1 << 8
    # w = 1920
    # h = 1080
    sl = SmoothLife(h, w)
    sl.add_speckles()

    # Matplotlib shoves a horrible border on animation saves.
    # We'll do it manually. Ugh

    from skvideo.io import FFmpegWriter
    from matplotlib import cm

    fps = 10
    frames = 100
    w = FFmpegWriter("smoothlife.mp4", inputdict={"-r": str(fps)})
    for i in range(frames):
        frame = cm.viridis(sl.field)
        frame *= 255
        frame = frame.astype("uint8")
        w.writeFrame(frame)
        sl.step()
    w.close()

    # Also, webm output isn't working for me,
    # so I have to manually convert. Ugh
    # ffmpeg -i smoothlife.mp4 -c:v libvpx -b:v 2M smoothlife.webm 
Example #14
Source File: display.py    From skylibs with GNU Lesser General Public License v3.0 5 votes vote down vote up
def plotSubFigure(X, Y, Z, subfig, type_):
    fig = plt.gcf()
    ax = fig.add_subplot(1, 3, subfig, projection='3d')
    #ax = fig.gca(projection='3d')
    if type_ == "colormap":
        ax.plot_surface(X, Y, Z, cmap=cm.viridis, rstride=1, cstride=1,
                        shade=True, linewidth=0, antialiased=False)
    else:
        ax.plot_surface(X, Y, Z, color=[0.7, 0.7, 0.7], rstride=1, cstride=1,
                        shade=True, linewidth=0, antialiased=False)

    ax.set_aspect("equal")

    max_range = np.array([X.max()-X.min(), Y.max()-Y.min(), Z.max()-Z.min()]).max() / 2.0 * 0.6
    mid_x = (X.max()+X.min()) * 0.5
    mid_y = (Y.max()+Y.min()) * 0.5
    mid_z = (Z.max()+Z.min()) * 0.5
    ax.set_xlim(mid_x - max_range, mid_x + max_range)
    ax.set_ylim(mid_y - max_range, mid_y + max_range)
    ax.set_zlim(mid_z - max_range, mid_z + max_range)

    az, el = 90, 90
    if type_ == "top":
        az = 130
    elif type_ == "side":
        az, el = 40, 0

    ax.view_init(az, el)
    fig.subplots_adjust(left=0, right=1, bottom=0, top=1)

    plt.grid(False)
    plt.axis('off') 
Example #15
Source File: imshow.py    From ASPP-2018-numpy with MIT License 5 votes vote down vote up
def imshow (Z, vmin=None, vmax=None, cmap=viridis, show_cmap=True):
    ''' Show a 2D numpy array using terminal colors '''

    Z = np.atleast_2d(Z)
    
    if len(Z.shape) != 2:
        print("Cannot display non 2D array")
        return

    vmin = vmin or Z.min()
    vmax = vmax or Z.max()

    # Build initialization string that setup terminal colors
    init = ''
    for i in range(240):
        v = i/240 
        r,g,b,a = cmap(v)
        init += "\x1b]4;%d;rgb:%02x/%02x/%02x\x1b\\" % (16+i, int(r*255),int(g*255),int(b*255))

    # Build array data string
    data = ''
    for i in range(Z.shape[0]):
        for j in range(Z.shape[1]):
            c = 16 + int( ((Z[Z.shape[0]-i-1,j]-vmin) / (vmax-vmin))*239)
            if (c < 16):
                c=16
            elif (c > 255):
                c=255
            data += "\x1b[48;5;%dm  " % c
            u = vmax - (i/float(max(Z.shape[0]-1,1))) * ((vmax-vmin))
        if show_cmap:
            data += "\x1b[0m  "
            data += "\x1b[48;5;%dm  " % (16 + (1-i/float(Z.shape[0]))*239)
            data += "\x1b[0m %+.2f" % u
        data += "\n"

    sys.stdout.write(init+'\n')
    sys.stdout.write(data+'\n') 
Example #16
Source File: mesh.py    From mikeio with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
def plot(self, cmap=None, z=None, label=None):
        """
        Plot mesh elements

        Parameters
        ----------
        cmap: matplotlib.cm.cmap, optional
            default viridis
        z: np.array
            value for each element to plot, default bathymetry
        label: str, optional
            colorbar label
        """
        if cmap is None:
            cmap = cm.viridis

        nc = self.get_node_coords()
        ec = self.get_element_coords()
        ne = ec.shape[0]

        if z is None:
            z = ec[:, 2]
            if label is None:
                label = "Bathymetry (m)"

        patches = []

        for j in range(ne):
            nodes = self._mesh.ElementTable[j]
            pcoords = np.empty([nodes.Length, 2])
            for i in range(nodes.Length):
                nidx = nodes[i] - 1
                pcoords[i, :] = nc[nidx, 0:2]

            polygon = Polygon(pcoords, True)
            patches.append(polygon)

        fig, ax = plt.subplots()
        p = PatchCollection(patches, cmap=cmap, edgecolor="black")

        p.set_array(z)
        ax.add_collection(p)
        fig.colorbar(p, ax=ax, label=label)
        ax.set_xlim(nc[:, 0].min(), nc[:, 0].max())
        ax.set_ylim(nc[:, 1].min(), nc[:, 1].max()) 
Example #17
Source File: _vista.py    From gempy with GNU Lesser General Public License v3.0 4 votes vote down vote up
def plot_structured_grid_interactive(
            self,
            name: str,
            render_topography: bool = False,
            **kwargs,
    ):
        """Plot interactive 3-D geomodel with three cross sections in subplot.

        Args:
            geo_model: Geomodel object with solutions.
            name (str): Can be either one of the following
                'lith' - Lithology id block.
                'scalar' - Scalar field block.
                'values' - Values matrix block.
            render_topography: Render topography. Defaults to False.
            **kwargs:

        Returns:
            (Vista) GemPy Vista object for plotting.
        """
        mesh = self.plot_structured_grid(name=name, render_topography=render_topography, **kwargs)[0]

        # define colormaps
        if name == "lith":
            cmap = mcolors.ListedColormap(list(self._get_color_lot(faults=False)))
        elif name == "scalar":
            cmap = cm.viridis

        # callback functions for subplots
        def xcallback(normal, origin):
            self.p.subplot(1)
            self.p.add_mesh(mesh.slice(normal=normal, origin=origin), name="xslc", cmap=cmap)

        def ycallback(normal, origin):
            self.p.subplot(2)
            self.p.add_mesh(mesh.slice(normal=normal, origin=origin), name="yslc", cmap=cmap)

        def zcallback(normal, origin):
            self.p.subplot(3)
            self.p.add_mesh(mesh.slice(normal=normal, origin=origin), name="zslc", cmap=cmap)

        # cross section widgets
        self.p.subplot(0)
        self.p.add_plane_widget(xcallback, normal="x")
        self.p.subplot(0)
        self.p.add_plane_widget(ycallback, normal="y")
        self.p.subplot(0)
        self.p.add_plane_widget(zcallback, normal="z")

        # Lock other three views in place
        self.p.subplot(1)
        self.p.view_yz()
        self.p.disable()
        self.p.subplot(2)
        self.p.view_xz()
        self.p.disable()
        self.p.subplot(3)
        self.p.view_xy()
        self.p.disable() 
Example #18
Source File: image.py    From twitter-stock-recommendation with MIT License 4 votes vote down vote up
def imsave(fname, arr, vmin=None, vmax=None, cmap=None, format=None,
           origin=None, dpi=100):
    """
    Save an array as in image file.

    The output formats available depend on the backend being used.

    Parameters
    ----------
    fname : str or file-like
        Path string to a filename, or a Python file-like object.
        If *format* is *None* and *fname* is a string, the output
        format is deduced from the extension of the filename.
    arr : array-like
        An MxN (luminance), MxNx3 (RGB) or MxNx4 (RGBA) array.
    vmin, vmax: [ None | scalar ]
        *vmin* and *vmax* set the color scaling for the image by fixing the
        values that map to the colormap color limits. If either *vmin*
        or *vmax* is None, that limit is determined from the *arr*
        min/max value.
    cmap : matplotlib.colors.Colormap, optional
        For example, ``cm.viridis``.  If ``None``, defaults to the
        ``image.cmap`` rcParam.
    format : str
        One of the file extensions supported by the active backend.  Most
        backends support png, pdf, ps, eps and svg.
    origin : [ 'upper' | 'lower' ]
        Indicates whether the ``(0, 0)`` index of the array is in the
        upper left or lower left corner of the axes.  Defaults to the
        ``image.origin`` rcParam.
    dpi : int
        The DPI to store in the metadata of the file.  This does not affect the
        resolution of the output image.
    """
    from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
    from matplotlib.figure import Figure
    if isinstance(fname, getattr(os, "PathLike", ())):
        fname = os.fspath(fname)
    if (format == 'png'
        or (format is None
            and isinstance(fname, six.string_types)
            and fname.lower().endswith('.png'))):
        image = AxesImage(None, cmap=cmap, origin=origin)
        image.set_data(arr)
        image.set_clim(vmin, vmax)
        image.write_png(fname)
    else:
        fig = Figure(dpi=dpi, frameon=False)
        FigureCanvas(fig)
        fig.figimage(arr, cmap=cmap, vmin=vmin, vmax=vmax, origin=origin,
                     resize=True)
        fig.savefig(fname, dpi=dpi, format=format, transparent=True) 
Example #19
Source File: tools.py    From devito with MIT License 4 votes vote down vote up
def plot_field(field, xmin=0., xmax=2., ymin=0., ymax=2., zmin=None, zmax=None,
               view=None, linewidth=0):
    """
    Utility plotting routine for 2D data.

    Parameters
    ----------
    field : array_like
        Field data to plot.
    xmax : int, optional
        Length of the x-axis.
    ymax : int, optional
        Length of the y-axis.
    view: int, optional
        View point to intialise.
    """
    if xmin > xmax or ymin > ymax:
        raise ValueError("Dimension min cannot be larger than dimension max.")
    if (zmin is not None and zmax is not None):
        if zmin > zmax:
            raise ValueError("Dimension min cannot be larger than dimension max.")
    elif(zmin is None and zmax is not None):
        if np.min(field) >= zmax:
            warning("zmax is less than field's minima. Figure deceptive.")
    elif(zmin is not None and zmax is None):
        if np.max(field) <= zmin:
            warning("zmin is larger than field's maxima. Figure deceptive.")
    x_coord = np.linspace(xmin, xmax, field.shape[0])
    y_coord = np.linspace(ymin, ymax, field.shape[1])
    fig = pyplot.figure(figsize=(11, 7), dpi=100)
    ax = fig.gca(projection='3d')
    X, Y = np.meshgrid(x_coord, y_coord, indexing='ij')
    ax.plot_surface(X, Y, field[:], cmap=cm.viridis, rstride=1, cstride=1,
                    linewidth=linewidth, antialiased=False)

    # Enforce axis measures and set view if given
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)
    if zmin is None:
        zmin = np.min(field)
    if zmax is None:
        zmax = np.max(field)
    ax.set_zlim(zmin, zmax)

    if view is not None:
        ax.view_init(*view)

    # Label axis
    ax.set_xlabel('$x$')
    ax.set_ylabel('$y$')

    pyplot.show() 
Example #20
Source File: tgasSelect.py    From gaia_tools with MIT License 4 votes vote down vote up
def plot_cmd(self,type='sf',cut=True):
        """
        NAME:
           plot_cmd
        PURPOSE:
           Plot the distribution of counts in the color-magnitude diagram
        INPUT:
           type= ('sf') Plot 'sf': selection function
                             'tgas': TGAS counts
                             '2mass': 2MASS counts
           cut= (True) cut to the 'good' part of the sky
        OUTPUT:
           Plot to output device
        HISTORY:
           2017-01-17 - Written - Bovy (UofT/CCA)
        """
        jtbins= (numpy.amax(_2mc[0])-numpy.amin(_2mc[0]))/0.1+1
        nstar2mass, edges= numpy.histogramdd(\
            _2mc[:3].T,bins=[jtbins,3,_BASE_NPIX],
            range=[[numpy.amin(_2mc[0])-0.05,numpy.amax(_2mc[0])+0.05],
                   [-0.05,1.0],[-0.5,_BASE_NPIX-0.5]],weights=_2mc[3])
        findx= (self._full_jk > -0.05)*(self._full_jk < 1.0)\
            *(self._full_twomass['j_mag'] < 13.5)
        nstartgas, edges= numpy.histogramdd(\
            numpy.array([self._full_jt[findx],self._full_jk[findx],\
                             (self._full_tgas['source_id'][findx]\
                                  /2**(35.+2*(12.-numpy.log2(_BASE_NSIDE))))\
                             .astype('int')]).T,
            bins=[jtbins,3,_BASE_NPIX],
            range=[[numpy.amin(_2mc[0])-0.05,numpy.amax(_2mc[0])+0.05],
                   [-0.05,1.0],[-0.5,_BASE_NPIX-0.5]])
        if cut:
            nstar2mass[:,:,self._exclude_mask_skyonly]= numpy.nan
            nstartgas[:,:,self._exclude_mask_skyonly]= numpy.nan
        nstar2mass= numpy.nansum(nstar2mass,axis=-1)
        nstartgas= numpy.nansum(nstartgas,axis=-1)
        if type == 'sf':
            pt= nstartgas/nstar2mass
            vmin= 0.
            vmax= 1.
            zlabel=r'$\mathrm{completeness}$'
        elif type == 'tgas' or type == '2mass':
            vmin= 0.
            vmax= 6.
            zlabel= r'$\log_{10}\mathrm{number\ counts}$'
            if type == 'tgas':
                pt= numpy.log10(nstartgas)
            elif type == '2mass':
                pt= numpy.log10(nstar2mass)
        return bovy_plot.bovy_dens2d(pt,origin='lower',
                                     cmap='viridis',interpolation='nearest',
                                     colorbar=True,shrink=0.78,
                                     vmin=vmin,vmax=vmax,zlabel=zlabel,
                                     yrange=[edges[0][0],edges[0][-1]],
                                     xrange=[edges[1][0],edges[1][-1]],
                                     xlabel=r'$J-K_s$',
                                     ylabel=r'$J+\Delta J$') 
Example #21
Source File: LSDMap_HillslopeMorphology.py    From LSDMappingTools with MIT License 4 votes vote down vote up
def PlotKsnAgainstRStar(DataDirectory, FilenamePrefix, PlotDirectory):
    """
    Function to plot median Ksn against R* for a series of basins

    Author: FJC
    """

    # SMM: What generates this file?? I don't have it.
    input_csv = PlotDirectory+FilenamePrefix+'_basin_hillslope_data.csv'
    df = pd.read_csv(input_csv)

    # linregress
    slope, intercept, r_value, p_value, std_err = stats.linregress(df['mchi_median'],df['Rstar_median'])
    print(slope, intercept, r_value, p_value)
    x = np.linspace(0, 200, 100)
    new_y = slope*x + intercept

    # set up the figure
    fig, ax = plt.subplots(nrows=1, ncols=1, sharex=True, figsize=(5,5))

    ax.scatter(df['mchi_median'], df['Rstar_median'], c=df['basin_keys'], s=50, edgecolors='k', zorder=100, cmap=cm.viridis)
    ax.errorbar(df['mchi_median'], df['Rstar_median'], xerr=[df['mchi_lower_err'], df['mchi_upper_err']], yerr=[df['Rstar_lower_err'], df['Rstar_upper_err']], fmt='o', ecolor='0.5',markersize=1,mfc='white',mec='k')
    # ax.text(0.55, 0.1, '$y = $'+str(np.round(slope,4))+'$x + $'+str(np.round(intercept,2))+'\n$R^2 = $'+str(np.round(r_value,2))+'\n$p = $'+str(p_value), fontsize=9, color='black', transform=ax.transAxes)
    ax.plot(x, new_y, c='0.5', ls='--')
    ax.set_xlim(0,100)

    ax.set_xlabel('$k_{sn}$')
    ax.set_ylabel('$R*$')

    plt.subplots_adjust(left=0.15,right=0.85, bottom=0.1, top=0.95)
    CAx = fig.add_axes([0.87,0.1,0.02,0.85])
    m = cm.ScalarMappable(cmap=cm.viridis)
    m.set_array(df['basin_keys'])
    plt.colorbar(m, cax=CAx,orientation='vertical', label='Basin key')

    #plt.tight_layout()

    #save output
    plt.savefig(PlotDirectory+FilenamePrefix +"_ksn_vs_rstar.png", dpi=300)
    plt.clf()


# This seems to do the same as the PlotEStarRStarWithinBasin function!!!
# However it is not working since I don't have the _basin_hillslope_data.csv' file 
Example #22
Source File: plot_hillslope_morphology.py    From LSDMappingTools with MIT License 4 votes vote down vote up
def PlotEStarRStar(Basin, Sc=0.71):
    """
    MDH
    """
    
    Data = CalculateEStarRStar(Basin)
    
    # setup the figure
    Fig = CreateFigure(AspectRatio=1.2)
        
    #choose colormap
    ColourMap = cm.viridis

    #Plot analytical relationship
    PlotEStarRStarTheoretical()
    
    # colour code by flow length
    MinFlowLength = Data.FlowLength.min()
    Data.FlowLength = Data.FlowLength-MinFlowLength
    MaxFlowLength = Data.FlowLength.max()
    colours = (Data.FlowLength/MaxFlowLength)
    
    #plot the data
    plt.loglog()
    
    # Error bars with colours but faded (alpha)
    for i, row in Data.iterrows(): 
        EStarErr = np.array([[row.EStarLower],[row.EStarUpper]])
        RStarErr = np.array([[row.RStarLower],[row.RStarUpper]])
        plt.plot([row.EStar,row.EStar],RStarErr,'-', lw=1, color=ColourMap(colours[i]), alpha=0.5,zorder=9)
        plt.plot(EStarErr,[row.RStar,row.RStar],'-', lw=1, color=ColourMap(colours[i]), alpha=0.5,zorder=9)
        plt.plot(row.EStar,row.RStar,'o',ms=4,color=ColourMap(colours[i]),zorder=32)

    # Finalise the figure
    plt.xlabel('$E^*={{-2\:C_{HT}\:L_H}/{S_C}}$')
    plt.ylabel('$R^*=S/S_C$')
    plt.xlim(0.1,1000)
    plt.ylim(0.01,1.5)
        
    # add colour bar
    m = cm.ScalarMappable(cmap=ColourMap)
    m.set_array(Data.FlowLength)
    cbar = plt.colorbar(m)
    tick_locator = ticker.MaxNLocator(nbins=5)
    cbar.locator = tick_locator
    cbar.update_ticks()
    cbar.set_label('Distance to Outlet (m)')
    
    plt.suptitle("Basin "+str(Basin)+" Dimensionless Hillslope Morphology")
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig(PlotDirectory+FilenamePrefix + "_" + "%02d" % Basin + "_EStarRStar.png", dpi=300)
    plt.close(Fig) 
Example #23
Source File: plot_hillslope_morphology.py    From LSDMappingTools with MIT License 4 votes vote down vote up
def PlotLongProfileMChi(BasinID):
    
    # load the channel data
    ChannelData = ReadChannelData(DataDirectory, FilenamePrefix)

    # isolate basin data
    BasinChannelData = ChannelData[ChannelData.basin_key == BasinID]
    
    if (BasinChannelData.count == 0):
        print("No Channel Data for Basin ID " + str(BasinID))
        
    MinimumDistance = BasinChannelData.flow_distance.min()
    MaximumMChi = BasinChannelData.m_chi.max()
    
    # how many segments are we dealing with?    
    Segments = BasinChannelData.segment_number.unique()
    
    # setup the figure
    Fig = CreateFigure()
    
    #choose colormap
    ColourMap = cm.viridis
        
    # Get the data columns for plotting
    for i in range(0, len(Segments)):
        #get data arrays
        Dist = ChannelData.flow_distance[ChannelData.segment_number == Segments[i]]
        Elevation = ChannelData.elevation[ChannelData.segment_number == Segments[i]]
        SegmentedElevation = ChannelData.segmented_elevation[ChannelData.segment_number == Segments[i]]
        MChi = ChannelData.m_chi[ChannelData.segment_number == Segments[i]].unique()[0]
        
        #normalise distance by outlet distance
        Dist = Dist-MinimumDistance
        #plot, colouring segments
        Colour = MChi/MaximumMChi
        plt.plot(Dist/1000,Elevation,'k--',dashes=(2,2), lw=0.5,zorder=10)
        plt.plot(Dist/1000, SegmentedElevation, '-', lw=2, c=ColourMap(Colour),zorder=9)
    
    # Finalise the figure
    plt.xlabel('Distance (km)')
    plt.ylabel('Elevation (m)')
    plt.title('Basin ID ' + str(BasinID))
    plt.tight_layout()
    #add colourbar
    CAx = Fig.add_axes([0.15,0.8,0.4,0.05])
    m = cm.ScalarMappable(cmap=ColourMap)
    m.set_array(ChannelData.m_chi)
    plt.colorbar(m, cax=CAx,orientation='horizontal')
    plt.xlabel('$M_{\chi}$ m$^{0.64}$')
    #save output
    plt.savefig(PlotDirectory+FilenamePrefix + "_" + str(BasinID) + "_LongProfMChi.png", dpi=300)
    plt.close() 
Example #24
Source File: plot_hillslope_morphology.py    From LSDMappingTools with MIT License 4 votes vote down vote up
def PlotChiElevationMChi(BasinID):
    
    # load the channel data
    ChannelData = ReadChannelData(DataDirectory, FilenamePrefix)

    if (BasinChannelData.count == 0):
        print("No Channel Data for Basin ID " + str(BasinID))

    # isolate basin data
    BasinChannelData = ChannelData[ChannelData.basin_key == BasinID]
    MinimumChi = BasinChannelData.chi.min()
    MaximumMChi = BasinChannelData.m_chi.max()
    
    # how many segments are we dealing with?    
    Segments = BasinChannelData.segment_number.unique()
    
    # setup the figure
    Fig = CreateFigure()
    
    #choose colormap
    ColourMap = cm.viridis
    
    # Get the data columns for plotting
    for i in range(0, len(Segments)):
        #get data arrays
        Chi = ChannelData.chi[ChannelData.segment_number == Segments[i]]
        Elevation = ChannelData.elevation[ChannelData.segment_number == Segments[i]]
        SegmentedElevation = ChannelData.segmented_elevation[ChannelData.segment_number == Segments[i]]
        MChi = ChannelData.m_chi[ChannelData.segment_number == Segments[i]].unique()[0]
        
        #normalise chi by outlet chi
        Chi = Chi-MinimumChi
        #plot, colouring segments
        Colour = MChi/MaximumMChi
        plt.plot(Chi,Elevation,'k--',dashes=(2,2), lw=0.5,zorder=10)
        plt.plot(Chi, SegmentedElevation, '-', lw=2, c=ColourMap(Colour),zorder=9)
    
    # Finalise the figure
    plt.xlabel(r'$\chi$ (m)')
    plt.ylabel('Elevation (m)')
    plt.title('Basin ID ' + str(BasinID))
    plt.tight_layout()
    #add colourbar
    CAx = Fig.add_axes([0.15,0.8,0.4,0.05])
    m = cm.ScalarMappable(cmap=ColourMap)
    m.set_array(ChannelData.m_chi)
    plt.colorbar(m, cax=CAx,orientation='horizontal')
    plt.xlabel('$M_{\chi}$ m$^{0.64}$')
    #save output
    plt.savefig(PlotDirectory+FilenamePrefix + "_" + str(BasinID) + "_ChiElevMChi.png", dpi=300)
    plt.close()