import itertools
from collections import defaultdict
import torch
from torch.utils.data.sampler import Sampler
from ..utils import common_functions as c_f
# Inspired by
# https://github.com/kunhe/Deep-Metric-Learning-Baselines/blob/master/datasets.py
class HierarchicalSampler(Sampler):
def __init__(
self,
labels,
batch_size,
samples_per_class,
batches_per_super_tuple=4,
super_classes_per_batch=2,
inner_label=0,
outer_label=1,
):
"""
labels: 2D array, where rows correspond to elements, and columns correspond to the hierarchical labels
batch_size: because this is a BatchSampler the batch size must be specified
samples_per_class: number of instances to sample for a specific class. set to "all" if all element in a class
batches_per_super_tuples: number of batches to create for a pair of categories (or super labels)
inner_label: columns index corresponding to classes
outer_label: columns index corresponding to the level of hierarchy for the pairs
"""
if torch.is_tensor(labels):
labels = labels.cpu().numpy()
self.batch_size = batch_size
self.batches_per_super_tuple = batches_per_super_tuple
self.samples_per_class = samples_per_class
self.super_classes_per_batch = super_classes_per_batch
# checks
assert (
self.batch_size % super_classes_per_batch == 0
), f"batch_size should be a multiple of {super_classes_per_batch}"
self.sub_batch_len = self.batch_size // super_classes_per_batch
if self.samples_per_class != "all":
assert self.samples_per_class > 0
assert (
self.sub_batch_len % self.samples_per_class == 0
), "batch_size not a multiple of samples_per_class"
all_super_labels = set(labels[:, outer_label])
self.super_image_lists = {slb: defaultdict(list) for slb in all_super_labels}
for idx, instance in enumerate(labels):
slb, lb = instance[outer_label], instance[inner_label]
self.super_image_lists[slb][lb].append(idx)
self.super_pairs = list(
itertools.combinations(all_super_labels, super_classes_per_batch)
)
self.reshuffle()
def __iter__(
self,
):
self.reshuffle()
for batch in self.batches:
yield batch
def __len__(
self,
):
return len(self.batches)
def reshuffle(self):
batches = []
for combinations in self.super_pairs:
for b in range(self.batches_per_super_tuple):
batch = []
for slb in combinations:
sub_batch = []
all_classes = list(self.super_image_lists[slb].keys())
c_f.NUMPY_RANDOM.shuffle(all_classes)
for cl in all_classes:
if len(sub_batch) >= self.sub_batch_len:
break
instances = self.super_image_lists[slb][cl]
samples_per_class = (
self.samples_per_class
if self.samples_per_class != "all"
else len(instances)
)
if len(sub_batch) + samples_per_class > self.sub_batch_len:
continue
sub_batch.extend(
c_f.safe_random_choice(instances, size=samples_per_class)
)
batch.extend(sub_batch)
c_f.NUMPY_RANDOM.shuffle(batch)
batches.append(batch)
c_f.NUMPY_RANDOM.shuffle(batches)
self.batches = batches