Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
991 views
in Technique[技术] by (71.8m points)

pytorch - Customizing the batch with specific elements

I am a fresh starter with pytorch. Strangley I cannot find anything related to this, although it seems rather simple.

I want to structure my batch with specific examples, like all examples per batch having the same label or just fill the batch with examples of just 2 classes.

How would I do that? For me it seems the right place within the dataloader and not in the dataset? As the dataloader is responsible for the batches and not the dataset?

Is there simple minimal example?

question from:https://stackoverflow.com/questions/66065272/customizing-the-batch-with-specific-elements

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

TLDR;

  1. Default DataLoader only uses a sampler, not a batch sampler.

  2. You can define a sampler, plus a batch sampler, a batch sampler will override the sampler.

  3. The sampler only yields the sequence of dataset element, not the actual batches (this is handled by the data loader, depending on batch_size).


To answer your initial question: Working with a sampler on an iterable dataset doesn't seem to be possible cf. Github issue (still open). Also, read the following note on pytorch/dataloader.py.


Samplers (for map-style datasets):

That aside, if you are switching to a map-style dataset, here are some details on samplers and batch samplers work. You have access to a dataset's underlying data using indices, just like you would with a list (since torch.utils.data.Dataset implements __getitem__). In another word, your dataset elements are all dataset[i], for i in [0, len(dataset) - 1].

Here is a toy dataset:

class DS(Dataset):
    def __getitem__(self, index):
        return index
        
    def __len__(self):
        return 10

In a general use case you would just give torch.utils.data.DataLoader the arguments batch_size and shuffle. By default, shuffle is set to false, which means it will use torch.utils.data.SequentialSampler. Else (if shuffle is true) torch.utils.data.RandomSampler will be used. The sampler defines how the data loader accesses the dataset (in which order it accesses it).

The above dataset (DS) has 10 elements. The indices are 0, 1, 2, 3, 4, 5, 6, 7, 8, and 9. They map to elements 0, 10, 20, 30, 40, 50, 60, 70, 80, and 90. So with a batch size of 2:

  • SequentialSampler: DataLoader(ds, batch_size=2) (implictly shuffle=False), identical to DataLoader(ds, batch_size=2, sampler=SequentialSampler(ds)). The dataloader will deliver tensor([0, 10]), tensor([20, 30]), tensor([40, 50]), tensor([60, 70]), and tensor([80, 90]).

  • RandomSampler: DataLoader(ds, batch_size=2, shuffle=True), identical to DataLoader(ds, batch_size=2, sampler=RandomSampler(ds)). The dataloader will sample randomly each time you iterate through it. For instance: tensor([50, 40]), tensor([90, 80]), tensor([0, 60]), tensor([10, 20]), and tensor([30, 70]). But the sequence will be different if you iterate through the dataloader a second time!


Batch sampler

Providing batch_sampler will override batch_size, shuffle, sampler, and drop_last altogether. It is meant to define exactly the batch elements and their content. For instance:

>>> DataLoader(ds, batch_sampler=[[1,2,3], [6,5,4], [7,8], [0,9]])` 

Will yield tensor([10, 20, 30]), tensor([60, 50, 40]), tensor([70, 80]), and tensor([ 0, 90]).


Batch sampling on the class

Let's say I just want to have 2 elements (different or not) of each class in my batch and have to exclude more examples of each class. So ensuring that not 3 examples are inside of the batch.

Let's say you have a dataset with four classes. Here is how I would do it. First, keep track of dataset indices for each class.

class DS(Dataset):
    def __init__(self, data):
        super(DS, self).__init__()
        self.data = data

        self.indices = [[] for _ in range(4)]
        for i, x in enumerate(data):
            if x > 0 and x % 2: self.indices[0].append(i)
            if x > 0 and not x % 2: self.indices[1].append(i)
            if x < 0 and x % 2: self.indices[2].append(i)
            if x < 0 and not x % 2: self.indices[3].append(i)

    def classes(self):
        return self.indices

    def __getitem__(self, index):
        return self.data[index]

For example:

>>> ds = DS([1, 6, 7, -5, 10, -6, 8, 6, 1, -3, 9, -21, -13, 11, -2, -4, -21, 4])

Will give:

>>> ds.classes()
[[0, 2, 8, 10, 13], [1, 4, 6, 7, 17], [3, 9, 11, 12, 16], [5, 14, 15]]

Then for the batch sampler, the easiest way is to create a list of class indices that are available, and have as many class index as there are dataset element.

In the dataset defined above, we have 5 items from class 0, 5 from class 1, 5 from class 2, and 3 from class 3. Therefore we want to construct [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3]. We will shuffle it. Then, from this list and the dataset classes content (ds.classes()) we will be able to construct the batches.

class Sampler():
    def __init__(self, classes):
        self.classes = classes

    def __iter__(self):
        classes = copy.deepcopy(self.classes)
        indices = flatten([[i for _ in range(len(klass))] for i, klass in enumerate(classes)])
        random.shuffle(indices)
        grouped = zip(*[iter(indices)]*2)

        res = []
        for a, b in grouped:
            res.append((classes[a].pop(), classes[b].pop()))
        return iter(res)

Note - deep copying the list is required since we're popping elements from it.

A possible output of this sampler would be:

[(15, 14), (16, 17), (7, 12), (11, 6), (13, 10), (5, 4), (9, 8), (2, 0), (3, 1)]

At this point we can simply use torch.data.utils.DataLoader:

>>> dl = DataLoader(ds, batch_sampler=sampler(ds.classes()))

Which could yield something like:

[tensor([ 4, -4]), tensor([-21,  11]), tensor([-13,   6]), tensor([9, 1]), tensor([  8, -21]), tensor([-3, 10]), tensor([ 6, -2]), tensor([-5,  7]), tensor([-6,  1])]

An easier approach

Here is another - easier - approach that will not guarantee to return all elements from the dataset, on average it will...

For each batch, first sample class_per_batch classes, then sample batch_size elements from these selected classes (by first sampling a class from that class subset, then sampling from a data point from that class).

class Sampler():
    def __init__(self, classes, class_per_batch, batch_size):
        self.classes = classes
        self.n_batches = sum([len(x) for x in classes]) // batch_size
        self.class_per_batch = class_per_batch
        self.batch_size = batch_size

    def __iter__(self):
        classes = random.sample(range(len(self.classes)), self.class_per_batch)
        
        batches = []
        for _ in range(self.n_batches):
            batch = []
            for i in range(self.batch_size):
                klass = random.choice(classes)
                batch.append(random.choice(self.classes[klass]))
            batches.append(batch)
        return iter(batches)

You can try it this way:

>>> s = Sampler(ds.classes(), class_per_batch=2, batch_size=4)
>>> list(s)
[[16, 0, 0, 9], [10, 8, 11, 2], [16, 9, 16, 8], [2, 9, 2, 3]]

>>> dl = DataLoader(ds, batch_sampler=s)
>>> list(iter(dl))
[tensor([ -5,  -6, -21, -13]), tensor([ -4,  -4, -13, -13]), tensor([ -3, -21,  -2,  -5]), tensor([-3, -5, -4, -6])]

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...