Shakerato

How to implement next_batch for mini batch gradient descent in deep learning 본문

Research

How to implement next_batch for mini batch gradient descent in deep learning

Shakeratto 2018. 4. 2. 15:29


Full Code: https://www.kaggle.com/pedrolcn/deep-tensorflow-ccn-cross-validation

class TrainBatcher(object):

    # Class constructor
    def __init__(self, examples, labels):
        self.labels = labels
        self.examples = examples
        self.index_in_epoch = 0
        self.num_examples = examples.shape[0]

    # mini-batching method
    def next_batch(self, batch_size):
        start = self.index_in_epoch
        self.index_in_epoch += batch_size

        # When all the training data is ran, shuffles it
        if self.index_in_epoch > self.num_examples:
            perm = np.arange(self.num_examples)
            np.random.shuffle(perm)
            self.examples = self.examples[perm]
            self.labels = self.labels[perm]
            # Start next epoch
            start = 0
            self.index_in_epoch = batch_size
            assert batch_size <= self.num_examples
        end = self.index_in_epoch
        
        return self.examples[start:end], self.labels[start:end]

mnist = TrainBatcher(train_images, train_labels)
batch_xs, batch_ys = mnist.next_batch(BATCH_SIZE)


Comments