from . import *

from datasets.dataloader import Dataset
from datasets.dataset_utils import *
from datasets.dataset_config import DatasetConfig

ImageFile.LOAD_TRUNCATED_IMAGES = True

class IDD_Dataset(Dataset):
    def __init__(self,
                 args,
                 data_path='./datasets/IDD',
                 list_path='./datasets/IDD/standard_split',
                 split='train',  # train->training, test->testing or val->validation
                 incr_class_split = 'standard',
                 incr_class_step = 0,
                 crop_to=None,
                 resize_to=None,
                 make_val_size_div=False,
                 random_resize=None,
                 training=True,
                 resize_mask=True,
                 class_set='city19'):

        kwargs = {k:v for k,v in locals().items() if k in inspect.getargspec(self.__init__).args and k not in ('self',)}
        self.dataset_name = 'idd'
        super().__init__(**kwargs)

        ###
        item_list_filepath = os.path.join(self.list_path, self.split + ".txt")
        self.items = [id for id in open(item_list_filepath)]
        ###

        self.id2trainid = DatasetConfig.get_id2iid(dataset='idd', incr_class_split=incr_class_split, class_set=class_set, incr_class_step=incr_class_step, training=training, convert_names_to='cityscapes')
        self.image_mean = np.array(DatasetConfig.mean_bgr['idd'], dtype=np.float32)


        self.print_for_debug = True if args.debug else False

        if args.debug: print(f'DEBUG: {self.dataset_name.upper()} {self.split} -> item_list_filepath: {item_list_filepath} , first item: {self.items[0]}')
        if args.debug: print(f'{len(self.items)} num images in {self.dataset_name.upper()} {self.split} set have been loaded')

        if not args.debug and split=='train': self._remove_unlabeled_images()

    def set_id2trainid_all_seen_classes(self):
        # Overwrite parent method to add name conversion option
        self.id2trainid = DatasetConfig.get_id2iid(dataset=self.dataset_name, incr_class_split=self.incr_class_split, class_set=self.class_set,
                                                   incr_class_step=self.incr_class_step, training=False, convert_names_to='cityscapes')


    def __getitem__(self, item):
        id_img, id_gt = self.items[item].strip('\n').split(' ')

        # image
        image_path = os.path.join(self.data_path, id_img)
        image = Image.open(image_path).convert("RGB")
        w,h = image.size
        # gt
        gt_image_path = os.path.join(self.data_path, id_gt)
        gt_image = Image.open(gt_image_path)
        if self.doing_label_check: return gt_image

        if self.split == "train" and self.training:
            image, gt_image = self._train_sync_transform(image, gt_image)
        else:
            image, gt_image = self._val_sync_transform(image, gt_image, resize_mask=self.resize_mask)

        if self.print_for_debug:
            self.print_for_debug = False
            print(f'DEBUG: {self.dataset_name.upper()} {self.split} -> item: {item}, image_path: {image_path}')
            print(f'DEBUG: {self.dataset_name.upper()} {self.split} -> item: {item}, gt_path: {gt_image_path}')
            print(f'DEBUG: GT values in {gt_image[gt_image>=0].unique().sort()[0].tolist()}')
        return image, gt_image, id_img.split('/')[-1].split('.')[0], (w,h)



