import random
import numpy as np

from PIL import Image, ImageOps, ImageFilter
import torch
import torch.utils.data as data
from datasets.dataset_config import DatasetConfig
from functools import partial


def multi_img(stack=False):
    def decorator(fn):
        def wrapper(self, *args, **kwargs):
            stack_fn = (lambda x:x) if not stack else torch.stack
            img = args[0]
            input_list = img
            partial_fn = partial(fn, self, **kwargs)

            if isinstance(img, list):
                outputs = stack_fn(list(map(partial_fn,input_list)))
                if 'mask' in kwargs:  # outputs: [(img0,lab),(img1,lab),...]
                    # zip(*outputs): [(img0,img1,...),(lab,lab,...)]
                    return list(list(zip(*outputs))[0]), list(zip(*outputs))[1][0]
                return outputs
            else:
                return partial_fn(*args)
        return wrapper
    return decorator


class Dataset(data.Dataset):

    def __init__(self,
                 args,
                 data_path='./datasets/Cityscapes',
                 list_path='./datasets/Cityscapes/standard_split',
                 split='train',  # train->training, test->testing or val->validation
                 incr_class_split = 'standard',
                 incr_class_step = 0,
                 crop_to=None,
                 resize_to=None,
                 make_val_size_div=False,
                 random_resize=None,
                 training=True,
                 resize_mask=True,
                 class_set='city19'):

        self.args = args
        self.data_path = data_path
        self.list_path = list_path
        self.split = split
        if args.debug: print(f'DEBUG: {self.dataset_name} {self.split} dataset path is {self.list_path}')
        self.incr_class_split = incr_class_split
        self.incr_class_step = incr_class_step
        self.crop_to = crop_to
        self.resize_to = resize_to
        self.make_val_size_div = make_val_size_div
        self.random_resize = random_resize
        if args.debug: print(f'DEBUG: {self.dataset_name} {self.split} dataset image size is {self.crop_to}')

        self.training = training
        if self.training:
            self.random_mirror = args.random_mirror
            self.gaussian_blur = args.gaussian_blur
            self.random_translation = args.random_translation
        else:
            self.random_mirror = self.gaussian_blur = self.random_translation = False
        self.resize_mask = resize_mask
        self.class_set = class_set
        self.doing_label_check = False

        self.fda_style_extraction = False
        self.fda_style_fn = None

    def __len__(self):
        return len(self.items)

    def __getitem__(self, item):
        raise NotImplementedError

    def set_id2trainid_all_seen_classes(self):
        self.id2trainid = DatasetConfig.get_id2iid(dataset=self.dataset_name, incr_class_split=self.incr_class_split, class_set=self.class_set,
                                                   incr_class_step=self.incr_class_step, training=False)

    def _remove_unlabeled_images(self):
        current_classes = DatasetConfig.name2iid()[self.incr_class_split][self.class_set][self.incr_class_step]
        if 'unknown' in current_classes:
            print('Removing images lacking new class instances...')
            self.doing_label_check = True
            unknown_label = current_classes['unknown']
            indexes_to_keep = []
            for index, gt_image in enumerate(self):
                gt_image = self._mask_transform(gt_image)
                if set(gt_image.unique().int().tolist()) - set([unknown_label]):
                    indexes_to_keep.append(index)
            if len(self.items)>len(indexes_to_keep):
                print('!!! WARNING !!!')
                print(f'{len(self.items)-len(indexes_to_keep)} images will be discarded due to lack of new class instances')
            self.items = [self.items[i] for i in indexes_to_keep]
        else:
            print('Skipping image removal: no unknown class found in class set')
        self.doing_label_check = False

    def id2trainId(self, label, reverse=False, ignore_label=-1):
        label_copy = ignore_label * np.ones(label.shape, dtype=np.float32)
        for k, v in self.id2trainid.items():
            label_copy[label == k] = v
        return label_copy

    def _train_sync_transform(self, img, mask):
        """
        Training pre-processing; random mirroring, gaussian blurring, resizing or cropping can be performed
        :param img: input image
        :param mask: input gt map
        :return:
        """
        if self.random_mirror: # default = True
            img, mask = self._random_mirror(img, mask)

        if self.resize_to:
            img, mask = self._resize(img, mask, train=True, resize_mask=True, make_divisible=False)

        if self.crop_to:
            img, mask = self._crop(img, mask)

        if self.fda_style_fn is not None or self.fda_style_extraction:
            img = self._fda(img)
            if self.fda_style_extraction: return img, self._mask_transform(mask)

        if self.gaussian_blur: # default = True
            kwargs = {}
            if self.fda_style_fn is not None:  # fix transform for all styles
                kwargs = {'to_apply':random.random(), 'radius':random.random()}
            img = self._gaussian_blur(img, **kwargs)

        if self.random_translation:   # erfnet https://github.com/Eromera/erfnet_pytorch/blob/master/train/main.py
            kwargs = {}
            if self.fda_style_fn is not None:  # fix transform for all styles
                kwargs = {'transX': random.randint(-2, 2), 'transY': random.randint(-2, 2)}
            img, mask = self._random_translation(img, mask=mask, **kwargs)

        # final transform
        if mask:
            img, mask = self._img_transform(img), self._mask_transform(mask)
            return img, mask
        else:
            img = self._img_transform(img)
            return img

    def _val_sync_transform(self, img, mask, resize_mask=True):
        """
        Validation pre-processing: resize input images and labels to match training pre-processing
        :param img: input image
        :param mask: input gt map
        :param resize_mask: whether or not to resize the gt map, set to False when validating at full resolution
        :return:
        """
        if self.resize_to:
            img, mask = self._resize(img, mask, train=False, resize_mask=resize_mask, make_divisible=self.make_val_size_div)

        # final transform
        img, mask = self._img_transform(img), self._mask_transform(mask)
        return img, mask


    ### TRANSFORMS ###

    @staticmethod
    def _random_mirror(img, mask):
        if random.random() < 0.5:
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
            if mask: mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
        return img, mask

    @staticmethod
    def _make_size_divisible(w, h, base=8):
        def _make_divisible(n):
            r = n % base
            if r == 0: return n
            elif r < base//2: return n//base * base
            else: return (n//base + 1) * base
        ratio = h/w
        h_ = _make_divisible(h)
        w_ = round(h_/ratio)
        w_ = _make_divisible(w_)
        return w_, h_

    def _resize(self, img, mask, train: bool, resize_mask: bool, make_divisible: bool):
        img_w, img_h = img.size  # PIL img.size = w,h
        resize_w, resize_h = self.resize_to
        if self.random_resize and train:
            factor = random.uniform(1-self.random_resize, 1+self.random_resize)
            if resize_w: resize_w = int(resize_w * factor)
            if resize_h: resize_h = int(resize_h * factor)
        assert resize_w or resize_h, 'Both resize sizes are not valid'
        if not resize_w:
            resize_w = round(resize_h / img_h * img_w)
        elif not resize_h:
            resize_h = round(resize_w / img_w * img_h)
        if make_divisible:
            resize_w, resize_h = self._make_size_divisible(w=resize_w, h=resize_h, base=8)
        if img_w < resize_w or img_h < resize_h:
            print(f'WARNING: resize sizes ({resize_w},{resize_h}) exceed original dimensions ({img_w},{img_h})')
        img = img.resize((int(resize_w),int(resize_h)), Image.BICUBIC)
        if resize_mask and mask: mask = mask.resize((int(resize_w),int(resize_h)), Image.NEAREST)
        return img, mask

    def _crop(self, img, mask):
        img_w, img_h = img.size  # PIL img.size = w,h
        crop_w, crop_h = int(self.crop_to[0]), int(self.crop_to[1])
        assert crop_w and crop_h, 'Both crop sizes are not valid'
        if img_w < crop_w or img_h < crop_h:
            print('WARNING: crop sizes ({},{}) exceed original dimensions ({},{})'.format(crop_w, crop_h, img_w, img_h))
            padh = crop_h - img_h if img_h < crop_h else 0
            padw = crop_w - img_w if img_w < crop_w else 0
            img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
            if mask: mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.args.ignore_label) ###

        # recompute padded img size
        img_w, img_h = img.size
        x1 = random.randint(0, img_w - crop_w)
        y1 = random.randint(0, img_h - crop_h)
        img = img.crop((x1, y1, x1 + crop_w, y1 + crop_h))
        if mask: mask = mask.crop((x1, y1, x1 + crop_w, y1 + crop_h))
        return img, mask


    @multi_img(stack=False)
    def _gaussian_blur(self, img, to_apply=None, radius=None):
        # gaussian blur as in PSP
        # PIL, Scipy and OpenCV radius-sigma relations
        # https://stackoverflow.com/questions/62968174/for-pil-imagefilter-gaussianblur-how-what-kernel-is-used-and-does-the-radius-par
        if to_apply is None: to_apply = random.random()
        if radius is None: radius = random.random()
        if to_apply < 0.5:
            img = img.filter(ImageFilter.GaussianBlur(radius=radius))  # random.random() -> Return random number between 0.0 and 1.0
        return img

    @multi_img(stack=False)
    def _random_translation(self, img, mask=None, transX=None, transY=None):  # erfnet https://github.com/Eromera/erfnet_pytorch/blob/master/train/main.py
        assert self.crop_to is None, 'Random translation not to be used when cropping'
        if transX is None: transX = random.randint(-2, 2)
        if transY is None: transY = random.randint(-2, 2)
        img = ImageOps.expand(img, border=(transX, transY, 0, 0), fill=0)
        mask = ImageOps.expand(mask, border=(transX, transY, 0, 0), fill=self.args.ignore_label)  # pad label filling with ignore_label
        img = img.crop((0, 0, img.size[0]-transX, img.size[1]-transY))
        mask = mask.crop((0, 0, mask.size[0]-transX, mask.size[1]-transY))
        return img, mask

    def _fda(self, img):
        if self.fda_style_extraction:
            return img
        return self.fda_style_fn(img)


    @multi_img(stack=True)
    def _img_transform(self, image):
        if self.args.numpy_transform: # default = True
            image = np.asarray(image, np.float32)
            image = image[:, :, ::-1]  # change to BGR
            image -= self.image_mean
            image = image.transpose((2, 0, 1)).copy() # (C x H x W)
            new_image = torch.from_numpy(image)
        else:
            raise NotImplementedError
        return new_image

    def _mask_transform(self, gt_image):
        target = np.asarray(gt_image, np.float32)
        target = self.id2trainId(target).copy()
        target = torch.from_numpy(target)
        return target


