import torch
import numpy as np
from PIL import Image
from datasets.dataset_config import DatasetConfig

def flip(x, dim):
    dim = x.dim() + dim if dim < 0 else dim
    inds = tuple(slice(None, None) if i != dim
                 else x.new(torch.arange(x.size(i)-1, -1, -1).tolist()).long()
                 for i in range(x.dim()))
    return x[inds]

def inv_preprocess(imgs, num_images=1, img_mean=None, numpy_transform=False):
    """Inverse preprocessing of the batch of images.

    Args:
      imgs: batch of input images.
      num_images: number of images to apply the inverse transformations on.
      img_mean: vector of mean colour values.
      numpy_transform: whether change RGB to BGR during img_transform.

    Returns:
      The batch of the size num_images with the same spatial dimensions as the input.
    """
    if numpy_transform:
        imgs = flip(imgs, 1)
    def norm_ip(img):
        if img.dim()==4: img.add_(torch.tensor(img_mean[::-1].copy()).repeat(img.size(0), 1).unsqueeze(-1).unsqueeze(-1)).div_(255.)
        elif img.dim()==3: img.add_(torch.tensor(img_mean[::-1].copy()).unsqueeze(-1).unsqueeze(-1)).div_(255.)
        else: raise NotImplementedError
    norm_ip(imgs)
    return imgs[:num_images]

def decode_labels(mask, class_names, class_set='city19', num_images=1, ignore_index=-1):
    """Decode batch of segmentation masks.

    Args:
      mask: result of inference after taking argmax.
      num_images: number of images to decode from the batch.
      class_names: classes to predict.
      class_set: set of classes to be used
      ignore_index: index to ignore

    Returns:
      A batch with num_images RGB images of the same size as the input.
    """
    num_classes = len(class_names)
    assert class_set in DatasetConfig.name2color, "Class set '{}' not valid, choose from '{}'".format(class_set, DatasetConfig.name2color)
    name2color = DatasetConfig.name2color[class_set]
    label_colours = [name2color[cn] for cn in class_names]

    if isinstance(mask, torch.Tensor):
        mask = mask.data.cpu().numpy()
    n, h, w = mask.shape
    if n < num_images:
        num_images = n
    outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8)
    for i in range(num_images):
        img = Image.new('RGB', (len(mask[i, 0]), len(mask[i])))
        pixels = img.load()
        for j_, j in enumerate(mask[i, :, :]):
            for k_, k in enumerate(j):
                if k in range(num_classes):
                    pixels[k_,j_] = tuple(label_colours[k])
                elif k == ignore_index:
                    pixels[k_,j_] = (255,255,255)
                else:
                    raise ValueError('Invalid class index {} in segmentation map'.format(k))
        outputs[i] = np.array(img)
    return torch.from_numpy(outputs.transpose([0, 3, 1, 2]).astype('float32')).div_(255.0)
