Python numpy.fromfile() Examples

The following are 30 code examples for showing how to use numpy.fromfile(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.

You may check out the related API usage on the sidebar.

You may also want to check out all available functions/classes of the module numpy , or try the search function .

Example 1
Project: deep-learning-note   Author: wdxtub   File: utils.py    License: MIT License 7 votes vote down vote up
def parse_data(path, dataset, flatten):
    if dataset != 'train' and dataset != 't10k':
        raise NameError('dataset must be train or t10k')

    label_file = os.path.join(path, dataset + '-labels-idx1-ubyte')
    with open(label_file, 'rb') as file:
        _, num = struct.unpack(">II", file.read(8))
        labels = np.fromfile(file, dtype=np.int8)  # int8
        new_labels = np.zeros((num, 10))
        new_labels[np.arange(num), labels] = 1

    img_file = os.path.join(path, dataset + '-images-idx3-ubyte')
    with open(img_file, 'rb') as file:
        _, num, rows, cols = struct.unpack(">IIII", file.read(16))
        imgs = np.fromfile(file, dtype=np.uint8).reshape(num, rows, cols)  # uint8
        imgs = imgs.astype(np.float32) / 255.0
        if flatten:
            imgs = imgs.reshape([num, -1])

    return imgs, new_labels 
Example 2
Project: me-ica   Author: ME-ICA   File: io.py    License: GNU Lesser General Public License v2.1 7 votes vote down vote up
def _fread3_many(fobj, n):
    """Read 3-byte ints from an open binary file object.

    Parameters
    ----------
    fobj : file
        File descriptor

    Returns
    -------
    out : 1D array
        An array of 3 byte int
    """
    b1, b2, b3 = np.fromfile(fobj, ">u1", 3 * n).reshape(-1,
                                                    3).astype(np.int).T
    return (b1 << 16) + (b2 << 8) + b3 
Example 3
Project: ArtGAN   Author: cs-chan   File: ingest_stl10.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def collectdata(self,):
        print 'Start Collect Data...'

        train_x_path = os.path.join(self.input_dir, 'unlabeled_X.bin')

        train_xf = open(train_x_path, 'rb')
        train_x = np.fromfile(train_xf, dtype=np.uint8)
        train_x = np.reshape(train_x, (-1, 3, 96, 96))
        train_x = np.transpose(train_x, (0, 3, 2, 1))

        idx = 0
        for i in xrange(train_x.shape[0]):
            if not self.skipimg:
                transform_and_save(img_arr=train_x[i], output_filename=os.path.join(self.unlabeldir, str(idx) + '.jpg'))
            self.trainpairlist[os.path.join('images', 'unlabeled', str(idx) + '.jpg')] = 'labels/11.txt'
            idx += 1

        print 'Finished Collect Data...' 
Example 4
Project: pointnet-registration-framework   Author: vinits5   File: plyfile.py    License: MIT License 6 votes vote down vote up
def _read(self, stream, text, byte_order):
        '''
        Read the actual data from a PLY file.

        '''
        if text:
            self._read_txt(stream)
        else:
            if self._have_list:
                # There are list properties, so a simple load is
                # impossible.
                self._read_bin(stream, byte_order)
            else:
                # There are no list properties, so loading the data is
                # much more straightforward.
                self._data = _np.fromfile(stream,
                                          self.dtype(byte_order),
                                          self.count)

        if len(self._data) < self.count:
            k = len(self._data)
            del self._data
            raise PlyParseError("early end-of-file", self, k)

        self._check_sanity() 
Example 5
Project: baseband   Author: mhvk   File: test_mark4.py    License: GNU General Public License v3.0 6 votes vote down vote up
def test_payload_getitem_setitem(self, item):
        with open(SAMPLE_FILE, 'rb') as fh:
            fh.seek(0xa88)
            header = mark4.Mark4Header.fromfile(fh, ntrack=64, decade=2010)
            payload = mark4.Mark4Payload.fromfile(fh, header)
        sel_data = payload.data[item]
        assert np.all(payload[item] == sel_data)
        payload2 = mark4.Mark4Payload(payload.words.copy(), header)
        assert payload2 == payload
        payload2[item] = -sel_data
        check = payload.data
        check[item] = -sel_data
        assert np.all(payload2[item] == -sel_data)
        assert np.all(payload2.data == check)
        assert payload2 != payload
        payload2[item] = sel_data
        assert np.all(payload2[item] == sel_data)
        assert payload2 == payload 
Example 6
Project: baseband   Author: mhvk   File: test_mark4.py    License: GNU General Public License v3.0 6 votes vote down vote up
def test_binary_file_reader(self):
        with mark4.open(SAMPLE_FILE, 'rb', decade=2010, ntrack=64) as fh:
            locations = fh.locate_frames()
            assert locations == [0xa88, 0xa88+64*2500]
            fh.seek(0xa88)
            header = mark4.Mark4Header.fromfile(fh, decade=2010, ntrack=64)
            fh.seek(0xa88)
            header2 = fh.read_header()
            current_pos = fh.tell()
            assert header2 == header
            frame_rate = fh.get_frame_rate()
            assert abs(frame_rate
                       - 32 * u.MHz / header.samples_per_frame) < 1 * u.nHz
            assert fh.tell() == current_pos
            repr_fh = repr(fh)

        assert repr_fh.startswith('Mark4FileReader')
        assert 'ntrack=64, decade=2010, ref_time=None' in repr_fh 
Example 7
Project: baseband   Author: mhvk   File: test_mark4.py    License: GNU General Public License v3.0 6 votes vote down vote up
def test_header_times(self):
        with mark4.open(SAMPLE_FILE, 'rb', decade=2010, ntrack=64) as fh:
            fh.seek(0xa88)
            header0 = mark4.Mark4Header.fromfile(fh, ntrack=64, decade=2010)
            start_time = header0.time
            # Use frame size, since header adds to payload.
            samples_per_frame = header0.frame_nbytes * 8 // 2 // 8
            frame_rate = 32. * u.MHz / samples_per_frame
            frame_duration = 1. / frame_rate
            fh.seek(0xa88)
            for frame_nr in range(100):
                try:
                    frame = fh.read_frame()
                except EOFError:
                    break
                header_time = frame.header.time
                expected = start_time + frame_nr * frame_duration
                assert abs(header_time - expected) < 1. * u.ns 
Example 8
Project: baseband   Author: mhvk   File: test_mark4.py    License: GNU General Public License v3.0 6 votes vote down vote up
def test_header(self):
        with open(SAMPLE_32TRACK, 'rb') as fh:
            fh.seek(9656)
            header = mark4.Mark4Header.fromfile(fh, ntrack=32, decade=2010)

        # Try initialising with properties instead of keywords.
        # Here, we let
        # * time imply the decade, bcd_unit_year, bcd_day, bcd_hour,
        #   bcd_minute, bcd_second, bcd_fraction;
        # * ntrack, samples_per_frame, bps define headstack_id, bcd_track_id,
        #   fan_out, and magnitude_bit;
        # * nsb defines lsb_output and converter_id.
        header1 = mark4.Mark4Header.fromvalues(
            ntrack=32, samples_per_frame=80000, bps=2, nsb=2, time=header.time,
            system_id=108)
        assert header1 == header 
Example 9
Project: baseband   Author: mhvk   File: test_mark4.py    License: GNU General Public License v3.0 6 votes vote down vote up
def test_header(self):
        with open(SAMPLE_32TRACK_FANOUT2, 'rb') as fh:
            fh.seek(17436)
            header = mark4.Mark4Header.fromfile(fh, ntrack=32, decade=2010)

        # Try initialising with properties instead of keywords.
        # * time imply the decade, bcd_unit_year, bcd_day, bcd_hour,
        #   bcd_minute, bcd_second, bcd_fraction;
        # * ntrack, samples_per_frame, bps define headstack_id, bcd_track_id,
        #   fan_out, and magnitude_bit;
        # * header.converter since lsb_output and converter_id are somewhat
        #   non-standard
        header1 = mark4.Mark4Header.fromvalues(
            ntrack=32, samples_per_frame=40000, bps=2, time=header.time,
            system_id=108, converters=header.converters)
        assert header1 == header 
Example 10
Project: baseband   Author: mhvk   File: test_mark4.py    License: GNU General Public License v3.0 6 votes vote down vote up
def test_header(self):
        with open(SAMPLE_64TRACK_FT, 'rb') as fh:
            fh.seek(124288)
            header = mark4.Mark4Header.fromfile(fh, ntrack=64, decade=2010)

        # Try initialising with properties instead of keywords.
        # * time imply the decade, bcd_unit_year, bcd_day, bcd_hour,
        #   bcd_minute, bcd_second, bcd_fraction;
        # * ntrack, samples_per_frame, bps define headstack_id, bcd_track_id,
        #   fan_out, and magnitude_bit;
        # * explicitly set lsb_output and converter_id as they are so odd.
        header1 = mark4.Mark4Header.fromvalues(
            ntrack=64, samples_per_frame=40000,
            time=header.time, system_id=114, lsb_output=header['lsb_output'],
            converter_id=header['converter_id'],
            magnitude_bit=header['magnitude_bit'])
        assert header1 == header 
Example 11
Project: yatsm   Author: ceholden   File: stack_line_readers.py    License: MIT License 6 votes vote down vote up
def _read_row(self, row):
        data = np.empty((self.size[1], self.n_image, self.size[0]),
                        self.datatype)

        for i, fid in enumerate(self.files):
            # Find where we need to seek to
            offset = np.dtype(self.datatype).itemsize * \
                (row * self.size[0]) * self.size[1]
            # Seek relative to current position
            fid.seek(offset - fid.tell(), 1)
            # Read
            data[:, i, :] = np.fromfile(fid,
                                        dtype=self.datatype,
                                        count=self.size[0] * self.size[1],
                                        ).reshape(self.size).T

        return data 
Example 12
Project: CapsLayer   Author: naturomics   File: writer.py    License: Apache License 2.0 6 votes vote down vote up
def load_fashion_mnist(path, split):
    split = split.lower()
    image_file, label_file = [os.path.join(path, file_name) for file_name in MNIST_FILES[split]]

    with open(image_file) as fd:
        images = np.fromfile(file=fd, dtype=np.uint8)
        images = images[16:].reshape(-1, 784).astype(np.float32)
        if split == "train":
            images = images[:55000]
        elif split == "eval":
            images = images[55000:]
    with open(label_file) as fd:
        labels = np.fromfile(file=fd, dtype=np.uint8)
        labels = labels[8:].astype(np.int32)
        if split == "train":
            labels = labels[:55000]
        elif split == "eval":
            labels = labels[55000:]
    return(zip(images, labels)) 
Example 13
Project: CapsLayer   Author: naturomics   File: writer.py    License: Apache License 2.0 6 votes vote down vote up
def load_mnist(path, split):
    split = split.lower()
    image_file, label_file = [os.path.join(path, file_name) for file_name in MNIST_FILES[split]]

    with open(image_file) as fd:
        images = np.fromfile(file=fd, dtype=np.uint8)
        images = images[16:].reshape(-1, 784).astype(np.float32)
        if split == "train":
            images = images[:55000]
        elif split == "eval":
            images = images[55000:]
    with open(label_file) as fd:
        labels = np.fromfile(file=fd, dtype=np.uint8)
        labels = labels[8:].astype(np.int32)
        if split == "train":
            labels = labels[:55000]
        elif split == "eval":
            labels = labels[55000:]
    return(zip(images, labels)) 
Example 14
Project: ibllib   Author: int-brain-lab   File: certification_protocol.py    License: MIT License 6 votes vote down vote up
def load_rf_mapping_stimulus(session_path, stim_metadata):
    """
    extract frames of rf mapping stimulus

    :param session_path: absolute path of a session, i.e. /mnt/data/Subjects/ZM_1887/2019-07-10/001
    :type session_path: str
    :param stim_metadata: dictionary of stimulus/task metadata
    :type stim_metadata: dict
    :return: stimulus frames
    :rtype: np.ndarray of shape (y_pix, x_pix, n_frames)
    """

    idx_rfm = get_stim_num_from_name(stim_metadata['VISUAL_STIMULI'], 'receptive_field_mapping')

    if idx_rfm is not None:
        stim_filename = stim_metadata['VISUAL_STIM_%i' % idx_rfm].get(
            'stim_data_file_name', '*RFMapStim.raw*')
        stim_file = glob.glob(os.path.join(session_path, 'raw_behavior_data', stim_filename))[0]
        frame_array = np.fromfile(stim_file, dtype='uint8')
        y_pix, x_pix, _ = stim_metadata['VISUAL_STIM_%i' % idx_rfm]['stim_file_shape']
        frames = np.transpose(np.reshape(frame_array, [y_pix, x_pix, -1], order='F'), [2, 1, 0])
    else:
        frames = np.array([])
    return frames 
Example 15
Project: typhon   Author: atmtools   File: topography.py    License: MIT License 6 votes vote down vote up
def get_tile(name):
        """
        Get tile with the given name.

        Check the cache for the tile with the given name. If not found, the
        tile is download.

        Args:
            name(str): The name of the tile.
        """
        dem_file = os.path.join(_get_data_path(), (name + ".dem").upper())
        if not (os.path.exists(dem_file)):
            SRTM30.download_tile(name)
        y = np.fromfile(dem_file, dtype = np.dtype('>i2')).reshape(SRTM30._tile_height,
                                                                   SRTM30._tile_width)
        return y 
Example 16
Project: typhon   Author: atmtools   File: catalogues.py    License: MIT License 6 votes vote down vote up
def from_xml(cls, xmlelement):
        """Loads a Sparse object from an existing file."""

        binaryfp = xmlelement.binaryfp
        nelem = int(xmlelement[0].attrib['nelem'])
        nrows = int(xmlelement.attrib['nrows'])
        ncols = int(xmlelement.attrib['ncols'])

        if binaryfp is None:
            rowindex = np.fromstring(xmlelement[0].text, sep=' ').astype(int)
            colindex = np.fromstring(xmlelement[1].text, sep=' ').astype(int)
            sparsedata = np.fromstring(xmlelement[2].text, sep=' ')
        else:
            rowindex = np.fromfile(binaryfp, dtype='<i4', count=nelem)
            colindex = np.fromfile(binaryfp, dtype='<i4', count=nelem)
            sparsedata = np.fromfile(binaryfp, dtype='<d', count=nelem)

        return cls((sparsedata, (rowindex, colindex)), [nrows, ncols]) 
Example 17
Project: typhon   Author: atmtools   File: read.py    License: MIT License 6 votes vote down vote up
def Vector(elem):
        nelem = int(elem.attrib['nelem'])
        if nelem == 0:
            arr = np.ndarray((0,))
        else:
            # sep=' ' seems to work even when separated by newlines, see
            # http://stackoverflow.com/q/31882167/974555
            if elem.binaryfp is not None:
                arr = np.fromfile(elem.binaryfp, dtype='<d', count=nelem)
            else:
                arr = np.fromstring(elem.text, sep=' ')
            if arr.size != nelem:
                raise RuntimeError(
                    'Expected {:s} elements in Vector, found {:d}'
                    ' elements!'.format(elem.attrib['nelem'],
                                        arr.size))
        return arr 
Example 18
Project: typhon   Author: atmtools   File: read.py    License: MIT License 6 votes vote down vote up
def ComplexMatrix(elem):
        # turn dims around: in ARTS, [10 x 1 x 1] means 10 pages, 1 row, 1 col
        dimnames = [dim for dim in dimension_names
                    if dim in elem.attrib.keys()][::-1]
        dims = [int(elem.attrib[dim]) for dim in dimnames]
        if np.prod(dims) == 0:
            flatarr = np.ndarray(dims, dtype=np.complex128)
        elif elem.binaryfp is not None:
            flatarr = np.fromfile(elem.binaryfp, dtype=np.complex128,
                                  count=np.prod(np.array(dims)).item())
            flatarr = flatarr.reshape(dims)
        else:
            flatarr = np.fromstring(elem.text, sep=' ', dtype=np.float64)
            flatarr.dtype = np.complex128
            flatarr = flatarr.reshape(dims)
        return flatarr 
Example 19
Project: Pointnet_Pointnet2_pytorch   Author: yanx27   File: plyfile.py    License: MIT License 6 votes vote down vote up
def _read(self, stream, text, byte_order):
        '''
        Read the actual data from a PLY file.
        '''
        if text:
            self._read_txt(stream)
        else:
            if self._have_list:
                # There are list properties, so a simple load is
                # impossible.
                self._read_bin(stream, byte_order)
            else:
                # There are no list properties, so loading the data is
                # much more straightforward.
                self._data = _np.fromfile(stream,
                                          self.dtype(byte_order),
                                          self.count)

        if len(self._data) < self.count:
            k = len(self._data)
            del self._data
            raise PlyParseError("early end-of-file", self, k)

        self._check_sanity() 
Example 20
Project: DenseMatchingBenchmark   Author: DeepMotionAIResearch   File: load_flow.py    License: MIT License 6 votes vote down vote up
def load_flo(file_path):
    """
    Read .flo file in MiddleBury format
    Code adapted from:
    http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy

    WARNING: this will work on little-endian architectures (eg Intel x86) only!
    Args:
        file_path string: file path(absolute)
    Returns:
        flow (numpy.array): data of image in (Height, Width, 2) layout
    """

    with open(file_path, 'rb') as f:
        magic = np.fromfile(f, np.float32, count=1)
        assert(magic == 202021.25)
        w = int(np.fromfile(f, np.int32, count=1))
        h = int(np.fromfile(f, np.int32, count=1))
        # print('Reading %d x %d flo file\n' % (w, h))
        flow = np.fromfile(f, np.float32, count=2 * w * h)
        # Reshape data into 3D array (columns, rows, bands)
        # The reshape here is for visualization, the original code is (w,h,2)
        flow = np.resize(flow, (h, w, 2))

    return flow 
Example 21
Project: PolarSeg   Author: edwardzhou130   File: dataset.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def __getitem__(self, index):
        raw_data = np.fromfile(self.im_idx[index], dtype=np.float32).reshape((-1, 4))
        if self.imageset == 'test':
            annotated_data = np.expand_dims(np.zeros_like(raw_data[:,0],dtype=int),axis=1)
        else:
            annotated_data = np.fromfile(self.im_idx[index].replace('velodyne','labels')[:-3]+'label', dtype=np.int32).reshape((-1,1))
            annotated_data = annotated_data & 0xFFFF #delete high 16 digits binary
            annotated_data = np.vectorize(self.learning_map.__getitem__)(annotated_data)
        data_tuple = (raw_data[:,:3], annotated_data.astype(np.uint8))
        if self.return_ref:
            data_tuple += (raw_data[:,3],)
        return data_tuple 
Example 22
Project: Generative-Latent-Optimization-Tensorflow   Author: clvrai   File: download.py    License: MIT License 5 votes vote down vote up
def download_mnist(download_path):
    data_dir = osp.join(download_path, 'mnist')

    if check_file(data_dir):
        print('MNIST was downloaded.')
        return

    data_url = 'http://yann.lecun.com/exdb/mnist/'
    keys = ['train-images-idx3-ubyte.gz', 't10k-images-idx3-ubyte.gz']

    for k in keys:
        url = (data_url+k).format(**locals())
        target_path = osp.join(data_dir, k)
        cmd = ['curl', url, '-o', target_path]
        print('Downloading ', k)
        subprocess.call(cmd)
        cmd = ['gzip', '-d', target_path]
        print('Unzip ', k)
        subprocess.call(cmd)

    num_mnist_train = 60000
    num_mnist_test = 10000

    fd = open(osp.join(data_dir, 'train-images-idx3-ubyte'))
    loaded = np.fromfile(file=fd, dtype=np.uint8)
    train_image = loaded[16:].reshape((num_mnist_train, 28, 28, 1)).astype(np.float)

    fd = open(osp.join(data_dir, 't10k-images-idx3-ubyte'))
    loaded = np.fromfile(file=fd, dtype=np.uint8)
    test_image = loaded[16:].reshape((num_mnist_test, 28, 28, 1)).astype(np.float)

    prepare_h5py(train_image, test_image, data_dir)

    for k in keys:
        cmd = ['rm', '-f', osp.join(data_dir, k[:-3])]
        subprocess.call(cmd) 
Example 23
Project: PSMNet   Author: JiaRenChang   File: readpfm.py    License: MIT License 5 votes vote down vote up
def readPFM(file):
    file = open(file, 'rb')

    color = None
    width = None
    height = None
    scale = None
    endian = None

    header = file.readline().rstrip()
    if header == 'PF':
        color = True
    elif header == 'Pf':
        color = False
    else:
        raise Exception('Not a PFM file.')

    dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline())
    if dim_match:
        width, height = map(int, dim_match.groups())
    else:
        raise Exception('Malformed PFM header.')

    scale = float(file.readline().rstrip())
    if scale < 0: # little-endian
        endian = '<'
        scale = -scale
    else:
        endian = '>' # big-endian

    data = np.fromfile(file, endian + 'f')
    shape = (height, width, 3) if color else (height, width)

    data = np.reshape(data, shape)
    data = np.flipud(data)
    return data, scale 
Example 24
Project: PSMNet   Author: JiaRenChang   File: readpfm.py    License: MIT License 5 votes vote down vote up
def readPFM(file):
    file = open(file, 'rb')

    color = None
    width = None
    height = None
    scale = None
    endian = None

    header = file.readline().rstrip()
    encode_type = chardet.detect(header)  
    header = header.decode(encode_type['encoding'])
    if header == 'PF':
        color = True
    elif header == 'Pf':
        color = False
    else:
        raise Exception('Not a PFM file.')

    dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode(encode_type['encoding']))
    if dim_match:
        width, height = map(int, dim_match.groups())
    else:
        raise Exception('Malformed PFM header.')

    scale = float(file.readline().rstrip().decode(encode_type['encoding']))
    if scale < 0: # little-endian
        endian = '<'
        scale = -scale
    else:
        endian = '>' # big-endian

    data = np.fromfile(file, endian + 'f')
    shape = (height, width, 3) if color else (height, width)

    data = np.reshape(data, shape)
    data = np.flipud(data)
    return data, scale 
Example 25
Project: pymesh   Author: taxpon   File: stl.py    License: MIT License 5 votes vote down vote up
def __load_binary(fh):
        # Read the triangle count
        count, = struct.unpack("i", fh.read(Stl.COUNT_SIZE))
        assert count < Stl.MAX_COUNT, \
            'File too large, got {} triangles which exceeds the maximum of {}' .format(
                count, Stl.MAX_COUNT
            )
        return numpy.fromfile(fh, Stl.stl_dtype, count=count) 
Example 26
Project: pointnet-registration-framework   Author: vinits5   File: plyfile.py    License: MIT License 5 votes vote down vote up
def _read_bin(self, stream, byte_order):
        '''
        Read data from a binary stream.  Raise StopIteration if the
        property could not be read.

        '''
        try:
            return _np.fromfile(stream, self.dtype(byte_order), 1)[0]
        except IndexError:
            raise StopIteration 
Example 27
Project: pointnet-registration-framework   Author: vinits5   File: plyfile.py    License: MIT License 5 votes vote down vote up
def _read_bin(self, stream, byte_order):
        (len_t, val_t) = self.list_dtype(byte_order)

        try:
            n = _np.fromfile(stream, len_t, 1)[0]
        except IndexError:
            raise StopIteration

        data = _np.fromfile(stream, val_t, n)
        if len(data) < n:
            raise StopIteration

        return data 
Example 28
Project: pyscf   Author: pyscf   File: m_openmx_mat.py    License: Apache License 2.0 5 votes vote down vote up
def fromfile(self, f, out=None, dtype=np.float):
    """ Read from an open file f """
    if out is None:
      res = np.zeros((self.natoms+1, self.FNAN_mx+1, self.Total_NumOrbs_mx, self.Total_NumOrbs_mx), dtype=dtype)
    else :
      res = out

    for ct_AN in range(1,self.natoms+1):
      for h_AN in range(0,self.FNAN[ct_AN]+1):
        for i in range(self.Total_NumOrbs[ct_AN]):
          c = self.Total_NumOrbs[self.natn[ct_AN,h_AN]]
          res[ct_AN,h_AN,i,0:c] = np.fromfile(f, count=c)
    
    return res 
Example 29
Project: MobileNetv2-SSDLite   Author: PINTO0309   File: load_caffe_weights.py    License: MIT License 5 votes vote down vote up
def load_data(net):
    for key in net.params.iterkeys():
        if type(net.params[key]) is caffe._caffe.BlobVec:
            print key
            if key.find('mbox') == -1 and (key.startswith("conv") or key.startswith("Conv") or key.startswith("layer")):
                print('conv')
                if key.endswith("/bn"):
                    prefix = 'output/' + key.replace('/', '_')
                    net.params[key][0].data[...] = np.fromfile(prefix + '_moving_mean.dat', dtype=np.float32)
                    net.params[key][1].data[...] = np.fromfile(prefix + '_moving_variance.dat', dtype=np.float32)
                    net.params[key][2].data[...] = np.ones(net.params[key][2].data.shape, dtype=np.float32)
                elif key.endswith("/scale"):
                    prefix = 'output/' + key.replace('scale','bn').replace('/', '_')
                    net.params[key][0].data[...] = np.fromfile(prefix + '_gamma.dat', dtype=np.float32)
                    net.params[key][1].data[...] = np.fromfile(prefix + '_beta.dat', dtype=np.float32)
                else:
                    prefix = 'output/' + key.replace('/', '_')
                    net.params[key][0].data[...] = np.fromfile(prefix + '_weights.dat', dtype=np.float32).reshape(net.params[key][0].data.shape)
                    if len(net.params[key]) > 1:
                        net.params[key][1].data[...] = np.fromfile(prefix + '_biases.dat', dtype=np.float32)
            elif key.endswith("mbox_loc"):
                prefix = key.replace("_mbox_loc", "")
                index = box_layers.index(prefix)
                prefix = 'output/BoxPredictor_' + str(index) + '_BoxEncodingPredictor'
                net.params[key][0].data[...] = np.fromfile(prefix + '_weights.dat', dtype=np.float32).reshape(net.params[key][0].data.shape)
                net.params[key][1].data[...] = np.fromfile(prefix + '_biases.dat', dtype=np.float32)
            elif key.endswith("mbox_conf"):
                prefix = key.replace("_mbox_conf", "")
                index = box_layers.index(prefix)
                prefix = 'output/BoxPredictor_' + str(index) + '_ClassPredictor'
                net.params[key][0].data[...] = np.fromfile(prefix + '_weights.dat', dtype=np.float32).reshape(net.params[key][0].data.shape)
                net.params[key][1].data[...] = np.fromfile(prefix + '_biases.dat', dtype=np.float32)
            else:
                print ("error key " + key) 
Example 30
Project: MobileNetv2-SSDLite   Author: PINTO0309   File: load_caffe_weights.py    License: MIT License 5 votes vote down vote up
def load_weights(path, shape=None):
    weights = None
    if shape is None: 
        weights = np.fromfile(path, dtype=np.float32)
    else:
        weights = np.fromfile(path, dtype=np.float32).reshape(shape)
    os.unlink(path)
    return weights