from . import *
from utils.misc import *


def safe_log(t,eps=1e-10):
    return torch.log(t+eps)


class CE_loss_unbiased(nn.Module):
    def __init__(self, ignore_label=-1, device='cuda'):
        super().__init__()
        self.ignore_label = ignore_label
        self.device = device

        self.old_classes, self.new_classes = [], []

    def update_classes(self, num_classes):
        """
        Update old and new class lists
        :param num_classes: integer number of current classes
        """
        # To support UDA mode (when no new classes have been added, do not update)
        if num_classes > len(self.old_classes):
            self.old_classes += self.new_classes
            self.new_classes = [i for i in range(len(self.old_classes), num_classes)]


    def forward(self, pred, gt):
        """
        Compute unbiased CE from MiB
        :param pred: BxCxHxW tensor, output of classifier before softmax
        :param gt: BxHxW tensor, semantic label
        :return: loss value
        """
        log_softmax = F.log_softmax(pred, dim=1)
        new_classes = torch.Tensor(self.new_classes).long().to(self.device)  # channel indices of new classes
        # select only channels corresponding to new classes
        log_softmax_list = [log_softmax.index_select(dim=1, index=new_classes)]  # B x ->C<- x H x W
        if self.old_classes:
            old_classes = torch.Tensor(self.old_classes).long().to(self.device)
            softmax = F.softmax(pred, dim=1)
            # select only channels corresponding to old classes (unknown class included) and sum them (output: total prob of being in old or unseen classes)
            log_softmax_list.insert(0, safe_log(softmax.index_select(dim=1, index=old_classes).sum(dim=1)).unsqueeze(1))  # B x 1 x H x W
            # shift back gt indices, so that idx0 => old&unseen classes and idx1... => new classes
            gt = torch.where(gt>0, gt-len(self.old_classes)+1, gt)

        log_softmax_balanced = torch.cat(log_softmax_list, dim=1)
        loss = F.nll_loss(log_softmax_balanced, gt, ignore_index=self.ignore_label)

        return loss

class UnbiasedCrossEntropy(nn.Module):
    def __init__(self, old_cl=None, reduction='mean', ignore_index=-1, weight_fn=None):
        super().__init__()
        self.reduction = reduction
        self.ignore_index = ignore_index
        self.weight_fn = weight_fn

        # not from original implementation
        self.old_cl = old_cl
        self.old_classes, self.new_classes = [], []

    # not from original implementation
    def update_classes(self, num_classes):
        if num_classes > len(self.old_classes):
            self.old_classes += self.new_classes
            self.new_classes = [i for i in range(len(self.old_classes), num_classes)]
        self.old_cl = max(self.old_classes)+1 if self.old_classes else 0

    def forward(self, inputs, targets):

        old_cl = self.old_cl
        outputs = torch.zeros_like(inputs)  # B, C (1+V+N), H, W
        den = torch.logsumexp(inputs, dim=1)                               # B, H, W       den of softmax
        outputs[:, 0] = torch.logsumexp(inputs[:, 0:old_cl], dim=1) - den  # B, H, W       p(O)
        outputs[:, old_cl:] = inputs[:, old_cl:] - den.unsqueeze(dim=1)    # B, N, H, W    p(N_i)

        labels = targets.clone()    # B, H, W
        labels[targets < old_cl] = 0  # just to be sure that all labels old belongs to zero

        weight = None
        if self.weight_fn:
            weight = self.weight_fn(targets, inputs.max(1)[1], num_classes=inputs.size(1))
        loss = F.nll_loss(outputs, labels, ignore_index=self.ignore_index, reduction=self.reduction, weight=weight)

        return loss



class Distillation_loss_unbiased(nn.Module):
    def __init__(self, unknown_label=-1, device='cuda'):
        super().__init__()
        self.unknown_label = unknown_label
        self.device = device

        self.old_classes, self.new_classes = [], []

    def update_classes(self, num_classes):
        self.old_classes += self.new_classes
        self.new_classes = [i for i in range(max(self.old_classes)+1 if self.old_classes else 0, num_classes)]
        if self.unknown_label in self.new_classes:
            self.new_classes.remove(self.unknown_label)

    def forward(self, pred_new, pred_old, mask=None):

        softmax_old = F.softmax(pred_old.detach().clone(), dim=1)
        softmax_new = F.softmax(pred_new, dim=1)
        # merge unknown and new classes to unknown class to match label distribution of old prediction
        classes_to_merge = torch.Tensor(self.new_classes + [self.unknown_label]).long().to(self.device)
        classes_to_keep = torch.Tensor(self.old_classes).long().to(self.device)
        softmax_new_new = softmax_new.index_select(dim=1, index=classes_to_merge).sum(dim=1).unsqueeze(1)
        softmax_new_old = softmax_new.index_select(dim=1, index=classes_to_keep)

        log_softmax_new = safe_log(torch.cat([softmax_new_new,softmax_new_old], dim=1))
        loss = softmax_old * log_softmax_new
        if mask is not None:
            loss = loss * mask.float()
        loss = -1 * torch.mean(loss)
        return loss

class UnbiasedKnowledgeDistillationLoss(nn.Module):
    def __init__(self, reduction='mean', alpha=1.):
        super().__init__()
        self.reduction = reduction
        self.alpha = alpha

        # not from original implementation
        self.old_classes, self.new_classes = [], []

    # not from original implementation
    def update_classes(self, num_classes):
        self.old_classes += self.new_classes
        self.new_classes = [i for i in range(max(self.old_classes)+1 if self.old_classes else 0, num_classes)]
        if hasattr(self,'unknown_label') and self.unknown_label in self.new_classes:
            self.new_classes.remove(self.unknown_label)

    def forward(self, inputs, targets, mask=None, is_target_output=True):

        new_cl = inputs.shape[1] - targets.shape[1]

        targets = targets * self.alpha

        new_bkg_idx = torch.tensor([0] + [x for x in range(targets.shape[1], inputs.shape[1])]).to(inputs.device)

        den = torch.logsumexp(inputs, dim=1)                          # B, H, W
        outputs_no_bgk = inputs[:, 1:-new_cl] - den.unsqueeze(dim=1)  # B, OLD_CL, H, W
        outputs_bkg = torch.logsumexp(torch.index_select(inputs, index=new_bkg_idx, dim=1), dim=1) - den     # B, H, W

        # not from original implementation
        if is_target_output:
            labels = torch.softmax(targets, dim=1)                        # B, BKG + OLD_CL, H, W
        else:
            labels = targets

        # make the average on the classes 1/n_cl \sum{c=1..n_cl} L_c
        loss = (labels[:, 0] * outputs_bkg + (labels[:, 1:] * outputs_no_bgk).sum(dim=1)) / targets.shape[1]

        if mask is not None:
            loss = loss * mask.float()

        if self.reduction == 'mean':
            outputs = -torch.mean(loss)
        elif self.reduction == 'sum':
            outputs = -torch.sum(loss)
        else:
            outputs = -loss

        return outputs

