import torch.utils.data as data

class Joint_Dataset(data.Dataset):

    def __init__(self, dl_list):
        self.dl_list = dl_list
        self.dl_len = [len(dl) for dl in self.dl_list]
        self.image_mean = self.dl_list[0].image_mean
        self.dataset_name = '-'.join([dl.dataset_name for dl in self.dl_list])

    def __len__(self):
        return sum([len(dl) for dl in self.dl_list])

    def __getitem__(self, item):
        for i,l in enumerate(self.dl_len):
            if item >= l:
                item -= l
            else:
                dl_idx = i
                break
        return self.dl_list[dl_idx][item]