from . import *

np.seterr(divide='ignore', invalid='ignore')


class Eval:
    def __init__(self, name_classes):
        self.name_classes = name_classes
        self.num_classes = len(self.name_classes)
        self.classes_per_step = None
        self.confusion_matrix = np.zeros((self.num_classes,)*2)
        self.ignore_index = None

    def Pixel_Accuracy(self):
        """
        Pixel accuracy = all pixels correctly predicted / total number of pixels
        """
        if np.sum(self.confusion_matrix) == 0:
            print("Attention: pixel_total is zero!!!")
            PA = 0
        else:
            PA = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
        return PA

    def Mean_Pixel_Accuracy(self):
        """
        Mean pixel accuracy = mean over classes of TP / (TP + FN) (fraction of true pixels of a class correctly detected, regardless of the FP, i.e. pixels classified as current class but of other true class)
        (if you predict all pixels with a class -> 100%)
        """
        MPA = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)  # TP / (TP + FP)  -> TP+FN are all the pixels whose true class is the current one (rows<->true classes)
        MPA = np.nanmean(MPA[:self.ignore_index])
        return MPA

    def Mean_Intersection_over_Union(self):
        """
        Mean IoU = mean over classes of TP / (TP + FN + FP)
        """
        MIoU = np.diag(self.confusion_matrix) / (
                np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - np.diag(self.confusion_matrix))  # (TP + FN) + (TP + FP) - TP = TP + FN + FP
        start = 0 if 'unknown' not in self.name_classes else 1
        MIoU_no_ukw = np.nanmean(MIoU[start:len(MIoU)])
        MIoU = np.nanmean(MIoU[:self.ignore_index])
        return MIoU, MIoU_no_ukw

    def Frequency_Weighted_Intersection_over_Union(self):
        """
        Frequency weighted IoU = weighted mean over classes of TP / (TP + FN + FP), with weight TP+FN / total # of pixels, i.e. fraction of pixels whose true class is the current one
        """
        FWIoU = np.multiply(np.sum(self.confusion_matrix, axis=1), np.diag(self.confusion_matrix))
        FWIoU = FWIoU / (
                np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - np.diag(self.confusion_matrix))
        FWIoU = np.sum(i for i in FWIoU if not np.isnan(i)) / np.sum(self.confusion_matrix)
        return FWIoU

    def Mean_Precision(self):
        """
        Mean precision = mean over classes of TP / (TP + FP), i.e. mean fraction of predictions which are correct, regardless of the FN pixels that have not been found (if you predict only a pixels but correctly -> 100%)
        """
        Precision = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=0)
        Precision = np.nanmean(Precision[:self.ignore_index])
        return Precision
    
    def Print_Every_class_Eval(self, logger=None, wandb_logger=None, current_incr_step=None, curr_data_split=None, split=None, current_iter=None, dataset_name=None):

        printer = logger.info if logger else print

        def mean(l,c=1.):
            l = [el*c for el in l if not math.isnan(el)]
            return sum(l)/len(l) if len(l) > 0 else 0

        MIoU = np.diag(self.confusion_matrix) / (
                    np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - np.diag(self.confusion_matrix))
        MPA = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)
        Precision = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=0)
        Class_ratio = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
        Pred_ratio = np.sum(self.confusion_matrix, axis=0) / np.sum(self.confusion_matrix)
        printer('===>{:<17}\t'.format('Classes') + '{:>6}'.format('MPA') + '\t' + '{:>6}'.format('IoU') +'\t'+ '{:>6}'.format('PC') +'\t'+ '{:>6}'.format('Ratio') +'\t'+ '{:>6}'.format('Pred_Ratio'))

        wandb_data, partial_count = {'global_step': current_iter}, 0
        metrics_to_ret = OrderedDict()
        wandb_string = f'Eval Step {current_incr_step} ({curr_data_split}) on {dataset_name}/({split.title()})'
        n_dec_digits = 6
        for ind_class in range(len(MIoU)):
            pa = str(round(MPA[ind_class] * 100, n_dec_digits)) if not np.isnan(MPA[ind_class]) else 'nan'
            iou = str(round(MIoU[ind_class] * 100, n_dec_digits)) if not np.isnan(MIoU[ind_class]) else 'nan'
            pc = str(round(Precision[ind_class] * 100, n_dec_digits)) if not np.isnan(Precision[ind_class]) else 'nan'
            cr = str(round(Class_ratio[ind_class] * 100, n_dec_digits)) if not np.isnan(Class_ratio[ind_class]) else 'nan'
            pr = str(round(Pred_ratio[ind_class] * 100, n_dec_digits)) if not np.isnan(Pred_ratio[ind_class]) else 'nan'
            printer('===>' + '{:<17}'.format(self.name_classes[ind_class]) + '\t' + '{:>6}'.format(pa) + '\t' + '{:>6}'.format(iou) + '\t' + '{:>6}'.format(pc) + '\t' + '{:>6}'.format(cr) + '\t' + '{:>6}'.format(pr))

            if self.classes_per_step and ind_class+1 in self.classes_per_step:
                start, stop = self.classes_per_step[self.classes_per_step.index(ind_class+1)-1], ind_class+1
                printer('===>' + '{}: {}'.format('Partial mIoU', round(mean(MIoU[start:stop],100), n_dec_digits) ))
                printer('-'*62)
                wandb_data = {**wandb_data, f'{wandb_string} Partial-{partial_count} mIoU step': round(mean(MIoU[start:stop],100), 2)}
                offset = 1 if 'unknown' in self.name_classes else 0
                metrics_to_ret[partial_count-offset] = mean(MIoU[start:stop],100)
                partial_count += 1

        start = 0 if 'unknown' not in self.name_classes else 1
        s = '' if 'unknown' not in self.name_classes else ' (no bgr)'
        printer('===>' + '{}: {}'.format(f'Total mIoU{s}', round(mean(MIoU[start:len(MIoU)],100), n_dec_digits) ))
        if -1 not in metrics_to_ret and start==1: metrics_to_ret[-1] = mean([MIoU[0]], 100)
        if 0 not in metrics_to_ret: metrics_to_ret[0] = mean(MIoU[start:len(MIoU)], 100)
        metrics_to_ret['total'] = mean(MIoU[start:len(MIoU)],100)

        wandb_data = {**wandb_data, f'{wandb_string} Total mIoU{s}': round(mean(MIoU[start:len(MIoU)],100), 2)}

        if wandb_logger is not None:
            data = list(zip(range(len(MIoU)), self.name_classes[:len(MIoU)], MIoU, MPA, Precision, Class_ratio, Pred_ratio))
            columns = ['Class Index', 'Class Name', 'mIoU', 'mPA', 'Precision', 'Class Ratio', 'Pred Ratio']
            table = wandb.Table(data=data, columns=columns)
            for idx,m in enumerate(columns[2:]):
                wandb_data = {**wandb_data, **{f'{m} histogram': wandb.plot.scatter(table, x='Class Index', y=m, title=m)}}

        if wandb_logger is not None: wandb_logger.log(wandb_data, current_iter)

        return metrics_to_ret

    # generate confusion matrix (actually inverted => rows: gt class, columns: pred class)
    def __generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_classes)
        label = self.num_classes * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_classes**2)
        confusion_matrix = count.reshape(self.num_classes, self.num_classes)
        return confusion_matrix

    def add_batch(self, gt_image, pre_image):
        # assert the size of two images are same
        assert gt_image.shape == pre_image.shape
        self.confusion_matrix += self.__generate_matrix(gt_image, pre_image)

    def update(self, name_classes):
        """
        Update the Eval object with new classes, re-initializing or updating its params
        :param name_classes: new classes to be added
        """
        # if first time to update, initialize list of incremental class split sizes
        if not self.classes_per_step:
            self.classes_per_step = [0, self.num_classes] if 'unknown' not in self.name_classes else [0, 1, self.num_classes]

        # update class names and total class number
        self.name_classes += name_classes
        self.num_classes = len(self.name_classes)

        # keep track of number of new classes per incremental step
        self.classes_per_step.append(self.num_classes)

        # re-initialize confusion matrix
        self.confusion_matrix = np.zeros((self.num_classes,)*2)

    def reset(self):
        self.confusion_matrix = np.zeros((self.num_classes,) * 2)

def softmax(k, axis=None):
    exp_k = np.exp(k)
    return exp_k / np.sum(exp_k, axis=axis, keepdims=True)

