""" This module contains all the different batching methods
that are implemented. Any new batching method needs to inherit
from :class:`bgd.batch.Batching` and to implement its abstract
methods (:obj:`_start` and :obj:`next`). """
# batch.py: Batching methods for neural networks
# author : Antoine Passemiers, Robin Petit
__all__ = ['SGDBatching']
from abc import ABCMeta, abstractmethod
import numpy as np
[docs]class Batching(metaclass=ABCMeta):
""" Base class for generating batches.
Attributes:
warm (bool):
Whether a dataset has been passed to the
batching algorithm. Once it is warm,
method 'next' can be called to retrieve
a batch.
"""
def __init__(self):
self.warm = False
[docs] def start(self, X, y):
""" Wrapper method for providing a dataset
to the batching algorithm.
This *must* be called before to method :meth:`next` is called.
Args:
X (:obj:`np.ndarray`): Input samples
y (:obj:`np.ndarray`): Target values
"""
self._start(X, y)
self.warm = True
[docs] def next(self):
""" Wrapper method for retrieving a batch from
the provided dataset.
Returns:
:
X (:obj:`np.ndarray`):
Batch samples.
y (:obj:`np.ndarray`):
Batch target values.
"""
assert self.warm
return self._next()
[docs] @abstractmethod
def _start(self, X, y):
""" Wrapped method for providing a dataset
to the batching algorithm. Subclasses must
override this method."""
pass
[docs] @abstractmethod
def _next(self):
""" Wrapped method for retrieving a batch from
the provided dataset. Subclasses must override
this method."""
pass
[docs]class SGDBatching(Batching):
""" Stochastic Gradient Descent Batching algorithm.
Args:
batch_size (int):
Size of each random batch.
shuffle (bool):
Whether to shuffle the dataset before
extracting a batch from it.
"""
def __init__(self, batch_size, shuffle=True):
Batching.__init__(self)
self.batch_size = batch_size
self.shuffle = shuffle
self.batches = None
[docs] def _start(self, X, y):
""" Provide a dataset to the batching algorithm.
This *must* be called before to method :meth:`next` is called.
Args:
X (:obj:`np.ndarray`): Input samples.
y (:obj:`np.ndarray`): Target values.
"""
self.batch_size = min(len(X), self.batch_size)
self.batches = self.mini_batches(X, y)
[docs] def _next(self):
""" Retrieves next batch using a generator function.
Returns:
:
X (:obj:`np.ndarray`):
Batch samples.
y (:obj:`np.ndarray`):
Batch target values.
"""
return next(self.batches)
[docs] def mini_batches(self, X, y):
""" Generator function that iteratively yields batches.
Args:
X (:obj:`np.ndarray`): Input samples.
y (:obj:`np.ndarray`): Target values.
Yields:
:
X (:obj:`np.ndarray`):
Batch samples.
y (:obj:`np.ndarray`):
Batch target values.
"""
indices = np.arange(0, len(X), self.batch_size)
if self.shuffle: # If shuffled dataset
np.random.shuffle(indices)
for i in indices:
yield X[i:i + self.batch_size], y[i:i + self.batch_size]
yield None