from . import torch, F

class Pseudo_Labeler:

    def __init__(self, ignore_index=-1, global_thresh=0.9, class_keep_fract=0.66, teacher=None, global_thresh_low=-1.):
        self.ignore_index = ignore_index
        self.global_thresh = global_thresh
        self.class_keep_fract = class_keep_fract
        self.teacher = teacher
        self.global_thresh_low = global_thresh_low

    def set_teacher(self, model):
        self.teacher = model

    def get_image_mask(self, prob, pseudo_lab):
        """ Compute mask of pixels to keep: if above thr prediction prob or if within the frc top confident predictions"""
        max_prob = prob.detach().clone().max(0)[0]
        mask_prob = max_prob > self.global_thresh if 0. < self.global_thresh < 1. else torch.zeros(max_prob.size(), dtype=torch.bool).to(max_prob.device)
        mask_topk = torch.zeros(max_prob.size(), dtype=torch.bool).to(max_prob.device)
        if 0. < self.class_keep_fract < 1.:
            for c in pseudo_lab.unique():
                mask_c = pseudo_lab == c
                max_prob_c = max_prob.clone()
                max_prob_c[~mask_c] = 0
                _, idx_c = torch.topk(max_prob_c.flatten(), k=int(mask_c.sum()*self.class_keep_fract))
                mask_topk_c = torch.zeros_like(max_prob_c.flatten(), dtype=torch.bool)
                mask_topk_c[idx_c] = 1
                mask_c &= mask_topk_c.unflatten(dim=0, sizes=max_prob_c.size())
                mask_topk |= mask_c
        return mask_prob | mask_topk

    def get_batch_mask(self, softmax, pseudo_lab):
        b,_,_,_ = softmax.size()
        mask = torch.stack([self.get_image_mask(pb,pl) for pb,pl in zip(softmax, pseudo_lab)], dim=0)
        return mask

    def check_consistency(self, pseudo_lab, gt, unknown_lab, lab_to_force):
        """ Where we have new classes in GT, there can not be an old class being pseudo-labeled, so we force an unkn pseudo-lab """
        mask = (gt != unknown_lab) & (gt != self.ignore_index) & (pseudo_lab != unknown_lab) & (pseudo_lab != self.ignore_index)
        pseudo_lab[mask] = lab_to_force
        return pseudo_lab

    def __call__(self, imgs, teacher=None, gt_new_classes=None, unknown_lab=None, lab_to_force=None):

        if teacher is None: teacher = self.teacher

        consistency = self.global_thresh_low is not None and 0. <= self.global_thresh_low <= 1.
        consistency_mask = None
        if isinstance(imgs, list):
            # select most confident style per pixel: softmax vector is taken from that style
            with torch.no_grad():
                softmax_with_style = [F.softmax(teacher(img)['out'], dim=1) for img in imgs]
            softmax_values = torch.stack(softmax_with_style, dim=0)  # S x B x C x H x W
            S,B,C,H,W = softmax_values.size()

            if consistency: consistency_mask = (F.one_hot(softmax_values.max(2, keepdim=False)[1]).sum(0) == S).sum(-1).bool()   # B x H x W

            max_values = torch.stack([pred.max(1)[0] for pred in softmax_with_style], dim=0)  # S x B x H x W
            style_per_pixel = max_values.max(0)[1]  # B x H x W
            style_per_pixel = style_per_pixel.unsqueeze(1).repeat(1, C, 1, 1).unsqueeze(0) # 1 x B x C x H x W
            softmax = torch.gather(softmax_values, 0, style_per_pixel).squeeze(0) # 1 x B x C x H x W -> B x C x H x W
            pseudo_lab = softmax.detach().clone().max(1)[1]

            if consistency:
                max_prob = softmax.detach().clone().max(1)[0]
                mask_low_th = max_prob > self.global_thresh_low if 0. <= self.global_thresh_low <= 1. else torch.zeros(max_prob.size(), dtype=torch.bool).to(max_prob.device)
                consistency_mask &= mask_low_th

        else:
            with torch.no_grad():
                softmax = F.softmax(teacher(imgs)['out'], dim=1)
            pseudo_lab = softmax.detach().clone().max(1)[1]

        mask = self.get_batch_mask(softmax, pseudo_lab)
        if consistency_mask is not None and consistency: mask |= consistency_mask
        pseudo_lab[~mask] = self.ignore_index

        if gt_new_classes is not None:
            assert unknown_lab is not None and lab_to_force is not None
            pseudo_lab = self.check_consistency(pseudo_lab, gt_new_classes, unknown_lab, lab_to_force)
            mask = pseudo_lab != self.ignore_index

        return pseudo_lab, softmax, mask
