Source code for macrosynergy.learning.forecasting.torch.samplers.timeseries_sampler
import torch
import torch.nn as nn
from torch.utils.data import Sampler
import numbers
[docs]class TimeSeriesSampler(Sampler):
"""
Batch sampler for datasets indexed by time, to ensure that batches are comprised of
samples from contiguous time periods.
Parameters
----------
dataset : torch.utils.data.Dataset
The PyTorch dataset to sample from.
batch_size : int
Number of samples per batch.
shuffle : bool, optional
Whether to shuffle the order of batches. Default is True.
aggregate_last : bool, optional
Whether to aggregate the last batch with the previous one if it has length
smaller than batch_size. Default is True.
drop_last : bool, optional
Whether to drop the last batch if it has length smaller than batch_size.
Default is False.
"""
def __init__(self, dataset, batch_size, shuffle = True, aggregate_last = True, drop_last = False):
# Checks
self._check_init_params(
dataset,
batch_size,
shuffle,
aggregate_last,
drop_last,
)
# Attributes
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.aggregate_last = aggregate_last
self.drop_last = drop_last
self.dataset_size = len(dataset)
# Determine batches
self.batches = self._create_batches(
self.batch_size,
self.dataset_size,
self.aggregate_last,
self.drop_last,
)
def _create_batches(self, batch_size, dataset_size, aggregate_last, drop_last):
""" Create list of batches """
batches = [
list(range(start, min(start + batch_size, dataset_size)))
for start in range(0, dataset_size, batch_size)
]
if aggregate_last:
if len(batches) > 1 and len(batches[-1]) < batch_size:
batches[-2].extend(batches[-1])
batches = batches[:-1]
if drop_last:
if len(batches) > 1 and len(batches[-1]) < batch_size:
batches = batches[:-1]
return batches
def __iter__(self):
"""
Generator for batch indices.
"""
if self.shuffle:
batch_indices = torch.randperm(len(self.batches)).tolist()
else:
batch_indices = range(len(self.batches))
for idx in batch_indices:
yield self.batches[idx]
def __len__(self):
""" Returns number of batches """
return len(self.batches)
def _check_init_params(
self,
dataset,
batch_size,
shuffle,
aggregate_last,
drop_last,
):
# dataset
if not isinstance(dataset, torch.utils.data.Dataset):
raise TypeError("dataset must be a torch.utils.data.Dataset instance.")
# batch_size
if not isinstance(batch_size, numbers.Integral):
raise TypeError("batch_size must be an integer.")
if batch_size < 1:
raise ValueError("batch_size must be at least 1.")
# shuffle
if not isinstance(shuffle, bool):
raise TypeError("shuffle must be a boolean.")
# aggregate_last
if not isinstance(aggregate_last, bool):
raise TypeError("aggregate_last must be a boolean.")
# drop_last
if not isinstance(drop_last, bool):
raise TypeError("drop_last must be a boolean.")
if aggregate_last and drop_last:
raise ValueError("aggregate_last and drop_last cannot both be True.")