from . import *

def grad_clipping(model):
    raise NotImplementedError

def grad_check(model, logger):
    nan_grad_list = [(n,p.grad.norm()) for n,p in model.named_parameters() if p.requires_grad and torch.isnan(p.grad.norm())]
    if nan_grad_list:
        logger.info('')
        logger.info('Parameters with exploding gradients')
        [print(f"{n} -> {p}") for n,p in nan_grad_list]

def fix_data_name(n):
    n = n.lower()
    n2n = {'gta':'gta5'}
    if n in n2n: return n2n[n]
    return n

def memory_check(log_string):
    torch.cuda.synchronize()
    print(log_string)
    print(' peak:', '{:.3f}'.format(torch.cuda.max_memory_allocated() / 1024 ** 3), 'GB')
    print(' current', '{:.3f}'.format(torch.cuda.memory_allocated() / 1024 ** 3), 'GB')


# used only in train.py
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.random.manual_seed(seed)

def hash_det(s,l=8):
    return str(int(hashlib.sha512(s.encode('utf-8')).hexdigest(), 16))[:l]


def save_results(data: Tuple[Dict,str], file_dir_path: str, extra_data: Optional[List] = None, wb_logging: bool = True) -> None:
    """
    data: Tuple( dict of struct -> {step_idx: {domain_name: {col0_metric_name: col0_metric_value, ...}, ...}, ...} , supercol_name)
    extra_data: list of tuples -> [(data_dict, col_name (optional), supercol_name (optional)), ...]
                data_dict could be of the form: a) {step_idx: {domain_name: {col0_metric_name: col0_metric_value, ...}, ...}, ...} (col_name is not used)
                                                b) {step_idx: {domain_name: metric_value, ...}, ...} (col_name is required)
                                                c) {step_idx: metric_value, ...} (col_name is required)
    """

    def fill_gaps(dict_to_fill: Dict) -> Dict:
        all_keys = []
        for step,step_data in list(dict_to_fill.items())[::-1]:
            if not isinstance(step_data, dict): return dict_to_fill  # data_dict of type (c)
            for domain,domain_data in step_data.items():
                if not isinstance(domain_data, dict): return dict_to_fill  # data_dict of type (b)
                if set(all_keys).issubset(set(domain_data.keys())):
                    all_keys = list(domain_data.keys())
                if set(domain_data.keys()) - set(all_keys):
                    all_keys += [k for k in domain_data.keys() if k not in all_keys]
        dtf_clone = dict_to_fill.copy()
        for step,step_data in dict_to_fill.items():
            for domain,domain_data in step_data.items():
                dtf_clone[step][domain] = {}
                for k in all_keys:
                    if k not in domain_data:
                        dtf_clone[step][domain][k] = None
                    else:
                        dtf_clone[step][domain][k] = domain_data[k]
        return dtf_clone

    def merge(main_dict: Dict, dict_to_merge: Dict, col_name: Optional[str]) -> Dict:
        md_clone = main_dict.copy()
        dtm_clone = dict_to_merge.copy()
        for ko, vo in md_clone.items():  # ko is step idx
            if ko not in dtm_clone: continue
            for ki in vo:   # ki is domain name
                if not isinstance(dtm_clone[ko], dict):                             # case (c)
                    md_clone[ko][ki][col_name] = dtm_clone[ko]
                    dtm_clone[ko] = None
                elif ki not in dtm_clone[ko]:                                       # case (a)/(b) with domain name missing
                    if col_name is not None: md_clone[ko][ki][col_name] = None
                elif not isinstance(dtm_clone[ko][ki], dict):                       # case (b)
                    md_clone[ko][ki][col_name] = dtm_clone[ko][ki]
                else:                                                               # case (a)
                    md_clone[ko][ki] = {**md_clone[ko][ki], **dtm_clone[ko][ki]}
            if isinstance(dtm_clone[ko], dict):
                for ki_dtm in dtm_clone[ko]:  # loop over all domains
                    if ki_dtm not in vo:
                        if not isinstance(dtm_clone[ko][ki_dtm], dict):                         # case (b)
                            md_clone[ko][ki_dtm] = {col_name: dtm_clone[ko][ki_dtm]}
                        else:                                                                   # case (a)
                            md_clone[ko][ki_dtm] = {**md_clone[ko][ki_dtm], **dtm_clone[ko][ki_dtm]}
        return md_clone

    data, super_col_name = data
    # columns = [(super_col_name,k) for k in list(list(data.values())[0].values())[0].keys()]
    columns = [(super_col_name,k) for k in [*[*data.values()][-1].values()][-1].keys()]
    if extra_data:
        for el in extra_data:
            if not isinstance(el, tuple):
                el = (el, None, super_col_name)
            elif len(el)==2:
                el = (el[0], el[1], super_col_name)
            data_dict, col_name, sup_col_name = el
            data = merge(data, data_dict, col_name)
            if col_name is None:
                columns += [(sup_col_name,k) for k in [*[*data_dict.values()][-1].values()][-1].keys()]
            else:
                columns += [(sup_col_name, col_name)]

    data = fill_gaps(data)
    index = [(i,j) for i in data.keys() for j in data[i].keys()]
    data_df = [[*data[i][j].values()] for i,j in index]
    dataf = {
        'index': index,
        'columns': columns,
        'data': data_df,
        'index_names': [None, None],
        'column_names': [None, None],
    }

    df = pd.DataFrame.from_dict(dataf, orient='tight')
    file_path = f'{file_dir_path}/results.xlsx'
    df.to_excel(file_path)

    if wb_logging:
        file_path_wb = f'{wandb.run.dir}/results.xlsx'
        df.to_excel(file_path_wb)
        wandb.save(file_path_wb)

class WandB:

    def __init__(self, args):
        self.args = args
        self.step_offset = 0
        self._set_job_name()
        self._set_group_name()
        wid = wandb.util.generate_id()
        wandb.init(name=self.job_name, # job name
                   project='CL_UDA', # The name of the project where you're sending the new run
                   group=self.group_name,
                   entity='lttm',
                   resume='allow',
                   id=wid,
                   config=args)

    def _set_group_name(self):
        if self.args.single_data_split: self.group_name = f'{self.args.single_data_split}'
        else: self.group_name = f'{self.args.incr_data_split}'
        self.group_name += f'_{self.args.incr_class_split}_{self.args.class_set}'

    def _set_job_name(self):
        def list2str(l):
            if not isinstance(l, list): return l
            return '-'.join([str(el) for el in l])
        self.job_name = f"{self.args.model}_lr{self.args.lr:.1E}_lrincr{self.args.lr_incremental:.1E}_ep{self.args.epochs}_epincr{list2str(self.args.epochs_incremental)}"\
                        f"_bs{self.args.batch_size}_res{list2str(self.args.resize_to)}_crop{list2str(self.args.crop_to)}_{'CEunb' if self.args.use_ce_unbiased else 'noCEunb'}" \
                        f"_{'unbCLSinit' if self.args.unbiased_cls_init else 'nounbCLSinit'}_{self.args.optim.lower()}_wd{self.args.weight_decay:.1E}"
        if self.args.wandb_name: self.job_name += f'_{self.args.wandb_name}'

    def log(self, data_dict, current_iter=None):
        if current_iter is not None: wandb.log(data_dict, step=current_iter + self.step_offset)
        else: wandb.log(data_dict)

    def restore(self, run_id, wandb_path, step):
        ckpt_name = f'step{step}.ckpt' if step is not None else 'curr.ckpt'
        try:
            print(f"Loading previous model {f'at step {step}' if step is not None else '(curr)'} from {run_id}...")
            ckpt_path = '/'.join(['checkpoints', ckpt_name])
            wandb.restore(ckpt_path, run_path=f"{wandb_path}/{run_id}", replace=True, root=os.getcwd())
            self.remove_file(ckpt_name)
            if not torch.cuda.is_available():
                return torch.load(ckpt_path, map_location=torch.device('cpu'))
            return torch.load(ckpt_path)
        except ValueError:
            print(f"===>>>  Warning: Ckpt missing  <<<===")

    @staticmethod
    def remove_file(name):
        """ Input: string to be contained inside the name of the file to be removed"""
        wandb_files_on_clound = wandb.Api().run(f'{wandb.run.entity}/{wandb.run.project}/{wandb.run.id}').files()
        f_to_remove = []
        for f in wandb_files_on_clound:
            if name in f.name:
                f_to_remove.append(f)
        if not f_to_remove:
            print(f'WARNING: {name} to remove has not been found among wandb files of current run')
            return
        print(f'WARNING: {[f.name for f in f_to_remove]} will be deleted from the cloud')
        [f.delete() for f in f_to_remove]

    @staticmethod
    def save_wandb_model(model, step, ckpt_dir, epoch=None, is_mid_step=False):
        if is_mid_step:
            assert epoch is not None, "When saving mid step, epoch is required"
            ckpt_name = f'curr.ckpt'
        else:
            ckpt_name = f'step{step}.ckpt'
        print(f"Saving model at {'end' if not is_mid_step else f'epoch {epoch}'} of step {step}...")
        state = {
            "model_state": model.state_dict(),
            "step": step,
        }
        if epoch is not None:
            state['epoch'] = epoch
        ckpt_path = os.path.join(ckpt_dir, 'checkpoints', ckpt_name)
        if not os.path.isdir(os.path.dirname(ckpt_path)): os.makedirs(os.path.dirname(ckpt_path))
        torch.save(state, ckpt_path)
        wandb.save(ckpt_path, base_path=ckpt_dir, policy="now")
        print("Done")

    def load_wandb_curr_model(self):
        assert self.args.wandb_id is not None, "When loading mid-step model wandb-id needs to be specified"
        wandb_path = f'{wandb.run.entity}/{wandb.run.project}'
        ckpt = self.restore(self.args.wandb_id, wandb_path, step=None)
        step_to_start = ckpt['step']
        old_ckpt = None
        if step_to_start>0:
            old_ckpt = self.get_pretrained_model(step_to_start-1)
            if not old_ckpt:
                return 0, None, None, None
        return step_to_start, ckpt, old_ckpt, ckpt['epoch']

    def load_wandb_model(self, steps=[]):
        ckpt, step_to_start = None, 0
        for step in steps:
            print(f"Checking if pretrained model of step {step} exists...")
            ckpt = self.get_pretrained_model(step)
            if ckpt is not None:
                step_to_start = step+1
                break
        if ckpt is not None: print(f"Found wandb pretrained checkpoint of step {step_to_start-1}")
        else: print("Ckpt not found, training from scratch")
        return step_to_start, ckpt

    def get_pretrained_model(self, step):
        group_name = wandb.run.group
        wandb_path = f'{wandb.run.entity}/{wandb.run.project}'
        keys_to_consider = OrderedDict({
            0: ["lr", "epochs", "iter_max", "max_iter_poly_decay", "poly_power", "weight_decay", "momentum", "optim", "step_to_start", "debug",
                "numpy_transform", "random_translation", "gaussian_blur", "random_mirror", "crop_to", "random_resize", "resize_to", "batch_size",
                "model", "encoder_weights_from_past", "incr_val_data", "single_data_split", "incr_data_split", "incr_class_split", "class_set",
                "gradient_clip", "use_MiB_reimplemented",
                "ce_curr_style", "lambda_ce_curr_style"],
            1: ["lr_incremental", "epochs_incremental", "use_ce_unbiased", "distil_type",
                "lambda_ce_style", "lambda_kd_style", "avg_style", "L", "n_images_per_style", "use_fda", "lambda_distil",
                "lambda_kd_feat", "lambda_kd_feat_style", "lambda_uda_feat"
                "freeze_encoder", "label_inpainting", "label_inpainting_type", "global_thresh", "class_keep_fract"]})
        default_values = OrderedDict({
            0: {"step_to_start":[0], "ce_curr_style":[False], "lambda_ce_curr_style":[None,0.]},
            1: {}
        })
        for s,l in list(keys_to_consider.copy().items())[:-1]:
            keys_to_consider[s+1] += keys_to_consider[s]

        if step not in keys_to_consider: return
        pretrain_keys = list(set(vars(self.args)) & set(keys_to_consider[step]))
        new_run_params = {key: value for key, value in vars(self.args).items() if key in pretrain_keys}

        def normalize(v):
            if isinstance(v, str):
                return str(float(v)) if v.isnumeric() else v
            elif isinstance(v, Number):
                return str(float(v))
            elif isinstance(v, dict):
                return [normalize(el) for el in v.items()]
            elif isinstance(v, Iterable):
                return [normalize(el) for el in v]
            else:
                return str(v)

        def dict_eq(d1, d2, return_diff=False):
            diff_list = []
            for k,v in d1.items():
                if k in d2 and not normalize(v)==normalize(d2[k]):
                    if not return_diff: return False
                    diff_list.append(k)
            if diff_list and return_diff:
                return False, diff_list
            elif return_diff:
                return True, []
            return True

        def is_default(d, diff_list):
            not_def = []
            for el in diff_list:
                if el in d and el in default_values[step]:
                    if d[el] in default_values[step][el]: continue
                not_def.append(el)
            return not_def

        if self.args.wandb_id is not None:
            try:
                old_run = wandb.Api().run(path=f'{wandb_path}/{self.args.wandb_id}')
            except wandb.errors.CommError:
                print('<<<<<<<<<<<<<<  WARNING  >>>>>>>>>>>>>>')
                print(f'Args of target wandb run {self.args.wandb_id} was not found')
                print(' ')
                return
            old_run_params = {key: value for key, value in old_run.config.items() if key in pretrain_keys}
            diff = set(new_run_params.keys()) ^ set(old_run_params.keys())
            diff_wo_def = list(set(is_default(new_run_params,diff)) & set(is_default(old_run_params,diff)))
            is_eq, not_eq_list = dict_eq(new_run_params, old_run_params, return_diff=True)    # Checks for all keys in common between d1 and d2 if values are the same (key not found is NOT counted as mismatch)
            # reverseCIL and CIL have step 0 in common
            if step == 0 and 'incr_class_split' in not_eq_list and old_run_params['incr_class_split'] in ('CIL', 'reverseCIL') and self.args.incr_class_split in ('CIL', 'reverseCIL'):
                not_eq_list.remove('incr_class_split')
                is_eq = len(not_eq_list) == 0
            # if not diff and is_eq:
            if not diff_wo_def and is_eq:
                ckpt = self.restore(self.args.wandb_id, wandb_path, step)
            else:
                ckpt = None
                print('!!!!!!!!!!!!!!  WARNING  !!!!!!!!!!!!!!')
                # print(f'Args of target wandb run {self.args.wandb_id} at step {step} differ from those of the current run (diff: {diff}, not_eq_list: {not_eq_list}): ckpt has not been loaded')
                print(f'Args of target wandb run {self.args.wandb_id} at step {step} differ from those of the current run (diff w/o defaults: {diff_wo_def}, not_eq_list: {not_eq_list}): ckpt has not been loaded')
                print(' ')
            return ckpt

        for old_run in wandb.Api().runs(path=wandb_path):
            if old_run._attrs["group"] == group_name:
                old_run_params = {key: value for key, value in old_run.config.items() if key in pretrain_keys}
                diff = set(new_run_params.keys()) ^ set(old_run_params.keys())
                if not diff and dict_eq(new_run_params,old_run_params):
                    ckpt = self.restore(old_run.id, wandb_path, step)
                    if ckpt is not None: return ckpt

    @staticmethod
    def log_metrics(metrics_by_name):
        for metric_name, metric_ptr in metrics_by_name.items():
            for _step, value in zip(metric_ptr["steps"], metric_ptr["values"]):
                if isinstance(value, np.ndarray):
                    wandb.log({metric_name: wandb.Image(value)})
                else:
                    wandb.log({metric_name: value})


class LossLogger:

    def __init__(self):
        self.log_dict = OrderedDict()

    def __setitem__(self, key, item):
        if isinstance(key, list) and isinstance(item, list):
            assert len(key) == len(item)
            for k,v in zip(key,item):
                self.log_dict[k] = v
        else:
            self.log_dict[key] = item

    def __str__(self):
        log_string = '={:5f} - '.join([str(n) for n in self.log_dict.keys()]) + '={:5f}'
        return log_string.format(*list(self.log_dict.values()))

    def items(self):
        return self.log_dict.items()

    def values(self):
        return self.log_dict.values()

    def keys(self):
        return self.log_dict.keys()
