from . import *

from datasets.dataloader import Dataset
from datasets.dataset_utils import *
from datasets.dataset_config import DatasetConfig

ImageFile.LOAD_TRUNCATED_IMAGES = True

class Shift_Dataset(Dataset):

    original_labels = {
        ''
    }

    def __init__(self,
                 args,
                 data_path='./datasets/Shift',
                 list_path='./datasets/Shift/standard_split',
                 split='train',
                 incr_class_split = 'standard',
                 incr_class_step = 0,
                 crop_to=None,
                 resize_to=None,
                 make_val_size_div=False,
                 random_resize=None,
                 training=True,
                 resize_mask=True,
                 class_set='city19'):

        kwargs = {k:v for k,v in locals().items() if k in inspect.getargspec(self.__init__).args and k not in ('self',)}
        if 'timeofday' in list_path or 'weather' in list_path:
            self.dataset_name = os.path.split(list_path)[-1]
        else:
            self.dataset_name = 'shift'
        super().__init__(**kwargs)

        ###
        item_list_filepath = os.path.join(self.list_path, self.split + ".txt")
        self.items = [id for id in open(item_list_filepath)]
        ###

        self.id2trainid = DatasetConfig.get_id2iid(dataset='shift', incr_class_split=incr_class_split, class_set=class_set, incr_class_step=incr_class_step, training=training,
                                                   convert_names_to=None)
        self.image_mean = np.array(DatasetConfig.mean_bgr['shift'], dtype=np.float32)


        self.print_for_debug = True if args.debug else False

        if args.debug: print(f'DEBUG: {self.dataset_name.upper()} {self.split} -> item_list_filepath: {item_list_filepath} , first item: {self.items[0]}')
        if args.debug: print(f'{len(self.items)} num images in {self.dataset_name.upper()} {self.split} set have been loaded')

        if not args.debug and split=='train': self._remove_unlabeled_images()

    def set_id2trainid_all_seen_classes(self):
        # Overwrite parent method to add name conversion option
        self.id2trainid = DatasetConfig.get_id2iid(dataset=self.dataset_name, incr_class_split=self.incr_class_split, class_set=self.class_set,
                                                   incr_class_step=self.incr_class_step, training=False, convert_names_to='cityscapes' if self.class_set not in ('mapillary', ) else None)

    def __getitem__(self, item):
        id = self.items[item].strip('\n')

        # image
        image_path = os.path.join(self.data_path, self.split, 'front', 'img', id)
        image = Image.open(image_path).convert("RGB")
        w,h = image.size
        # gt
        gt_image_path = os.path.join(self.data_path, self.split, 'front', 'semseg', id.replace('img','semseg').replace('jpg','png'))
        gt_image = Image.open(gt_image_path)
        if self.doing_label_check: return gt_image

        if self.split == "train" and self.training:
            image, gt_image = self._train_sync_transform(image, gt_image)
        else:
            image, gt_image = self._val_sync_transform(image, gt_image, resize_mask=self.resize_mask)

        if self.print_for_debug:
            self.print_for_debug = False
            print(f'DEBUG: {self.dataset_name.upper()} {self.split} -> item: {item}, image_path: {image_path}')
            print(f'DEBUG: {self.dataset_name.upper()} {self.split} -> item: {item}, gt_path: {gt_image_path}')
            print(f'DEBUG: GT values in {gt_image[gt_image>=0].unique().sort()[0].tolist()}')
        return image, gt_image[:,:,0], id, (w,h)



