Python provider.rotate_point_cloud() Examples

The following are 14 code examples of provider.rotate_point_cloud(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module provider , or try the search function .
Example #1
Source File: ply_dataset.py    From chainer-pointnet with MIT License 5 votes vote down vote up
def get_example(self, i):
        """Return i-th data"""
        if self.augment:
            rotated_data = provider.rotate_point_cloud(
                self.data[i:i + 1, :, :])
            jittered_data = provider.jitter_point_cloud(rotated_data)
            point_data = jittered_data[0]
        else:
            point_data = self.data[i]
        # pint_data (2048, 3): (num_point, k) --> convert to (k, num_point, 1)
        point_data = np.transpose(
            point_data.astype(np.float32), (1, 0))[:, :, None]
        assert point_data.dtype == np.float32
        assert self.label[i].dtype == np.int32
        return point_data, self.label[i] 
Example #2
Source File: modelnet_h5_dataset.py    From dfc2019 with MIT License 5 votes vote down vote up
def _augment_batch_data(self, batch_data):
        rotated_data = provider.rotate_point_cloud(batch_data)
        rotated_data = provider.rotate_perturbation_point_cloud(rotated_data)
        jittered_data = provider.random_scale_point_cloud(rotated_data[:,:,0:3])
        jittered_data = provider.shift_point_cloud(jittered_data)
        jittered_data = provider.jitter_point_cloud(jittered_data)
        rotated_data[:,:,0:3] = jittered_data
        return provider.shuffle_points(rotated_data) 
Example #3
Source File: modelnet_dataset.py    From dfc2019 with MIT License 5 votes vote down vote up
def _augment_batch_data(self, batch_data):
        if self.normal_channel:
            rotated_data = provider.rotate_point_cloud_with_normal(batch_data)
            rotated_data = provider.rotate_perturbation_point_cloud_with_normal(rotated_data)
        else:
            rotated_data = provider.rotate_point_cloud(batch_data)
            rotated_data = provider.rotate_perturbation_point_cloud(rotated_data)
    
        jittered_data = provider.random_scale_point_cloud(rotated_data[:,:,0:3])
        jittered_data = provider.shift_point_cloud(jittered_data)
        jittered_data = provider.jitter_point_cloud(jittered_data)
        rotated_data[:,:,0:3] = jittered_data
        return provider.shuffle_points(rotated_data) 
Example #4
Source File: train_partseg.py    From scanobjectnn with MIT License 5 votes vote down vote up
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True
    
    current_data, current_label, current_parts = data_utils.get_current_data_parts_h5(TRAIN_DATA, TRAIN_LABELS, TRAIN_PARTS, NUM_POINT)

    current_label = np.squeeze(current_label)
    current_parts = np.squeeze(current_parts)

    num_batches = current_data.shape[0]//BATCH_SIZE

    total_seen = 0
    loss_sum = 0
    total_correct_seg = 0    

    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = (batch_idx+1) * BATCH_SIZE

        # Augment batched point clouds by rotation and jittering
        rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
        jittered_data = provider.jitter_point_cloud(rotated_data)
        feed_dict = {ops['pointclouds_pl']: jittered_data,
                     ops['labels_pl']: current_label[start_idx:end_idx],
                     ops['parts_pl']: current_parts[start_idx:end_idx],
                     ops['is_training_pl']: is_training,}
        summary, step, _, loss_val, seg_val = sess.run([ops['merged'], ops['step'],
            ops['train_op'], ops['loss'], ops['seg_pred']], feed_dict=feed_dict)
        train_writer.add_summary(summary, step)
        
        seg_val = np.argmax(seg_val, 2)
        seg_correct = np.sum(seg_val == current_parts[start_idx:end_idx])
        total_correct_seg += seg_correct

        total_seen += BATCH_SIZE
        loss_sum += loss_val

    
    log_string('mean loss: %f' % (loss_sum / float(num_batches)))
    log_string('seg accuracy: %f' % (total_correct_seg / (float(total_seen)*NUM_POINT))) 
Example #5
Source File: train.py    From scanobjectnn with MIT License 5 votes vote down vote up
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True
    
    if (".h5" in TRAIN_FILE):
        current_data, current_label = data_utils.get_current_data_h5(TRAIN_DATA, TRAIN_LABELS, NUM_POINT)
    else:
        current_data, current_label = data_utils.get_current_data(TRAIN_DATA, TRAIN_LABELS, NUM_POINT)


    current_label = np.squeeze(current_label)

    num_batches = current_data.shape[0]//BATCH_SIZE

    total_correct = 0
    total_seen = 0
    loss_sum = 0
       
    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = (batch_idx+1) * BATCH_SIZE
        
        # Augment batched point clouds by rotation and jittering
        rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
        jittered_data = provider.jitter_point_cloud(rotated_data)

        feed_dict = {ops['pointclouds_pl']: jittered_data,
                     ops['labels_pl']: current_label[start_idx:end_idx],
                     ops['is_training_pl']: is_training,}
        summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
            ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict)
        train_writer.add_summary(summary, step)
        pred_val = np.argmax(pred_val, 1)
        correct = np.sum(pred_val == current_label[start_idx:end_idx])
        total_correct += correct
        total_seen += BATCH_SIZE
        loss_sum += loss_val
    
    log_string('mean loss: %f' % (loss_sum / float(num_batches)))
    log_string('accuracy: %f' % (total_correct / float(total_seen))) 
Example #6
Source File: train_xyz.py    From SpiderCNN with MIT License 4 votes vote down vote up
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True
    
    # Shuffle train files
    train_file_idxs = np.arange(0, len(TRAIN_FILES))
    np.random.shuffle(train_file_idxs)
    
    for fn in range(len(TRAIN_FILES)):
        log_string('----' + str(fn) + '-----')
        current_data, current_label, _ = provider.loadDataFile_with_normal(TRAIN_FILES[train_file_idxs[fn]])
        current_data = current_data[:,0:NUM_POINT,:]
        current_data, current_label, _ = provider.shuffle_data(current_data, np.squeeze(current_label))           
        current_label = np.squeeze(current_label)

        file_size = current_data.shape[0]
        num_batches = file_size // BATCH_SIZE
        
        total_correct = 0
        total_seen = 0
        loss_sum = 0
       
        for batch_idx in range(num_batches):
            start_idx = batch_idx * BATCH_SIZE
            end_idx = (batch_idx+1) * BATCH_SIZE
            
            # Augment batched point clouds by rotation and jittering
            rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
            jittered_data = provider.jitter_point_cloud(rotated_data)

            feed_dict = {ops['pointclouds_pl']: jittered_data,
                         ops['labels_pl']: current_label[start_idx:end_idx],
                         ops['is_training_pl']: is_training,}
            summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
                ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict)
            train_writer.add_summary(summary, step)
            pred_val = np.argmax(pred_val, 1)
            correct = np.sum(pred_val == current_label[start_idx:end_idx])
            total_correct += correct
            total_seen += BATCH_SIZE
            loss_sum += loss_val
        
        log_string('mean loss: %f' % (loss_sum / float(num_batches)))
        log_string('accuracy: %f' % (total_correct / float(total_seen))) 
Example #7
Source File: train.py    From SpiderCNN with MIT License 4 votes vote down vote up
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True
    
    # Shuffle train files
    train_file_idxs = np.arange(0, len(TRAIN_FILES))
    np.random.shuffle(train_file_idxs)
    
    for fn in range(len(TRAIN_FILES)):
        log_string('----' + str(fn) + '-----')
        current_data, current_label, normal_data = provider.loadDataFile_with_normal(TRAIN_FILES[train_file_idxs[fn]])
        normal_data = normal_data[:,0:NUM_POINT,:]
        current_data = current_data[:,0:NUM_POINT,:]
        current_data, current_label, shuffle_idx = provider.shuffle_data(current_data, np.squeeze(current_label))           
        current_label = np.squeeze(current_label)
        normal_data = normal_data[shuffle_idx, ...]

        file_size = current_data.shape[0]
        num_batches = file_size // BATCH_SIZE
        
        total_correct = 0
        total_seen = 0
        loss_sum = 0
       
        for batch_idx in range(num_batches):
            start_idx = batch_idx * BATCH_SIZE
            end_idx = (batch_idx+1) * BATCH_SIZE
            
            # Augment batched point clouds by rotation and jittering
            rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
            jittered_data = provider.jitter_point_cloud(rotated_data)
            input_data = np.concatenate((jittered_data, normal_data[start_idx:end_idx, :, :]), 2)
            #random point dropout
            input_data = provider.random_point_dropout(input_data)


            feed_dict = {ops['pointclouds_pl']: input_data,
                         ops['labels_pl']: current_label[start_idx:end_idx],
                         ops['is_training_pl']: is_training,}
            summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
                ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict)
            train_writer.add_summary(summary, step)
            pred_val = np.argmax(pred_val, 1)
            correct = np.sum(pred_val == current_label[start_idx:end_idx])
            total_correct += correct
            total_seen += BATCH_SIZE
            loss_sum += loss_val
        
        log_string('mean loss: %f' % (loss_sum / float(num_batches)))
        log_string('accuracy: %f' % (total_correct / float(total_seen))) 
Example #8
Source File: train_seg.py    From scanobjectnn with MIT License 4 votes vote down vote up
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True
    
    current_data, current_label, current_mask = data_utils.get_current_data_withmask_h5(TRAIN_DATA, TRAIN_LABELS, TRAIN_MASKS, NUM_POINT)

    current_label = np.squeeze(current_label)
    current_mask = np.squeeze(current_mask)

    num_batches = current_data.shape[0]//BATCH_SIZE

    total_correct = 0
    total_seen = 0
    loss_sum = 0
    total_correct_seg = 0    
    classify_loss_sum = 0
    seg_loss_sum = 0
    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = (batch_idx+1) * BATCH_SIZE

        # Augment batched point clouds by rotation and jittering
        rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
        jittered_data = provider.jitter_point_cloud(rotated_data)
        feed_dict = {ops['pointclouds_pl']: jittered_data,
                     ops['labels_pl']: current_label[start_idx:end_idx],
                     ops['masks_pl']: current_mask[start_idx:end_idx],
                     ops['is_training_pl']: is_training,}
        summary, step, _, loss_val, pred_val, seg_val, classify_loss, seg_loss = sess.run([ops['merged'], ops['step'],
            ops['train_op'], ops['loss'], ops['pred'], ops['seg_pred'], ops['classify_loss'], ops['seg_loss']], feed_dict=feed_dict)
        train_writer.add_summary(summary, step)
        pred_val = np.argmax(pred_val, 1)
        correct = np.sum(pred_val == current_label[start_idx:end_idx])
        
        seg_val = np.argmax(seg_val, 2)
        seg_correct = np.sum(seg_val == current_mask[start_idx:end_idx])
        total_correct_seg += seg_correct

        total_correct += correct
        total_seen += BATCH_SIZE
        loss_sum += loss_val
        classify_loss_sum += classify_loss
        seg_loss_sum += seg_loss
    
    log_string('mean loss: %f' % (loss_sum / float(num_batches)))
    log_string('classify mean loss: %f' % (classify_loss_sum / float(num_batches)))
    log_string('seg mean loss: %f' % (seg_loss_sum / float(num_batches)))
    log_string('accuracy: %f' % (total_correct / float(total_seen)))
    log_string('seg accuracy: %f' % (total_correct_seg / (float(total_seen)*NUM_POINT))) 
Example #9
Source File: train.py    From scanobjectnn with MIT License 4 votes vote down vote up
def train_one_epoch(sess, ops, gmm, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True

    if (".h5" in TRAIN_FILE):
        current_data, current_label = data_utils.get_current_data_h5(TRAIN_DATA, TRAIN_LABELS, NUM_POINT)
    else:
        current_data, current_label = data_utils.get_current_data(TRAIN_DATA, TRAIN_LABELS, NUM_POINT)


    current_label = np.squeeze(current_label)

    num_batches = current_data.shape[0]//BATCH_SIZE

    total_correct = 0
    total_seen = 0
    loss_sum = 0

    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = (batch_idx + 1) * BATCH_SIZE

        # Augment batched point clouds by rotation and jittering

        augmented_data = current_data[start_idx:end_idx, :, :]
        if augment_scale:
            augmented_data = provider.scale_point_cloud(augmented_data, smin=0.66, smax=1.5)
        if augment_rotation:
            augmented_data = provider.rotate_point_cloud(augmented_data)
        if augment_translation:
            augmented_data = provider.translate_point_cloud(augmented_data, tval = 0.2)
        if augment_jitter:
            augmented_data = provider.jitter_point_cloud(augmented_data, sigma=0.01,
                                                    clip=0.05)  # default sigma=0.01, clip=0.05
        if augment_outlier:
            augmented_data = provider.insert_outliers_to_point_cloud(augmented_data, outlier_ratio=0.02)

        feed_dict = {ops['points_pl']: augmented_data,
                     ops['labels_pl']: current_label[start_idx:end_idx],
                     ops['w_pl']: gmm.weights_,
                     ops['mu_pl']: gmm.means_,
                     ops['sigma_pl']: np.sqrt(gmm.covariances_),
                     ops['is_training_pl']: is_training, }
        summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
                                                         ops['train_op'], ops['loss'], ops['pred']],
                                                        feed_dict=feed_dict)

        train_writer.add_summary(summary, step)
        pred_val = np.argmax(pred_val, 1)
        correct = np.sum(pred_val == current_label[start_idx:end_idx])
        total_correct += correct
        total_seen += BATCH_SIZE
        loss_sum += loss_val

    log_string('mean loss: %f' % (loss_sum / float(num_batches)))
    log_string('accuracy: %f' % (total_correct / float(total_seen))) 
Example #10
Source File: train_seg.py    From scanobjectnn with MIT License 4 votes vote down vote up
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True

    current_data, current_label, current_mask = data_utils.get_current_data_withmask_h5(TRAIN_DATA, TRAIN_LABELS, TRAIN_MASKS, NUM_POINT)

    current_label = np.squeeze(current_label)
    current_mask = np.squeeze(current_mask)

    num_batches = current_data.shape[0]//BATCH_SIZE

    total_correct = 0
    total_seen = 0
    loss_sum = 0
    total_correct_seg = 0    
    classify_loss_sum = 0
    seg_loss_sum = 0  
    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = (batch_idx+1) * BATCH_SIZE

        # Augment batched point clouds by rotation and jittering
        rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
        jittered_data = provider.jitter_point_cloud(rotated_data)
        feed_dict = {ops['pointclouds_pl']: jittered_data,
                     ops['labels_pl']: current_label[start_idx:end_idx],
                     ops['masks_pl']: current_mask[start_idx:end_idx],
                     ops['is_training_pl']: is_training,}
        summary, step, _, loss_val, pred_val, seg_val, classify_loss, seg_loss = sess.run([ops['merged'], ops['step'],
            ops['train_op'], ops['loss'], ops['pred'], ops['seg_pred'], ops['classify_loss'], ops['seg_loss']], feed_dict=feed_dict)
        train_writer.add_summary(summary, step)
        pred_val = np.argmax(pred_val, 1)
        correct = np.sum(pred_val == current_label[start_idx:end_idx])

        seg_val = np.argmax(seg_val, 2)
        seg_correct = np.sum(seg_val == current_mask[start_idx:end_idx])
        total_correct_seg += seg_correct

        total_correct += correct
        total_seen += BATCH_SIZE
        loss_sum += loss_val
        classify_loss_sum += classify_loss
        seg_loss_sum += seg_loss          
    
    log_string('mean loss: %f' % (loss_sum / float(num_batches)))
    log_string('classify mean loss: %f' % (classify_loss_sum / float(num_batches)))
    log_string('seg mean loss: %f' % (seg_loss_sum / float(num_batches)))
    log_string('accuracy: %f' % (total_correct / float(total_seen)))
    log_string('seg accuracy: %f' % (total_correct_seg / (float(total_seen)*NUM_POINT))) 
Example #11
Source File: train_seg.py    From scanobjectnn with MIT License 4 votes vote down vote up
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True
    
    log_string(str(datetime.now()))

    current_data, current_label, current_mask = data_utils.get_current_data_withmask_h5(TRAIN_DATA, TRAIN_LABELS, TRAIN_MASKS, NUM_POINT)

    current_label = np.squeeze(current_label)
    current_mask = np.squeeze(current_mask)

    num_batches = current_data.shape[0]//BATCH_SIZE

    total_correct = 0
    total_seen = 0
    loss_sum = 0
    total_correct_seg = 0    
    classify_loss_sum = 0
    seg_loss_sum = 0    
    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = (batch_idx+1) * BATCH_SIZE

        # Augment batched point clouds by rotation and jittering
        rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
        jittered_data = provider.jitter_point_cloud(rotated_data)
        feed_dict = {ops['pointclouds_pl']: jittered_data,
                     ops['labels_pl']: current_label[start_idx:end_idx],
                     ops['masks_pl']: current_mask[start_idx:end_idx],
                     ops['is_training_pl']: is_training,}
        summary, step, _, loss_val, pred_val, seg_val, classify_loss, seg_loss = sess.run([ops['merged'], ops['step'],
            ops['train_op'], ops['loss'], ops['pred'], ops['seg_pred'], ops['classify_loss'], ops['seg_loss']], feed_dict=feed_dict)
        train_writer.add_summary(summary, step)
        pred_val = np.argmax(pred_val, 1)
        correct = np.sum(pred_val == current_label[start_idx:end_idx])

        seg_val = np.argmax(seg_val, 2)
        seg_correct = np.sum(seg_val == current_mask[start_idx:end_idx])
        total_correct_seg += seg_correct

        total_correct += correct
        total_seen += BATCH_SIZE
        loss_sum += loss_val
        classify_loss_sum += classify_loss
        seg_loss_sum += seg_loss        
    
    log_string('mean loss: %f' % (loss_sum / float(num_batches)))
    log_string('classify mean loss: %f' % (classify_loss_sum / float(num_batches)))
    log_string('seg mean loss: %f' % (seg_loss_sum / float(num_batches)))
    log_string('accuracy: %f' % (total_correct / float(total_seen)))
    log_string('seg accuracy: %f' % (total_correct_seg / (float(total_seen)*NUM_POINT))) 
Example #12
Source File: train.py    From scanobjectnn with MIT License 4 votes vote down vote up
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True
    
    log_string(str(datetime.now()))

    #Shuffle data
    # data_utils.shuffle_points(TRAIN_DATA)

    #get current data, shuffle and set to numpy array with desired num_point
    # current_data, current_label = data_utils.get_current_data(TRAIN_DATA, TRAIN_LABELS, NUM_POINT)
    if (".h5" in TRAIN_FILE):
        current_data, current_label = data_utils.get_current_data_h5(TRAIN_DATA, TRAIN_LABELS, NUM_POINT)
    else:
        current_data, current_label = data_utils.get_current_data(TRAIN_DATA, TRAIN_LABELS, NUM_POINT)

    current_label = np.squeeze(current_label)

    num_batches = current_data.shape[0]//BATCH_SIZE

    total_correct = 0
    total_seen = 0
    loss_sum = 0
    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = (batch_idx+1) * BATCH_SIZE

        # Augment batched point clouds by rotation and jittering
        rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
        jittered_data = provider.jitter_point_cloud(rotated_data)
        feed_dict = {ops['pointclouds_pl']: jittered_data,
                     ops['labels_pl']: current_label[start_idx:end_idx],
                     ops['is_training_pl']: is_training,}
        summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
            ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict)
        train_writer.add_summary(summary, step)
        pred_val = np.argmax(pred_val, 1)
        correct = np.sum(pred_val == current_label[start_idx:end_idx])
        total_correct += correct
        total_seen += BATCH_SIZE
        loss_sum += loss_val
    
    log_string('mean loss: %f' % (loss_sum / float(num_batches)))
    log_string('accuracy: %f' % (total_correct / float(total_seen))) 
Example #13
Source File: train.py    From PointCNN.Pytorch with MIT License 4 votes vote down vote up
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True
    
    # Shuffle train files
    train_file_idxs = np.arange(0, len(TRAIN_FILES))
    np.random.shuffle(train_file_idxs)
    
    for fn in range(len(TRAIN_FILES)):
        log_string('----' + str(fn) + '-----')
        current_data, current_label = provider.loadDataFile(TRAIN_FILES[train_file_idxs[fn]])
        current_data = current_data[:,0:NUM_POINT,:]
        current_data, current_label, _ = provider.shuffle_data(current_data, np.squeeze(current_label))            
        current_label = np.squeeze(current_label)
        
        file_size = current_data.shape[0]
        num_batches = file_size // BATCH_SIZE
        
        total_correct = 0
        total_seen = 0
        loss_sum = 0
       
        for batch_idx in range(num_batches):
            start_idx = batch_idx * BATCH_SIZE
            end_idx = (batch_idx+1) * BATCH_SIZE
            
            # Augment batched point clouds by rotation and jittering
            rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
            jittered_data = provider.jitter_point_cloud(rotated_data)
            feed_dict = {ops['pointclouds_pl']: jittered_data,
                         ops['labels_pl']: current_label[start_idx:end_idx],
                         ops['is_training_pl']: is_training,}
            summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
                ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict)
            train_writer.add_summary(summary, step)
            pred_val = np.argmax(pred_val, 1)
            correct = np.sum(pred_val == current_label[start_idx:end_idx])
            total_correct += correct
            total_seen += BATCH_SIZE
            loss_sum += loss_val
        
        log_string('mean loss: %f' % (loss_sum / float(num_batches)))
        log_string('accuracy: %f' % (total_correct / float(total_seen))) 
Example #14
Source File: train.py    From ldgcnn with MIT License 4 votes vote down vote up
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True
    
    # Shuffle train files
    train_file_idxs = np.arange(0, len(TRAIN_FILES))
    np.random.shuffle(train_file_idxs)
    
    for fn in range(len(TRAIN_FILES)):
        log_string('----' + str(fn) + '-----')
        # Load data and labels from the files.
        current_data, current_label = provider.loadDataFile(TRAIN_FILES[train_file_idxs[fn]])
        current_data = current_data[:,0:NUM_POINT,:]
        # Shuffle the data in the training set.
        current_data, current_label, _ = provider.shuffle_data(current_data, np.squeeze(current_label))            
        current_label = np.squeeze(current_label)
        
        file_size = current_data.shape[0]
        num_batches = file_size // BATCH_SIZE
        
        total_correct = 0
        total_seen = 0
        loss_sum = 0
       
        for batch_idx in range(num_batches):
            start_idx = batch_idx * BATCH_SIZE
            end_idx = (batch_idx+1) * BATCH_SIZE
            
            # Augment batched point clouds by rotating, jittering, shifting, 
            # and scaling.
            rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
            jittered_data = provider.jitter_point_cloud(rotated_data)
            jittered_data = provider.random_scale_point_cloud(jittered_data)
            jittered_data = provider.rotate_perturbation_point_cloud(jittered_data)
            jittered_data = provider.shift_point_cloud(jittered_data)
            
            # Input the augmented point cloud and labels to the graph.
            feed_dict = {ops['pointclouds_pl']: jittered_data,
                         ops['labels_pl']: current_label[start_idx:end_idx],
                         ops['is_training_pl']: is_training,}
            
            # Calculate the loss and accuracy of the input batch data.            
            summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
                ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict)
            
            train_writer.add_summary(summary, step)
            pred_val = np.argmax(pred_val, 1)
            correct = np.sum(pred_val == current_label[start_idx:end_idx])
            total_correct += correct
            total_seen += BATCH_SIZE
            loss_sum += loss_val
        
        log_string('mean loss: %f' % (loss_sum / float(num_batches)))
        log_string('accuracy: %f' % (total_correct / float(total_seen)))