import os
import socket
import itertools
from math import ceil
from pathlib import Path
from shutil import copyfile
from PIL import Image
from tqdm import tqdm
from tensorboardX import SummaryWriter

import torchvision
import torch.utils.data as data

import sys
sys.path.append(os.path.abspath('.'))

from datasets import Mapillary_Dataset, Cityscapes_Dataset, IDD_Dataset, BDD100k_Dataset, Shift_Dataset, DatasetConfig, StyleAugment, Joint_Dataset
from datasets.dataset_utils import inv_preprocess, decode_labels
from utils import *

# 'continent': ('EU', 'NA', 'AS', 'SA', 'OC', 'AF')
list_paths = {**{'-'.join(seq):seq for seq in list(itertools.permutations(('EU', 'NA', 'AS', 'SA', 'OC', 'AF')))}, # 'EU-NA-AS-SA-OC-AF'
              **{'-'.join(seq):seq for seq in list(itertools.permutations(('bdd100k','Cityscapes','IDD')))},
              **{'Mapillary': ('Mapillary',)},
              **{'-'.join(seq):seq for seq in list(itertools.permutations(('dawn_dusk', 'daytime', 'night')))},
              **{'Shift': ('Shift',)},
              }

def is_single_dataset_split(split):
    cont = all([s in split for s in ('EU', 'NA', 'AS', 'SA', 'OC', 'AF')])
    time = all([s in split for s in ('dawn_dusk', 'daytime', 'night')])
    return cont or time

def get_single_dataset_split(split):
    if all([s in split for s in ('EU', 'NA', 'AS', 'SA', 'OC', 'AF')]):
        return 'Mapillary/splits/continent/'
    if all([s in split for s in ('dawn_dusk', 'daytime', 'night')]):
        return 'Shift/splits/timeofday/'

data_paths = {'mapillary': 'Mapillary',
              'shift': 'Shift',
              'cityscapes': 'Cityscapes',
              'idd': 'IDD',
              'bdd100k': 'bdd100k'}

baseline_results = {
    'CIL':{
        'erfnet': {0: {'cityscapes': 85.152, 'bdd100k': 76.1, 'idd': 87.11,
                       'dawn_dusk': 85.152, 'daytime': 76.1, 'night': 87.11},
                   1: {'cityscapes': 72.9782, 'bdd100k': 63.34, 'idd': 68.85,
                       'dawn_dusk': 85.152, 'daytime': 76.1, 'night': 87.11},
                   2: {'cityscapes': 67.9116, 'bdd100k': 49.62, 'idd': 68.14,
                       'dawn_dusk': 85.152, 'daytime': 76.1, 'night': 87.11}},
        'deeplabv3-resnet101': {0: {'cityscapes': 83.6289008, 'bdd100k': 77.4019446, 'idd': 86.4280868},
                                1: {'cityscapes': 69.7301239, 'bdd100k': 64.0236616, 'idd': 70.7514114},
                                2: {'cityscapes': 67.6607726, 'bdd100k': 59.2175900, 'idd': 70.6390867}},
    },
    'doubleCIL':{
        'erfnet': {0: {'EU': 80.07899, 'NA': 83.8198015, 'AS': 77.5758835, 'SA': 78.2538505, 'OC': 80.2623745, 'AF': 81.373094},
                   1: {'EU': 82.05114, 'NA': 86.8074006, 'AS': 77.8153296, 'SA': 80.2398556, 'OC': 85.0574226, 'AF': 74.246452},
                   2: {'EU': 76.03736, 'NA': 77.6240577, 'AS': 73.0901071, 'SA': 69.1719885, 'OC': 75.0744991, 'AF': 60.479548},
                   3: {'EU': 72.68807, 'NA': 74.6430484, 'AS': 68.8469135, 'SA': 67.2882647, 'OC': 70.4244119, 'AF': 57.734623},
                   4: {'EU': 69.56538, 'NA': 71.9912233, 'AS': 67.1054072, 'SA': 64.4183680, 'OC': 63.7376157, 'AF': 54.313591},
                   5: {'EU': 64.38554, 'NA': 62.5620748, 'AS': 60.6720489, 'SA': 59.4489891, 'OC': 50.3552326, 'AF': 50.025090}},
    },
    'reverseCIL':{
        'erfnet': {0: {'cityscapes': 85.152, 'bdd100k': 76.1040000, 'idd': 87.1080000},
                   1: {'cityscapes': 70.2560, 'bdd100k': 49.3915385, 'idd': 74.8415385},
                   2: {'cityscapes': 67.9116, 'bdd100k': 49.6236842, 'idd': 68.1436842}},
    }
}


class Trainer:

    def __init__(self, args, ignore_index=-1, logger=None):
        self.args = args
        assert self.args.class_set in DatasetConfig.name2iid()[self.args.incr_class_split], "Class set '{}' not valid, choose from '{}'".format(self.args.class_set, DatasetConfig.name2color)
        self.name2iid = DatasetConfig.name2iid()[self.args.incr_class_split][self.args.class_set]
        self.args.ignore_label = ignore_index
        self.logger = logger
        self.cuda = self.args.device not in ('cpu',) and torch.cuda.is_available()
        self.device = torch.device('cuda' if self.cuda else 'cpu')
        self.writer = None
        self.wandb_logger = None

        # loss definition
        if not self.args.use_ce_unbiased:
            self.ce_loss = nn.CrossEntropyLoss(weight=None, ignore_index=self.args.ignore_label).to(self.device)
        else:
            if self.args.use_MiB_reimplemented:
                self.ce_loss = CE_loss_unbiased(device=self.device).to(self.device)
            else:
                weight_fn = None
                self.ce_loss = UnbiasedCrossEntropy(weight_fn=weight_fn).to(self.device)

        self.ce_los_aux = None
        if self.args.label_inpainting:
            assert not self.args.use_ce_unbiased, "CE std must be used when label inpainting"
            if self.args.label_inpainting_aux_loss: self.ce_los_aux = UnbiasedCrossEntropy().to(self.device)

        self.use_kd_fda = self.args.lambda_kd_style > 0. and self.args.use_fda
        self.use_kd_std = self.args.lambda_distil > 0. if not isinstance(self.args.lambda_distil,list) else self.args.lambda_distil[0] > 0.
        self.use_kd_loss = self.use_kd_fda or self.use_kd_std
        if self.use_kd_loss:
            if isinstance(args.lambda_distil,list):
                raise NotImplementedError
            if self.args.use_MiB_reimplemented:
                self.distil_loss = Distillation_loss_unbiased(device=self.device).to(self.device)
            else:
                self.distil_loss = UnbiasedKnowledgeDistillationLoss().to(self.device)

            self.distil_loss_ce = None
            if self.args.distil_type in ('pseudo_filter','pseudo_filter_style') and self.args.distill_with_ce:
                self.distil_loss_ce = nn.CrossEntropyLoss(ignore_index=self.args.ignore_label).to(self.device)

        if self.args.distil_curr_style: assert self.args.ce_curr_style, "Self-stylization on KD alone is not allowed"

        self.pseudo_labeler = Pseudo_Labeler(global_thresh=self.args.global_thresh, class_keep_fract=self.args.class_keep_fract, global_thresh_low=self.args.global_thresh_low)

        # loss logging
        self.loss_logger = LossLogger()

        self.old_model = None

        # dataset
        self.data_kwargs = {'incr_class_split':args.incr_class_split,
                            'crop_to':args.crop_to,
                            'resize_to':args.resize_to,
                            'random_resize':args.random_resize,
                            'class_set':args.class_set}
        self.dataloader_kwargs = {'num_workers':args.data_loader_workers,
                                  'pin_memory':args.pin_memory,
                                  'drop_last':True}

        self.available_datasets = {'mapillary': Mapillary_Dataset,
                                   'shift': Shift_Dataset,
                                   'cityscapes': Cityscapes_Dataset,
                                   'idd': IDD_Dataset,
                                   'bdd100k': BDD100k_Dataset}


        self.data_gen, self.train_dataloader, self.val_dataloader = None, None, None
        self.style_dataloader, self.data_gen_style = None, None
        self.train_steps_per_epoch, self.training_steps, self.num_epochs, self.validation_steps = 0, 0, 0, 0
        self.val_dls = {}  # to store references to dataloaders and avoid re-initializations

        self.curr_data_split_train, self.curr_data_split_val = None, None
        self.single_class_step = False

        # mapillary continents mode
        if is_single_dataset_split(self.args.incr_data_split):
            # train on single dataset
            par_path = get_single_dataset_split(self.args.incr_data_split)
            if self.args.single_data_split:
                assert self.args.single_data_split in list_paths[self.args.incr_data_split], "Single domain mode not valid, choose domain from '{}'".format(list_paths[self.args.incr_data_split])
                self.data_splits = [f'{par_path}/{self.args.single_data_split}']
            # or not
            else:
                self.data_splits = [f'{par_path}/{c}' for c in list_paths[self.args.incr_data_split]]
        # multiple datasets mode
        elif self.args.incr_data_split in list_paths:
            # train on single dataset
            if self.args.single_data_split:
                assert self.args.single_data_split in list_paths[self.args.incr_data_split], "Single domain mode not valid, choose domain from '{}'".format(list_paths[self.args.incr_data_split])
                self.data_splits = ['{}/splits/standard'.format(data_paths[fix_data_name(self.args.single_data_split.lower())])]
            # or not
            else:
                self.data_splits = ['{}/splits/standard'.format(data_paths[fix_data_name(d.lower())]) for d in list_paths[self.args.incr_data_split]]
        else:
            raise ValueError("Domain split '{}' not supported".format(self.args.incr_data_split))

        # domain\class |   =1     |   >1
        #     =1       | standard |   CIL
        #     >1       |   DA     |  CL-UDA
        # If multiple class and domain steps, choose the min # of steps
        if len(self.name2iid) > 1 and len(self.data_splits) > 1:
            self.num_incr_steps = min(len(self.name2iid), len(self.data_splits))
        # Else single class or domain mode, so perform all the incremental or domain steps
        else:
            self.num_incr_steps = max(len(self.name2iid), len(self.data_splits))
            # If we have class OR domain steps =1 (if both =1 skip next 'if', if both >1 we should not be inside the 'else')
            if len(self.name2iid) != len(self.data_splits):

                # DA mode: domains share class set, but past domains are not available (DA without source data)
                if len(self.name2iid) == 1:
                    self.name2iid *= self.num_incr_steps
                    self.single_class_step = True  # used to init dataloaders, which should all have the same class set corresponding to the first class step (the only one)
                    if self.use_kd_loss or self.args.use_ce_unbiased: raise NotImplementedError  # these losses expect incremental class sets

                # Incremental Learning mode: incremental class sets, but same domain for each step>0
                if len(self.data_splits) == 1:
                    self.data_splits *= self.num_incr_steps

        # step to stop is the index (from 0) of the last step to perform (included)
        if self.args.step_to_stop is not None and self.args.step_to_stop<self.num_incr_steps: self.num_incr_steps = self.args.step_to_stop + 1

        self.current_iter, self.current_epoch, self.current_incr_step = 0, 0, 0

        self.model, self.optimizer, self.Eval = None, None, None
        self.class_names, self.num_classes = [], 0
        self.unknown_label = None
        self.result_dict = OrderedDict()
        self.delta_results = OrderedDict()

        self.style_obj = None
        self.step_to_start_wandb, self.wandb_ckpt = 0, None
        self.wandb_old_ckpt, self.wandb_epochs = None, None
        self.wandb_save_model = self.args.wandb_save_model and self.args.use_wandb
        self.wandb_save_curr_model = self.args.wandb_save_curr_model_interval is not None and self.args.wandb_save_curr_model_interval > 0
        self.wandb_load_model = (self.args.wandb_load_model or self.args.wandb_load_curr_model) and self.args.use_wandb


    def model_init(self, step):

        # Update class names and total number
        if step > 0 and not self.class_names:  # if resuming training from step > 0, recover class names and number from previous steps
            classes_per_old_step = []
            for s in range(step):
                class_list = [k for k,v in self.name2iid[s].items() if v != self.args.ignore_label and k not in self.class_names] # unknown class may be already in list
                self.num_classes += len(class_list)
                self.class_names += class_list
                classes_per_old_step.append(class_list)
        # update for current step
        class_list = [k for k,v in self.name2iid[step].items() if v != self.args.ignore_label and k not in self.class_names] # unknown class may be already in list
        self.num_classes += len(class_list)
        self.class_names += class_list

        # Model init
        self.logger.info('  Getting model...')
        imagenet_pretrained = (step == 0 or not self.args.encoder_weights_from_past) and self.args.imagenet_pretrained
        logger_fn = lambda input_str: self.logger.info(f'    {input_str}')
        self.model, params = get_model(self.args, self.num_classes, imagenet_pretrained, logger_fn=logger_fn, device=self.device)
        self.logger.info('  Loading model on gpu...')
        self.model.to(self.device)

        if step>0 and self.args.freeze_encoder:
            assert hasattr(self.model, 'freeze')
            self.model.freeze('encoder')

        # Update losses' params
        self.logger.info('  Updating loss params...')

        if self.args.use_ce_unbiased:
            self.logger.info('    CE unbiased loss...')
            if step > 0 and not self.ce_loss.new_classes: # if resuming training from step > 0, set class number of previous steps
                for s in range(step):
                    self.ce_loss.update_classes(sum([len(l) for l in classes_per_old_step[:s+1]]))
            self.ce_loss.update_classes(self.num_classes)

        if self.ce_los_aux:
            self.logger.info('    CE unbiased auxiliary loss...')
            if step > 0 and not self.ce_los_aux.new_classes: # if resuming training from step > 0, set class number of previous steps
                for s in range(step):
                    self.ce_los_aux.update_classes(sum([len(l) for l in classes_per_old_step[:s+1]]))
            self.ce_los_aux.update_classes(self.num_classes)

        if self.use_kd_loss:
            self.logger.info('    Distillation loss...')
            if 'unknown' in self.name2iid[step]:
                self.distil_loss.unknown_label = self.name2iid[step]['unknown'] # set unknown class index
            if step > 0 and not self.distil_loss.new_classes:  # if resuming training from step > 0, set class number of previous steps
                for s in range(step):
                    self.distil_loss.update_classes(sum([len(l) for l in classes_per_old_step[:s+1]]))
            self.distil_loss.update_classes(self.num_classes)

        # Load model from checkpoint
        if step > 0:
            if self.args.encoder_weights_from_past:
                self.logger.info('Loading encoder into new model from past step...')
                self.load_checkpoint(step=step-1, network='encoder')
            else:
                self.logger.info('!!! WARNING !!!')
                self.logger.info('Encoder from past step has not been loaded')
            self.logger.info('Loading encoder and decoder into old model from past step...')
            logger_fn = lambda input_str: self.logger.info(f'  {input_str}')
            self.old_model = get_model(self.args, self.num_classes-len(class_list), False, logger_fn=logger_fn, device=self.device)[0].to(self.device).eval()
            self.load_checkpoint(step=step-1, network='all', target_model=self.old_model)

        # Initialize optimizer
        if self.args.optim.lower() == "sgd":
            self.optimizer = torch.optim.SGD(
                params=params,
                momentum=self.args.momentum,
                weight_decay=self.args.weight_decay
            )
        elif self.args.optim.lower() == "adam":
            self.optimizer = torch.optim.Adam(
                params=params,
                weight_decay=self.args.weight_decay
            )
        elif self.args.optim.lower() == "adamw":
            self.optimizer = torch.optim.AdamW(
                params=params,
                weight_decay=self.args.weight_decay
            )
        else:
            raise ValueError('Optimizer {} not supported'.format(self.args.optim))

        # Initialize or update eval obj
        self.logger.info('  Setting eval framework...')
        if step > 0 and not self.Eval: # if resuming training from step > 0, initialize eval obj and set class names of previous steps
            self.Eval = Eval(classes_per_old_step[0])
            for s in range(1,step):
                self.Eval.update(classes_per_old_step[1])
        if self.Eval:
            self.Eval.update(class_list)
        else:
            self.Eval = Eval(class_list)

        # Loading pretrained checkpoint overwriting pre-existent weights
        if self.wandb_load_model and step==self.step_to_start_wandb-1 and self.wandb_ckpt is not None:
            self.logger.info('  Loading W&B checkpoint...')
            if self.wandb_old_ckpt is not None:
                # next step this will become the old model
                self.model.load_network(self.wandb_old_ckpt["model_state"], network='encoder', load_partial_decoder=True)
            else:
                self.model.load_network(self.wandb_ckpt["model_state"], network='encoder', load_partial_decoder=True)
            if step not in self.result_dict: self.result_dict[step] = {}
            self.result_dict[step][self.val_dataloader.dataset.dataset_name] = self.validate()  # otherwise by skipping training steps data will be missing

    def dataset_init(self, step):
        """
        Setup training and validation datasets of specified incremental step
        :param step: step index of incremental dataset for which to initialize training and validation data
        """
        # set dataset kwargs and init generator for current step
        dataset = self.data_splits[step].split('/')[0].lower()
        self.data_kwargs['data_path'] = os.path.join(self.args.data_root_path, data_paths[fix_data_name(dataset)])
        self.data_kwargs['list_path'] = os.path.join(self.args.list_root_path, self.data_splits[step])
        self.data_kwargs['incr_class_step'] = step if not self.single_class_step else 0
        self.data_gen = self.available_datasets[fix_data_name(dataset)]
        # Make div option enabled only when mapillary is used as training dataset, and not when instead it is used to evaluate domain generalization properties
        self.data_kwargs['make_val_size_div'] = self.args.crop_to is not None and issubclass(self.data_gen, Mapillary_Dataset) and self.args.model in ('erfnet',)
        self.curr_data_split_train = self.curr_data_split_val = self.data_splits[step].split('/')[0] + '-' + self.data_splits[step].split('/')[-1]

        # create train and val dataloaders
        train_data_gen = self.data_gen(self.args, split='train', **self.data_kwargs)
        self.train_dataloader = data.DataLoader(train_data_gen, shuffle=True, batch_size=self.args.batch_size, **self.dataloader_kwargs)
        val_data_gen = self.data_gen(self.args, split='val', training=False, resize_mask=not self.args.val_at_orig_size, **self.data_kwargs)
        self.val_dataloader = data.DataLoader(val_data_gen, shuffle=False, batch_size=1, **self.dataloader_kwargs)
        self.validation_steps = len(val_data_gen) + 1

        if self.args.use_fda and (step>0 or self.args.ce_curr_style):
            self.train_dataloader.dataset.fda_style_fn = self.style_obj.apply_style

        # init training params
        # # If num images not multiple of batch size, we have an extra step to account for remaining images (drop_last left to False in train Dataloaders)  ->  NOT TRUE
        self.train_steps_per_epoch = len(self.train_dataloader)
        if self.args.iter_max is not None:
            if self.args.epochs is not None:
                raise ValueError('Both num epochs and training steps set, choose one')
            self.training_steps = self.args.iter_max
            self.num_epochs = ceil(self.training_steps / self.train_steps_per_epoch)
        else:
            if self.args.epochs is None:
                raise ValueError('Both num epochs and training steps are not set')
            if self.args.epochs_incremental is None: self.args.epochs_incremental = self.args.epochs
            # define total training steps and number of epochs for the current incremental step
            if step==0:
                self.training_steps = self.args.epochs * self.train_steps_per_epoch
                self.num_epochs = self.args.epochs
            elif isinstance(self.args.epochs_incremental,list):
                self.training_steps = self.args.epochs_incremental[step-1] * self.train_steps_per_epoch
                self.num_epochs = self.args.epochs_incremental[step-1]
            else:
                self.training_steps = self.args.epochs_incremental * self.train_steps_per_epoch
                self.num_epochs = self.args.epochs_incremental
        self.current_iter = 0
        self.current_epoch = 0

    def dataset_val_init(self, step):
        """
        Load validation DATASET of specified incremental step. CLASS SET defined by last call to 'dataset_init()' fn
        :param step: step index of incremental dataset for which to load the val dataset
        """
        if is_single_dataset_split(self.args.incr_data_split):
            data_split = f'{get_single_dataset_split(self.args.incr_data_split)}/{{}}'
        else:
            data_split = '{}/splits/standard'
        data_split = data_split.format(list_paths[self.args.incr_data_split][step])
        self.data_kwargs['list_path'] = os.path.join(self.args.list_root_path, data_split)
        self.curr_data_split_val = data_split.split('/')[0] + '-' + data_split.split('/')[-1]

        dataset = data_split.split('/')[0].lower()
        self.data_kwargs['data_path'] = os.path.join(self.args.data_root_path, data_paths[fix_data_name(dataset)])
        self.data_gen = self.available_datasets[fix_data_name(dataset)]

        val_data_gen = self.data_gen(self.args, split='val', training=False, resize_mask=not self.args.val_at_orig_size, **self.data_kwargs)
        self.val_dataloader = data.DataLoader(val_data_gen, shuffle=False, batch_size=1, **self.dataloader_kwargs)
        self.validation_steps = len(val_data_gen) + 1

    def dataset_dg_test(self, step):
        assert fix_data_name(self.args.dataset_dg_test) in ('mapillary',)
        data_split = 'Mapillary/splits/standard'
        data_kwargs = {'incr_class_split': self.args.incr_class_split,
                       'crop_to': self.args.crop_to,
                       'resize_to': self.args.resize_to,
                       'random_resize': self.args.random_resize,
                       'class_set': self.args.class_set,
                       'list_path': os.path.join(self.args.list_root_path, data_split),
                       'data_path': os.path.join(self.args.data_root_path, data_paths[fix_data_name(self.args.dataset_dg_test)]),
                       'incr_class_step':step if not self.single_class_step else 0}
        self.curr_data_split_val = data_split.split('/')[0] + '-' + data_split.split('/')[-1]
        self.data_gen = self.available_datasets[fix_data_name(self.args.dataset_dg_test)]
        val_data_gen = self.data_gen(self.args, split='val', training=False, resize_mask=not self.args.val_at_orig_size, **data_kwargs)
        self.val_dataloader = data.DataLoader(val_data_gen, shuffle=False, batch_size=1, **self.dataloader_kwargs)
        self.validation_steps = len(val_data_gen) + 1

    def dataset_joint_init(self):
        assert self.args.incr_class_split in ('standard',) and self.args.single_data_split is None, f'{self.args.incr_class_split} - {self.args.single_data_split}'
        train_l, val_l, name_l = [], [], []
        for step in range(self.num_incr_steps):
            self.dataset_init(step)
            train_l.append(self.train_dataloader.dataset)
            val_l.append(self.val_dataloader.dataset)
            name_l.append(self.data_splits[step].split('/')[0])
        train_data_gen = Joint_Dataset(train_l)
        val_data_gen = Joint_Dataset(val_l)
        self.curr_data_split_train = self.curr_data_split_val = '-'.join(name_l) + '-' + self.data_splits[-1].split('/')[-1]

        self.train_dataloader = data.DataLoader(train_data_gen, shuffle=True, batch_size=self.args.batch_size, **self.dataloader_kwargs)
        self.val_dataloader = data.DataLoader(val_data_gen, shuffle=False, batch_size=1, **self.dataloader_kwargs)
        self.validation_steps = len(val_data_gen) + 1

        # init training params
        # # If num images not multiple of batch size, we have an extra step to account for remaining images (drop_last left to False in train Dataloaders)  ->  NOT TRUE
        self.train_steps_per_epoch = len(self.train_dataloader)
        if self.args.iter_max is not None:
            if self.args.epochs is not None:
                raise ValueError('Both num epochs and training steps set, choose one')
            self.training_steps = self.args.iter_max
            self.num_epochs = ceil(self.training_steps / self.train_steps_per_epoch)
        else:
            if self.args.epochs is None:
                raise ValueError('Both num epochs and training steps are not set')
            if self.args.epochs_incremental is None: self.args.epochs_incremental = self.args.epochs
            # define total training steps and number of epochs for the current incremental step
            self.training_steps = self.args.epochs * self.train_steps_per_epoch
            self.num_epochs = self.args.epochs

        self.current_iter = 0
        self.current_epoch = 0


    def train_init(self):

        # display args details
        self.logger.info("Global configuration as follows:")
        for key, val in vars(self.args).items():
            self.logger.info("{:22} {}".format(key, val))

        # setup writer
        self.writer = SummaryWriter(self.args.checkpoint_dir)
        if self.args.use_wandb:
            self.wandb_logger = WandB(self.args)
        if self.wandb_load_model:
            if self.args.wandb_load_curr_model:
                self.step_to_start_wandb, self.wandb_ckpt, self.wandb_old_ckpt, self.wandb_epochs = self.wandb_logger.load_wandb_curr_model()
            else:
                # start searching for ckpt of last step and then go down until one if found or none is found for step 0
                steps = range(self.args.start_incr_step, self.num_incr_steps)[::-1]
                if self.args.wandb_load_step is not None: steps = [self.args.wandb_load_step]
                self.step_to_start_wandb, self.wandb_ckpt = self.wandb_logger.load_wandb_model(steps=steps)

        class_split = [['{}:{}'.format(v,k) for k,v in d.items() if v != -1] for d in self.name2iid]
        for step in range(self.num_incr_steps):
            self.logger.info('Step{} ->'.format(step) + ' Data split: {}'.format(self.data_splits[step].split('/')[0] + '-' + self.data_splits[step].split('/')[-1]))
            self.logger.info('{:8}'.format('') + ' Class split: ' + ' '.join([' {}']*len(class_split[step])).format(*class_split[step]))

        # choose cuda
        if self.cuda:
            current_device = torch.cuda.current_device()
            self.logger.info("This model will run on {}".format(torch.cuda.get_device_name(current_device)))
        else:
            self.logger.info("This model will run on CPU")

        # init fda object
        if self.args.use_fda:
            self.style_obj = StyleAugment(self.args.n_images_per_style, self.args.L, self.args.avg_style)

        for step in range(self.args.start_incr_step, self.num_incr_steps):

            if self.args.use_joint_data and step>0:
                break

            # set seed at the beginning of each incremental step
            s = step * self.args.seed
            set_seed(s)

            if self.args.step_to_stop is not None and step > self.args.step_to_stop: break

            self.logger.info('')
            self.logger.info('### INCREMENTAL STEP {} ###'.format(step))
            self.current_incr_step = step

            self.logger.info('Dataset initialization...')
            if self.args.use_joint_data:
                self.dataset_joint_init()
            else:
                self.dataset_init(step)   # both train and val sets of current dataset are initialized
            self.logger.info('Done')

            self.logger.info('Model initialization...')
            self.model_init(step)
            self.logger.info('Done')

            # extract style on current domain before training to allow for its use during current step
            if step < self.num_incr_steps and self.args.use_fda:
                save_or_load_path = None
                if self.args.avg_style and self.args.save_or_load_avg_style:
                    dataset_name = fix_data_name(list_paths[self.args.incr_data_split][step])
                    save_or_load_path = f".{'/'.join(os.path.split(self.args.checkpoint_dir)[:-2])}/styles/{dataset_name}_L{self.args.L}_N{min(self.args.n_images_per_style,len(self.train_dataloader.dataset))}_avg.npy"
                style_dataset = self.train_dataloader.dataset
                dom_unrel_styles = False

                self.style_obj.add_style(style_dataset, step, save_or_load_path=save_or_load_path, dom_unrel_styles=dom_unrel_styles)

            self.train_incremental_step(step)
            if self.args.incr_val_data in ['before','all']:
                if self.args.incr_val_data == 'before':
                    val_steps = list(range(step))
                else:
                    curr_data_split = self.args.single_data_split if self.args.single_data_split else list_paths[self.args.incr_data_split][step]
                    val_steps = [i for i,el in enumerate(list_paths[self.args.incr_data_split]) if el != curr_data_split]
                if self.args.use_joint_data:
                    val_steps = [0] + val_steps  # we want to validate to the first dataset as well
                for vs in val_steps:
                    self.dataset_val_init(step=vs)
                    if step not in self.result_dict: self.result_dict[step] = {}
                    self.result_dict[step][self.val_dataloader.dataset.dataset_name] = self.validate(log_wandb_imgs=True)
            if self.args.dataset_dg_test is not None:
                self.dataset_dg_test(step)
                if step not in self.result_dict: self.result_dict[step] = {}
                self.result_dict[step][self.val_dataloader.dataset.dataset_name] = self.validate(log_wandb_imgs=True)

            if self.args.incr_class_split not in ('standard',):
                self.compute_and_print_global_metrics(step)

            self.save_checkpoint(step=step)  # saving prototypes of new model
            if self.args.use_wandb:
                self.wandb_logger.step_offset += self.current_iter
            if self.wandb_save_model and step>=self.step_to_start_wandb:
                self.wandb_logger.save_wandb_model(self.model, step, self.args.checkpoint_dir)

        self.writer.close()

    def train_incremental_step(self, step):

        if self.skip_step(step):
            if step in self.result_dict and self.val_dataloader.dataset.dataset_name in self.result_dict[step]:
                return  # we validated when loading current model (inside model_init())
            metrics_to_ret = self.validate(log_wandb_imgs=False)
            if step not in self.result_dict: self.result_dict[step] = {}
            self.result_dict[step][self.val_dataloader.dataset.dataset_name] = metrics_to_ret
            return

        # Print some values of prototypes
        self.logger.info('')

        self.logger.info('\nTraining for {} epoch(s) on {} data split...'.format(self.num_epochs, self.curr_data_split_train))

        if self.args.wandb_load_curr_model and self.step_to_start_wandb==step and self.wandb_epochs is not None:
            self.model.load_network(self.wandb_ckpt["model_state"], network='encoder', load_partial_decoder=True)
            self.current_epoch = self.wandb_epochs + 1  # we start from the next epoch
            self.current_iter = self.train_steps_per_epoch * self.current_epoch + 1  # Ex. if self.wandb_epochs=0, then self.current_epoch=1 => 1 (=self.current_epoch) epoch (idx=0) has been performed
            if self.args.debug:
                print('DEBUG: validate current model after wandb loading')
                save_model_metric = self.validate()

        if 'unknown' in self.name2iid[step]:
            self.unknown_label = self.name2iid[step]['unknown']
        else:
            self.unknown_label = None

        # use CE loss in place of KD one when pseudo-labeling
        if self.use_kd_loss:
            distil_loss_fn_new = self.distil_loss_ce if self.distil_loss_ce is not None else self.distil_loss

        start_epoch = self.current_epoch  # if start from iter 'current_epoch' is that saved in ckpt
        for epoch in range(start_epoch, self.num_epochs):
            epoch_steps = min(self.train_steps_per_epoch, self.training_steps - self.current_iter)
            tqdm_epoch = tqdm(self.train_dataloader, total=epoch_steps,
                              mininterval=None, miniters=self.args.loss_log_step_interval, maxinterval=60000.,
                              desc="Training at Incremental Step {} - Data split {} - Epoch {} over {}".format(step, self.curr_data_split_train, epoch+1, self.num_epochs))
            self.Eval.reset()
            self.model.train()
            for train_iter, (x, y, tr_id, _) in enumerate(tqdm_epoch):

                # debug mode: go quickly through the whole training procedure
                if self.args.debug and train_iter==5: break

                if self.current_iter==0:
                    self.logger.info('First image: {}'.format(tr_id))
                    # with open('model_params_seed-{}_pc-{}_step{}.txt'.format(self.args.seed,socket.gethostname(),step), 'w') as f:
                    #     [f.write(str(float(p[1].view(-1)[0].data)) + '\n') for p in self.model.named_parameters()]
                self.poly_lr_scheduler(
                    optimizer=self.optimizer,
                    init_lr=self.args.lr if step==0 else self.args.lr_incremental,
                    iter=self.current_iter,
                    # to avoid negative lr when current iter > max_iter_poly_decay
                    max_iter=self.args.max_iter_poly_decay if (self.args.max_iter_poly_decay > self.training_steps or self.args.max_iter_poly_decay < 0) else self.training_steps,
                    power=self.args.poly_power,
                )

                if self.current_iter >= self.training_steps:
                    self.logger.info("Training stopped at iter {}!".format(self.current_iter))
                    break

                if self.args.use_fda and (step>0 or self.args.ce_curr_style):
                    # if not self.args.ce_curr_style: assert x.size(1)-1==step, f'{x.size(1)} styles found when {step-1} were expected'
                    assert x.dim()==5
                    x_with_style = [x[:,k,...].to(self.device) for k in range(1,step+1)]
                    if self.args.ce_curr_style:
                        x_with_curr_style = x[:,step+1,...].to(self.device)
                    x = x[:,0,...]


                x, y = x.to(self.device), y.to(device=self.device, dtype=torch.long)
                y = torch.squeeze(y, 1)

                # memory_check('1')
                if self.args.label_inpainting and self.args.label_inpainting_type in ('old',): assert self.args.use_fda   # assert to be checked at step 0 to avoid error been raised with delay

                if self.old_model:
                    with torch.no_grad():

                        label_to_force = self.args.ignore_label
                        # pseudo-label from 'x' is needed for CE-inp (new) or KD (x)
                        if (self.use_kd_std and not self.args.distil_curr_style) or (self.args.label_inpainting and self.args.label_inpainting_type in ('new',)):
                            pseudo_lab_x, pseudo_soft_x, pseudo_mask_x = self.pseudo_labeler(
                                x, teacher=self.old_model, gt_new_classes=y, unknown_lab=self.unknown_label, lab_to_force=label_to_force)
                        # pseudo-label from 'x_with_curr_style' is needed KD (x_with_curr_style)
                        if self.use_kd_std and self.args.distil_curr_style:
                            pseudo_lab_x_curr, pseudo_soft_x_curr, pseudo_mask_x_curr = self.pseudo_labeler(
                                x_with_curr_style, teacher=self.old_model, gt_new_classes=y, unknown_lab=self.unknown_label, lab_to_force=label_to_force)

                        if self.args.use_fda:
                            if self.args.distil_type in ('pseudo_filter_all_styles',):
                                pseudo_lab, pseudo_soft, pseudo_mask = self.pseudo_labeler([x] + x_with_style, teacher=self.old_model, gt_new_classes=y,
                                                                                           unknown_lab=self.unknown_label, lab_to_force=label_to_force)
                            else:
                                pseudo_lab, pseudo_soft, pseudo_mask = self.pseudo_labeler(x_with_style, teacher=self.old_model, gt_new_classes=y,
                                                                                           unknown_lab=self.unknown_label, lab_to_force=label_to_force)
                            if self.use_kd_fda or self.args.lambda_kd_feat_style > 0.: out_old_with_style = [self.old_model(xs) for xs in x_with_style]
                            if self.use_kd_fda: pred_old_with_style = [o['out'] for o in out_old_with_style]
                            if self.args.lambda_kd_feat_style > 0.: feat_old_with_style = [o['features'] for o in out_old_with_style]

                        if self.args.label_inpainting:
                            if self.current_iter==0:
                                self.logger.info(f'Inpainting GT with predictions on {self.args.label_inpainting_type} images')
                            psl = pseudo_lab if self.args.label_inpainting_type in ('old',) else pseudo_lab_x
                            if self.unknown_label is not None:
                                y[y==self.unknown_label] = psl[y==self.unknown_label]

                        if (self.args.lambda_kd_feat > 0.) or (not self.args.distil_curr_style):
                            output_old = self.old_model(x)
                            pred_old_x, feat_old_x = output_old['out'], output_old['features']
                        if self.args.distil_curr_style:
                            output_old = self.old_model(x_with_curr_style)
                            pred_old_x_curr, feat_old_x_curr = output_old['out'], output_old['features']

                # memory_check('2')

                self.optimizer.zero_grad()
                # model
                output = self.model(x)
                pred, feat = output['out'], output['features']

                if self.current_iter==0:
                    self.logger.info('pred {}, feat {}, x {}'.format(float(pred.sum()), float(feat.sum()), float(x.sum())))

                # loss
                ### To check if same results as MiB implementation ###
                # self.mib_ce_loss = UnbiasedCrossEntropy().to(self.device)
                # self.mib_ce_loss.old_cl = max(self.ce_loss.old_classes)+1 if self.ce_loss.old_classes else 0
                # mib_ce_loss = self.mib_ce_loss(pred, y)
                # # ce_loss = mib_ce_loss

                loss_before_backward = {}

                ### CE loss (X) ###
                if not (self.args.ce_curr_style and self.args.lambda_ce_curr_style is None):  # if False, CE is computed on img with current style
                    ce_loss = self.ce_loss(pred, y)
                    if self.ce_los_aux:
                        n, n_ignore = y.numel(), (y==self.args.ignore_label).sum()
                        if n_ignore>0:   # avoid computations if no ignore_label labeled pixels exist
                            y_aux = y.clone()
                            y_aux[y==self.args.ignore_label] = self.unknown_label
                            y_aux[y!=self.args.ignore_label] = self.args.ignore_label
                            ce_loss_aux = self.ce_los_aux(pred, y)
                            ce_loss = ce_loss * (n-n_ignore)/n + ce_loss_aux * n_ignore/n  # sum the 2 contribution taking into account the # of pixels used in each loss
                    loss_before_backward[0] = ce_loss
                    self.loss_logger['CE_loss'] = ce_loss.item()

                ### KD loss (X) ###
                if self.old_model:

                    ### KD feat level ###
                    if self.args.lambda_kd_feat > 0.:
                        kd_feat_loss = self.args.lambda_kd_feat * torch.norm(feat-feat_old_x, p=2)
                        self.loss_logger['KD_feat_loss'] = kd_feat_loss.item()
                        if 0 not in loss_before_backward: loss_before_backward[0] = kd_feat_loss
                        else: loss_before_backward[0] += kd_feat_loss

                    ### UDA feat level ###
                    if self.args.lambda_uda_feat > 0.:
                        # use feature output with style under torch.no_grad
                        assert self.args.use_fda
                        with torch.no_grad():
                            feat_with_style = [self.model(xs)['features'] for xs in x_with_style]
                        uda_feat_loss = 0
                        for fs in feat_with_style:
                            uda_feat_loss += self.args.lambda_uda_feat * torch.norm(feat - fs, p=2) / len(feat_with_style)
                        self.loss_logger['UDA_feat_loss'] = uda_feat_loss.item()
                        if 0 not in loss_before_backward: loss_before_backward[0] = uda_feat_loss
                        else: loss_before_backward[0] += uda_feat_loss

                    ### KD standard ###
                    if self.use_kd_std:

                        if self.args.distil_curr_style:

                            if 0 in loss_before_backward:
                                # Backward for model(x) (in case KD is applied with 'x_with_curr_style')
                                loss_before_backward[0].backward()
                                del loss_before_backward[0]
                            output_curr_style = self.model(x_with_curr_style)
                            pred_curr_style, feat_curr_style = output_curr_style['out'], output_curr_style['features']

                            inputs, pred_old_kd = pred_curr_style, pred_old_x_curr
                            pseudo_lab_x_kd, pseudo_mask_x_kd, pseudo_soft_x_kd = pseudo_lab_x_curr.detach().clone(), pseudo_mask_x_curr.clone(), pseudo_soft_x_curr.detach().clone()
                        else:
                            inputs, pred_old_kd = pred, pred_old_x
                            pseudo_lab_x_kd, pseudo_mask_x_kd, pseudo_soft_x_kd = pseudo_lab_x.detach().clone(), pseudo_mask_x.clone(), pseudo_soft_x.detach().clone()

                        if not isinstance(distil_loss_fn_new, nn.CrossEntropyLoss):
                            pseudo_mask_pre = pseudo_mask_x_kd.clone()
                            pseudo_lab_x_kd[y!=self.unknown_label] = self.unknown_label   # all new-class regions were unknown ones at the past step
                            pseudo_mask_x_kd = pseudo_mask_pre | (y!=self.unknown_label)
                        else:
                            pseudo_lab_x_kd[y!=self.unknown_label] = self.args.ignore_label   # all new-class regions are ignored when distilling with CE
                            pseudo_mask_x_kd = pseudo_lab_x_kd != self.args.ignore_label
                        pseudo_lab_x_to_oh = pseudo_lab_x_kd.clone()
                        pseudo_lab_x_to_oh[~pseudo_mask_x_kd] = 0

                        if self.args.use_fda:
                            pseudo_lab_kd, pseudo_mask_kd = pseudo_lab.detach().clone(), pseudo_mask.clone()
                            if not isinstance(distil_loss_fn_new, nn.CrossEntropyLoss):
                                pseudo_lab_kd[y!=self.unknown_label] = self.unknown_label   # all new-class regions were unknown ones at the past step
                                pseudo_mask_kd = pseudo_mask | (y!=self.unknown_label)
                            else:
                                pseudo_lab_kd[y!=self.unknown_label] = self.args.ignore_label   # all new-class regions are ignored when distilling with CE
                                pseudo_mask_kd = pseudo_lab_kd != self.args.ignore_label
                            pseudo_lab_to_oh = pseudo_lab_kd.clone()
                            pseudo_lab_to_oh[~pseudo_mask_kd] = 0

                        kwargs = {
                            'std': {'inputs': inputs,
                                    'targets': pred_old_kd} if self.args.distil_type=='std' else None,
                            'pseudo_filter': {'inputs': inputs,
                                              'targets': F.one_hot(pseudo_lab_x_to_oh, pred_old_kd.size(1)).permute(0, 3, 1, 2)
                                              if not isinstance(distil_loss_fn_new, nn.CrossEntropyLoss) else pseudo_lab_x_kd.detach().clone(),
                                              'mask': pseudo_mask_x_kd,
                                              'is_target_output': False} if self.args.distil_type=='pseudo_filter' else None,
                            'soft_filter': {'inputs': inputs,
                                            'targets': pseudo_soft_x_kd,
                                            'mask': pseudo_mask_x_kd,
                                            'is_target_output': False} if self.args.distil_type=='soft_filter' else None,
                            'std_style': {'inputs': inputs,
                                          'targets': pseudo_soft,
                                          'is_target_output': False} if self.args.use_fda and self.args.distil_type=='std_style' else None,
                            'pseudo_filter_style': {'inputs': inputs,
                                                    'targets': F.one_hot(pseudo_lab_to_oh, pred_old_kd.size(1)).permute(0, 3, 1, 2)
                                                    if not isinstance(distil_loss_fn_new, nn.CrossEntropyLoss) else pseudo_lab_kd.detach().clone(),
                                                    'mask': pseudo_mask_kd,
                                                    'is_target_output': False} if self.args.use_fda and self.args.distil_type=='pseudo_filter_style' else None,
                            'soft_filter_style': {'inputs': inputs,
                                                  'targets': pseudo_soft,
                                                  'mask': pseudo_mask_kd,
                                                  'is_target_output': False} if self.args.use_fda and self.args.distil_type=='soft_filter_style' else None,
                            'pseudo_filter_all_styles': {'inputs': inputs,
                                                         'targets': F.one_hot(pseudo_lab_to_oh, pred_old_kd.size(1)).permute(0, 3, 1, 2)
                                                         if not isinstance(distil_loss_fn_new, nn.CrossEntropyLoss) else pseudo_lab_kd.detach().clone(),
                                                         'mask': pseudo_mask_kd,
                                                         'is_target_output': False} if self.args.use_fda and self.args.distil_type=='pseudo_filter_all_styles' else None,
                        }
                        if 'style' in self.args.distil_type:
                            assert self.args.use_fda, f'{self.args.distil_type} not available when style transfer is disabled'
                        kwargs_to_use = kwargs[self.args.distil_type]
                        if isinstance(distil_loss_fn_new, nn.CrossEntropyLoss):
                            kwargs_to_use = {'input':kwargs_to_use['inputs'], 'target':kwargs_to_use['targets']}
                        distil_loss = self.args.lambda_distil * distil_loss_fn_new(**kwargs_to_use)
                        self.loss_logger['KD_loss'] = distil_loss.item()
                        if 0 not in loss_before_backward: loss_before_backward[0] = distil_loss
                        else: loss_before_backward[0] += distil_loss

                if 0 in loss_before_backward and not (self.args.distil_curr_style and self.old_model and self.use_kd_std):  # if second term True (w\o 'not'), then backward() on 'x' has already been performed
                    # Backward for model(x) (in case KD is NOT applied with 'x_with_curr_style')
                    loss_before_backward[0].backward()
                    del loss_before_backward[0]

                if self.args.ce_curr_style:
                    assert self.args.use_fda, 'To use CE with current style, style extraction must be enabled'
                    if not (self.args.distil_curr_style and self.old_model and self.use_kd_std):  # if True (w\o 'not'), then 'pred_curr_style' has already been computed
                        output_curr_style = self.model(x_with_curr_style)
                        pred_curr_style, feat_curr_style = output_curr_style['out'], output_curr_style['features']
                    ce_loss = self.ce_loss(pred_curr_style, y)
                    if self.args.lambda_ce_curr_style is not None:
                        ce_loss *= self.args.lambda_ce_curr_style
                    if self.ce_los_aux:
                        n, n_ignore = y.numel(), (y==self.args.ignore_label).sum()
                        if n_ignore>0:   # avoid computations if no ignore_label labeled pixels exist
                            y_aux = y.clone()
                            y_aux[y==self.args.ignore_label] = self.unknown_label
                            y_aux[y!=self.args.ignore_label] = self.args.ignore_label
                            ce_loss_aux = self.ce_los_aux(pred_curr_style, y)
                            ce_loss = ce_loss * (n-n_ignore)/n + ce_loss_aux * n_ignore/n  # sum the 2 contribution taking into account the # of pixels used in each loss
                    self.loss_logger['CE_loss_curr_style'] = ce_loss.item()
                    if 0 not in loss_before_backward: loss_before_backward[0] = ce_loss
                    else: loss_before_backward[0] += ce_loss
                    # Backward for model(x_with_curr_style)
                    loss_before_backward[0].backward()
                    del loss_before_backward[0]

                if self.use_kd_fda or self.args.lambda_ce_style>0. or self.args.lambda_kd_feat_style>0.: assert self.args.use_fda
                if self.args.use_fda and step>0 and (self.use_kd_fda or self.args.lambda_ce_style>0. or self.args.lambda_kd_feat_style>0.):
                    y_style = y.clone()
                    ce_losses_style = 0
                    kd_losses_style = 0
                    kd_feat_losses_style = 0
                    for idx, xs in enumerate(x_with_style):

                        output_xs = self.model(xs)
                        pred_xs, feat_xs = output_xs['out'], output_xs['features']
                        ### CE loss (XS) ###
                        if self.args.lambda_ce_style>0.:
                            ce_loss_style = self.args.lambda_ce_style / len(x_with_style) * self.ce_loss(pred_xs, y_style)
                            loss_before_backward[idx+1] = ce_loss_style
                            ce_losses_style += ce_loss_style.item()

                        ### KD loss (XS) ###
                        if self.use_kd_fda:
                            kd_loss_style = self.args.lambda_kd_style / len(x_with_style) * self.distil_loss(pred_xs, pred_old_with_style[idx])
                            if idx+1 not in loss_before_backward: loss_before_backward[idx+1] = kd_loss_style
                            else: loss_before_backward[idx+1] += kd_loss_style
                            kd_losses_style += kd_loss_style.item()

                        ### KD loss (feat) (XS) ###
                        if self.args.lambda_kd_feat_style > 0.:
                            kd_feat_loss_style = self.args.lambda_kd_feat_style * torch.norm(feat_xs - feat_old_with_style[idx], p=2)
                            if idx+1 not in loss_before_backward: loss_before_backward[idx+1] = kd_feat_loss_style
                            else: loss_before_backward[idx+1] += kd_feat_loss_style
                            kd_feat_losses_style += kd_feat_loss_style.item()

                        # Backward for model(xs)
                        loss_before_backward[idx+1].backward()
                        del loss_before_backward[idx+1]

                    if self.use_kd_fda: self.loss_logger['KD_loss_style'] = kd_losses_style
                    if self.args.lambda_ce_style>0.: self.loss_logger['CE_loss_style'] = ce_losses_style
                    if self.args.lambda_kd_feat_style>0.: self.loss_logger['KD_feat_loss_style'] = kd_feat_losses_style


                # optimizer
                if self.args.gradient_check: grad_check(self.model, logger=self.logger)
                if self.args.gradient_clip is not None:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.args.gradient_clip, norm_type=2)
                self.optimizer.step()

                if np.isnan([float(l) for l in self.loss_logger.values() if l]).any():
                    print_str = f"Train iter {train_iter},  ID: {tr_id},  IMG: min {x.min()}, max {x.max()},  PRED: min {pred.min()}, max {pred.max()},  LABEL: min {y.min()}, max {y.max()}"
                    self.logger.info(print_str)
                    self.logger.info('')
                    self.logger.info('Norm of model weight tensors')
                    [print(f"{n} -> {p.data.norm()}") for n,p in self.model.named_parameters() if p.requires_grad]
                    self.logger.info('')
                    self.logger.info('Norm of model weight tensors with nan values')
                    [print(f"{n} -> {p.data.norm()}") for n,p in self.model.named_parameters() if p.requires_grad and torch.isnan(p.data.norm())]
                    raise ValueError(f'Loss is nan during training..., {[float(l) for l in self.loss_logger.values() if l]}')

                # eval update
                img = x.clone().cpu()
                pred = pred.data.cpu().numpy()
                label = y.cpu().numpy()
                argpred = np.argmax(pred, axis=1)
                tr_id_log = tr_id
                if self.old_model and self.args.incr_class_split not in ('standard',):
                    num_old_classes = pred_old_x.size(1) if not self.args.distil_curr_style else pred_old_x_curr.size(1)
                    argpred = np.where(argpred<num_old_classes, 0, argpred)  # when training class-incrementally, training gt maps don't have old classes
                self.Eval.add_batch(label, argpred)
                if self.current_iter % self.args.loss_log_step_interval == 0:
                    self.logger.info(str(self.loss_logger))
                    for name, elem in self.loss_logger.items():
                        if self.args.use_wandb: self.wandb_logger.log({f'Losses/{name}/s{step}':elem, 'global_step':self.current_iter}, self.current_iter)
                        name += '/step{}'.format(step)
                        self.writer.add_scalar(name, elem, self.current_iter)

                self.current_iter += 1


            self.current_epoch += 1
            tqdm_epoch.close()

            extra_data = {}
            if self.args.ce_curr_style:
                extra_data = {**extra_data, **{'Image with style (curr)': [x_with_curr_style.cpu()]}}
            if self.args.use_fda and self.old_model is not None:
                extra_data = {**extra_data, **{'Image with style':[xs.clone().cpu() for xs in x_with_style],
                                               'Pseudo Lab (Style)':pseudo_lab.clone().cpu().numpy(),
                                               'Soft Lab (Style)':(pseudo_lab.clone().cpu().numpy(), pseudo_soft.clone().cpu())}}
            if self.use_kd_std and self.old_model is not None:
                try:
                    if not self.args.distil_curr_style:
                        extra_data = {**extra_data, **{'Pseudo Lab KD (New)':pseudo_lab_x_kd.clone().cpu().numpy(),
                                                       'Soft Lab KD (New)':(pseudo_lab_x.clone().cpu().numpy(), pseudo_soft_x_kd.clone().cpu()),
                                                       'Pseudo Lab KD (Style)':pseudo_lab_kd.clone().cpu().numpy()}}
                    else:
                        extra_data = {**extra_data, **{'Pseudo Lab KD (Curr Style)':pseudo_lab_x_kd.clone().cpu().numpy(),
                                                       'Soft Lab KD (Curr Style)':(pseudo_lab_x_curr.clone().cpu().numpy(), pseudo_soft_x_kd.clone().cpu()),
                                                       'Pseudo Lab KD (Style)':pseudo_lab_kd.clone().cpu().numpy()}}
                except:
                    print('###### WARNING: Error when logging data ######')

            self.log(img, label, argpred, split_name='train', img_id=tr_id_log[0], extra_data=extra_data if extra_data else None,
                     log_wandb_imgs = epoch%self.args.wandb_img_log_interval==0 or epoch==self.num_epochs-1)
            if self.args.val_interval is None or (self.args.val_interval is not None and epoch % self.args.val_interval == 0) or epoch==self.num_epochs-1:
                metrics_to_ret = self.validate(log_wandb_imgs=epoch%self.args.wandb_img_log_interval==0 or epoch == self.num_epochs-1)
                if epoch==self.num_epochs-1:
                    if step not in self.result_dict: self.result_dict[step] = {}
                    self.result_dict[step][self.val_dataloader.dataset.dataset_name] = metrics_to_ret
            if self.args.val_all_data_interval is not None and epoch % self.args.val_all_data_interval == 0 and epoch<self.num_epochs-1:  # last epoch eval on other dataset is always done inside 'train_init()'
                self.validate_other_datasets(step=step, log_wandb_imgs=epoch%self.args.wandb_img_log_interval==0 or epoch==self.num_epochs-1)
            if self.args.dataset_dg_test is not None and self.args.test_dg_data_interval is not None \
                    and epoch % self.args.test_dg_data_interval == 0 and epoch<self.num_epochs-1:  # last epoch eval on other dataset is always done inside 'train_init()'
                self.validate_dg_dataset(step=step, log_wandb_imgs=epoch%self.args.wandb_img_log_interval==0 or epoch==self.num_epochs-1)

            if self.wandb_save_curr_model and step>=self.step_to_start_wandb and epoch>0 and (epoch+1) % self.args.wandb_save_curr_model_interval == 0:
                if self.args.debug:
                    print('DEBUG: validate model after wandb saving')
                    save_model_metric = self.validate()
                self.wandb_logger.save_wandb_model(self.model, step, self.args.checkpoint_dir, epoch=epoch, is_mid_step=True)




    def skip_step(self, step):
        # Condition for skipping incrementatal step
        if self.wandb_load_model and self.step_to_start_wandb>step and self.wandb_ckpt is not None:
            self.logger.info(f'WARNING: Checkpoint of later step was found, training at step {step} skipped')
            return True
        if self.args.step_to_start > step:
            self.logger.info(f'WARNING: Training of step {step} skipped, starting from step {self.args.step_to_start}')
            return True
        return False

    def poly_lr_scheduler(self, optimizer, init_lr=None, iter=None, max_iter=None, power=None):
        """
        Scheduler for learning rate applying polynomial decay
        :param optimizer: the chosen optimizer
        :param init_lr: initial learning rate value
        :param iter: current iteration
        :param max_iter: final iteration corresponding to lr=0; if set to <= 0, then lr is kept fixed
        :param power: power in the decaying scheme
        """
        init_lr = self.args.lr if init_lr is None else init_lr
        iter = self.current_iter if iter is None else iter
        max_iter = self.args.max_iter_poly_decay if max_iter is None else max_iter
        power = self.args.poly_power if power is None else power
        new_lr = init_lr * (1 - float(iter) / max_iter) ** power if max_iter >= 0 else init_lr  # when max_iter set to <= 0, keep lr fixed
        optimizer.param_groups[0]["lr"] = new_lr
        if len(optimizer.param_groups) == 2:
            optimizer.param_groups[1]["lr"] = 10 * new_lr

    def save_checkpoint(self, step=None, filepath=None, save_iter=False):
        """
        Save weights of current model into ckpt
        :param step: current incremental step
        :param filepath: filepath to save the ckpt to
        :param save_common: if to save model into common folder
        :param save_iter: if saving intermediate ckpt -> remove pre-existing ones
        :param save_new_proto: saving old or new prototypes of similarity loss (to be next loaded as old prototypes)
        """
        state = {'epoch': self.current_epoch,  # epoch to restart on
                 'iteration': self.current_iter,  # iter to restart on
                 'state_dict': self.model.state_dict(),
                 'optimizer': self.optimizer.state_dict(),
                 'args': vars(self.args),
                 'prototypes': None,
                 'prototypes_new': None}

        if filepath:
            torch.save(state, filepath)
            self.logger.info('Checkpoint saved successfully at {}'.format(filepath))
            return

        if save_iter:
            # remove all pre-existing intermediate ckpt
            pattern = 'model_step*_iter*.pth'.format(step)
            saved_models = [os.remove(path) for path in glob.glob(os.path.join(self.args.checkpoint_dir,pattern))]
            # save new intermediate ckpt
            filepath = os.path.join(self.args.checkpoint_dir, 'model_step{}_iter{}.pth'.format(step,self.current_iter))
            torch.save(state, filepath)
            self.logger.info('Checkpoint saved successfully at {}'.format(filepath))
            return

        filepath = os.path.join(self.args.checkpoint_dir, 'model_step{}.pth'.format(step))
        torch.save(state, filepath)
        self.logger.info('Checkpoint saved successfully at {}'.format(filepath))

    def load_checkpoint(self, step=None, filepath=None, network='all', target_model=None, start_from_iter=False):
        """
        Load ckpt into old or new models
        :param step: step to look the ckpt for
        :param filepath: filepath of the ckpt to load
        :param network: network layers to load
        :param target_model: target model to load the weights on
        :param start_from_iter: when loading intermediate model (new prototypes to be loaded and params to be updated)
        """
        filepath = os.path.join(self.args.checkpoint_dir, 'model_step{}.pth'.format(step)) if step is not None else filepath
        try:
            self.logger.info('Loading checkpoint {}'.format(filepath))
            model = self.model if target_model is None else target_model

            # load the entire ckpt
            checkpoint = torch.load(filepath, map_location=self.device)
            assert network in ('encoder', 'decoder', 'all')
            model.load_network(checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint, network)


            # if starting from intermediate step, load new prototypes and update epoch and iter starting values
            if start_from_iter:
                assert 'prototypes_new' in checkpoint, 'Prototypes new missing from ckpt, impossible to restart from iter'
                self.current_iter = checkpoint['iteration']
                self.current_epoch = checkpoint['epoch']
                self.logger.info('New prototypes loaded for first {} classes  ->  {}'.format(len(checkpoint['new_prototypes']), [p[0].item() if torch.is_tensor(p) else 'None' for p in checkpoint['new_prototypes']]))

            self.logger.info('Checkpoint loaded successfully from {}, {} layers'.format(filepath, network))

        except OSError:
            self.logger.info('No checkpoint exists from {}. Skipping...'.format(self.args.checkpoint_dir))
            self.logger.info('**First time to train**')

    def log(self, x, label, argpred, split_name='train', img_id='', extra_data=None, use_logger=True, log_wandb_imgs=False):

        split_names = {'train':'TRAINING', 'val':'VALIDATION', 'test':'TESTING'}
        if split_name not in split_names.keys():
            raise ValueError('Invalid split {}'.format(split_name))

        #show train image on tensorboard
        images_inv = inv_preprocess(x, self.args.show_num_images, img_mean=self.train_dataloader.dataset.image_mean if split_name=='train' else self.val_dataloader.dataset.image_mean,
                                    numpy_transform=self.args.numpy_transform)
        labels_colors = decode_labels(label, self.class_names, self.args.class_set, self.args.show_num_images, self.args.ignore_label)
        preds_colors = decode_labels(argpred, self.class_names, self.args.class_set, self.args.show_num_images, self.args.ignore_label)

        def np2pil(npa, mode=None):
            return torchvision.transforms.ToPILImage(mode=mode)(npa)

        prefix = 'step{}/{}/{}/{{}}/{}'.format(self.current_incr_step, split_name, self.curr_data_split_train if split_name=='train' else self.curr_data_split_val, img_id)
        if log_wandb_imgs:
            for index, (img, lab, color_pred) in enumerate(zip(images_inv, labels_colors, preds_colors)):
                self.writer.add_image(prefix.format(index) + '/Images', img, self.current_epoch)
                self.writer.add_image(prefix.format(index) + '/Labels', lab, self.current_epoch)
                self.writer.add_image(prefix.format(index) + '/Preds', color_pred, self.current_epoch)

        dataset_name = self.train_dataloader.dataset.dataset_name if split_name=='train' else self.val_dataloader.dataset.dataset_name
        if self.args.use_wandb and log_wandb_imgs:
            curr_data_split = self.curr_data_split_train if split_name=='train' else self.curr_data_split_val
            wandb_string = f"Eval Step {self.current_incr_step} ({curr_data_split}) on {dataset_name} ({split_name.title()})"
            wandb_image_data = [wandb.Image(wawb_image, caption=n) for n,wawb_image in zip(('Image', 'GT', 'Prediction'),(np2pil(images_inv[0]), np2pil(labels_colors[0]), np2pil(preds_colors[0])))]
            if extra_data is not None:
                for n,v in extra_data.items():
                    if 'Image with style' in n:  # stylized images
                        for idx,el in enumerate(v):
                            images_inv = inv_preprocess(el, self.args.show_num_images, img_mean=self.train_dataloader.dataset.image_mean if split_name=='train' else self.val_dataloader.dataset.image_mean,
                                                        numpy_transform=self.args.numpy_transform)
                            wandb_image_data += [wandb.Image(np2pil(images_inv[0]), caption=f'{n} ({idx})')]
                    elif 'Pseudo Lab' in n:
                        pseudo = decode_labels(v, self.class_names, self.args.class_set, self.args.show_num_images, self.args.ignore_label) # B images, take only the first to plot
                        wandb_image_data += [wandb.Image(np2pil(pseudo[0]), caption=n)]
                    elif 'Soft Lab' in n:  # v expected as (pseudo label, softmax) of dims (BxHxW, BxCxHxW)
                        pseudo = decode_labels(v[0], self.class_names, self.args.class_set, self.args.show_num_images, self.args.ignore_label).cpu().numpy()
                        soft = v[1].max(1)[0].unsqueeze(1).numpy()
                        pseudo_soft = torch.from_numpy(np.concatenate((pseudo[0],soft[0]),axis=0))  # B x (3+1) x
                        wandb_image_data += [wandb.Image(np2pil(pseudo_soft, mode="RGBA"), mode="RGBA", caption=n)]  # B images, take only the first to plot
                    else:
                        return NotImplementedError(n)
            wandb_data = {wandb_string: wandb_image_data, 'global_step':self.current_iter}

            def wb_mask(bg_img, pred_mask, true_mask):
                class_labels = {i:n for i,n in enumerate(self.class_names)}
                return wandb.Image(bg_img, masks={
                    "prediction" : {"mask_data" : pred_mask, "class_labels" : class_labels},
                    "ground truth" : {"mask_data" : true_mask, "class_labels" : class_labels}})
            wandb_data = {**wandb_data, **{wandb_string + f' Image Overlay':wb_mask(np2pil(images_inv[0]), argpred[0].astype(np.uint8), label[0].astype(np.uint8))}}

            self.wandb_logger.log(wandb_data, self.current_iter)

        PA = self.Eval.Pixel_Accuracy()
        MPA = self.Eval.Mean_Pixel_Accuracy()
        MIoU,MIoU_no_ukw = self.Eval.Mean_Intersection_over_Union()
        FWIoU = self.Eval.Frequency_Weighted_Intersection_over_Union()
        metric_to_ret = self.Eval.Print_Every_class_Eval(
            self.logger if use_logger else None, self.wandb_logger,
            self.current_incr_step, self.curr_data_split_train if split_name=='train' else self.curr_data_split_val,
            split_name, self.current_iter, dataset_name)

        self.logger.info('{} RESULTS'.format(split_names[split_name]))

        self.logger.info('Step:{}, Data split:{} Epoch:{}, PA:{:.5f}, MPA:{:.5f}, MIoU:{:.5f}, FWIoU:{:.5f}\n'.format(
            self.current_incr_step, self.curr_data_split_train if split_name=='train' else self.curr_data_split_val,self.current_epoch, PA, MPA, MIoU, FWIoU))

        prefix = 'step{}/{}/{}/'.format(self.current_incr_step, split_name, self.curr_data_split_train if split_name=='train' else self.curr_data_split_val)
        log_step = self.current_iter * self.args.batch_size
        self.writer.add_scalar('{}PA'.format(prefix), PA, log_step)
        self.writer.add_scalar('{}MPA'.format(prefix), MPA, log_step)
        self.writer.add_scalar('{}MIoU'.format(prefix), MIoU, log_step)
        self.writer.add_scalar('{}MIoU_no_ukw'.format(prefix), MIoU_no_ukw, log_step)
        self.writer.add_scalar('{}FWIoU'.format(prefix), FWIoU, log_step)

        return metric_to_ret

    def validate(self, mode='val', log_wandb_imgs=False):
        size = '(original)' if self.args.val_at_orig_size else '(reduced)'
        self.logger.info('\nValidating on {} data split {} ...'.format(self.curr_data_split_val,size))
        self.Eval.reset()
        with torch.no_grad():
            tqdm_batch = tqdm(self.val_dataloader, total=self.validation_steps,
                              mininterval=None, miniters=self.validation_steps // 10, maxinterval=60000.,
                              desc="Validating at Incremental Step {} - Data split {} - Epoch {} over {}".format(self.current_incr_step, self.curr_data_split_val, self.current_epoch, self.num_epochs))
            if mode == 'val':
                self.model.eval()

            for val_iter, (x, y, val_id, input_size) in enumerate(tqdm_batch):

                # debug mode: go quickly through the whole training procedure
                if self.args.debug and val_iter==5: break

                x, y = x.to(self.device), y.to(device=self.device, dtype=torch.long)

                # model
                pred = self.model(x)['out']
                y = torch.squeeze(y, 1)

                # upsample input and its prediction map
                if self.args.val_at_orig_size:
                    pred = F.softmax(pred, dim=1)
                    pred = F.interpolate(pred, size=input_size[::-1], mode='bilinear', align_corners=True)
                    x = F.interpolate(x, size=input_size[::-1], mode='bilinear', align_corners=True)

                # eval
                img = x.clone().cpu()
                pred = pred.data.cpu().numpy()
                label = y.cpu().numpy()
                argpred = np.argmax(pred, axis=1)

                self.Eval.add_batch(label, argpred)


            tqdm_batch.close()
            metric_to_ret = self.log(img, label, argpred, split_name='val', img_id=val_id[0], log_wandb_imgs=log_wandb_imgs)

        return metric_to_ret

    def validate_other_datasets(self, step, mode='val', log_wandb_imgs=False):

        if self.args.incr_val_data not in ['before','all']:
            return  # no other datasets to validate on

        if step not in self.val_dls:
            self.val_dls[step] = {step:{'val_dataloader':self.val_dataloader, 'validation_steps':self.validation_steps, 'curr_data_split_val':self.curr_data_split_val}}

        if self.args.incr_val_data == 'before':
            val_steps = list(range(step))
        else:
            curr_data_split = self.args.single_data_split if self.args.single_data_split else list_paths[self.args.incr_data_split][step]
            val_steps = [i for i,el in enumerate(list_paths[self.args.incr_data_split]) if el != curr_data_split]
        for vs in val_steps:
            if vs not in self.val_dls[step]:
                self.dataset_val_init(step=vs)
                self.val_dls[step][vs] = {'val_dataloader':self.val_dataloader, 'validation_steps':self.validation_steps, 'curr_data_split_val':self.curr_data_split_val}
            else:
                self.val_dataloader = self.val_dls[step][vs]['val_dataloader']
                self.validation_steps = self.val_dls[step][vs]['validation_steps']
                self.curr_data_split_val = self.val_dls[step][vs]['curr_data_split_val']
            self.validate(mode=mode, log_wandb_imgs=log_wandb_imgs)

        # restore variables referring to current dataset
        self.val_dataloader = self.val_dls[step][step]['val_dataloader']
        self.validation_steps = self.val_dls[step][step]['validation_steps']
        self.curr_data_split_val = self.val_dls[step][step]['curr_data_split_val']

    def validate_dg_dataset(self, step, mode='val', log_wandb_imgs=False):

        if step not in self.val_dls:
            self.val_dls[step] = {step:{'val_dataloader':self.val_dataloader, 'validation_steps':self.validation_steps, 'curr_data_split_val':self.curr_data_split_val}}

        self.dataset_dg_test(step)
        if 'dg' not in self.val_dls[step]:
            self.val_dls[step]['dg'] = {'val_dataloader':self.val_dataloader, 'validation_steps':self.validation_steps, 'curr_data_split_val':self.curr_data_split_val}
        else:
            self.val_dataloader = self.val_dls[step]['dg']['val_dataloader']
            self.validation_steps = self.val_dls[step]['dg']['validation_steps']
            self.curr_data_split_val = self.val_dls[step]['dg']['curr_data_split_val']
        self.validate(mode=mode, log_wandb_imgs=log_wandb_imgs)

        # restore variables referring to current dataset
        self.val_dataloader = self.val_dls[step][step]['val_dataloader']
        self.validation_steps = self.val_dls[step][step]['validation_steps']
        self.curr_data_split_val = self.val_dls[step][step]['curr_data_split_val']

    def compute_and_print_global_metrics(self, step):

        if step not in self.delta_results: self.delta_results[step] = {}

        wandb_data= {'global_step': self.current_iter}

        self.logger.info(f'####################################')
        self.logger.info(f'Step {step} Metrics (mIoU)')
        base_res = baseline_results[self.args.incr_class_split][self.args.model][step]
        self.logger.info(f"{'':15}: {''.join( [f'{el:6}' if i>0 else f'{el:4}' for i,el in enumerate(list(list(self.result_dict[step].values())[0].keys())[:-1]) ] )}")
        for ds_name, m_dict in self.result_dict[step].items():
            if ds_name not in base_res: continue
            delta = (m_dict['total'] - base_res[ds_name]) / base_res[ds_name]
            self.delta_results[step][ds_name] = delta*100
            self.logger.info(f"{ds_name:15}: {''.join([f'{round(v,2):6}' for v in list(m_dict.values())[:-1]])} | "
                             f'total mIoU: {round(list(m_dict.values())[-1],2):5}% (baseline: {base_res[ds_name]:5}%) | delta: {round(delta*100,2):6}%')
        if self.args.incr_val_data in ('all', 'before'):
            avg_delta = sum([ self.delta_results[step][ fix_data_name(list_paths[self.args.incr_data_split][k]) ] for k in range(step+1)]) / (step+1)
            self.delta_results[step]['domain avg'] = avg_delta
            wandb_data['Global Metrics/Avg Delta'] = avg_delta
            self.logger.info(f' Avg delta: {avg_delta:6}%')
        extra_data = [(baseline_results[self.args.incr_class_split][self.args.model], 'baseline'), (self.delta_results, 'step avg (delta)', 'Delta')]
        if step > 0:
            try:
                class_acc = []
                for s in range(step):
                    class_acc_k = []
                    for k in range(s+1):
                        ds_k = fix_data_name(list_paths[self.args.incr_data_split][k])
                        if ds_k not in self.result_dict[step] or ds_k not in self.result_dict[s]: continue
                        class_acc_ds_k = (self.result_dict[step][ds_k][s] - self.result_dict[s][ds_k][s]) / self.result_dict[s][ds_k][s]
                        if self.result_dict[s][ds_k][s] <= 0.: class_acc_ds_k = 0   #
                        class_acc_k.append(class_acc_ds_k)
                    if class_acc_k: class_acc.append(sum(class_acc_k)/len(class_acc_k))
                if class_acc: self.logger.info(f'Class-incremental metric:  {str([round(v*100,2) for v in class_acc]):20} | avg: {round(sum(class_acc)/step*100,2):6}%')
                wandb_data['Global Metrics/Avg Forgetting'] = sum(class_acc)/step*100
            except Exception as e:
                self.logger.info(f"'{e}' was raised when computing class-incremental metric, skipping...")
        self.logger.info(f'####################################')

        if self.wandb_logger is not None: self.wandb_logger.log(wandb_data, self.current_iter)






