from . import *

def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Unsupported value encountered.')

def str2str_none_num(v,t=float):
    if v.lower() in ('none',):
        return None
    else:
        try: return t(v)
        except ValueError: return v

def str2list(tp):
    def conv(s):
        l = s.split(',')
        l = [str2str_none_num(el,t=tp) for el in l]
        return l if len(l)>1 else l[0]
    return conv

def str2dictlist(tp):
    def conv(s):
        ll = s.split(':')
        ll = {i:str2list(tp)(l) for i,l in enumerate(ll)}
        return ll if len(ll)>1 else ll[0]
    return conv




def add_train_args(arg_parser):

    arg_parser.add_argument('--debug', type=str2bool, default=False, help="whether to do debug or not")
    arg_parser.add_argument('--step_to_stop', type=int, default=2, help="last step to perform (included)")
    arg_parser.add_argument('--step_to_start', type=int, default=0, help="first step to perform (included)")
    arg_parser.add_argument('--use_MiB_reimplemented', type=str2bool, default=False, help="whether to do use MiB's reimplementation")
    arg_parser.add_argument('--deterministic', type=str2bool, default=False, help="whether to enable deterministic run")

    ### WandB related args ###
    arg_parser.add_argument('--use_wandb', type=str2bool, default=False, help="whether to enable wandb logging")
    arg_parser.add_argument('--wandb_name', type=str, default=None, help="additional postfix to add to wandb job name")
    arg_parser.add_argument('--wandb_img_log_interval', type=int, default=20, help="interval in epochs between wandb image logging")
    arg_parser.add_argument('--wandb_save_model', type=str2bool, default=True, help="if to save models to wandb")
    arg_parser.add_argument('--wandb_load_model', type=str2bool, default=False, help="if to load models from wandb")
    arg_parser.add_argument('--wandb_save_curr_model_interval', type=int, default=None, help="when to save models to wandb mid step, leave to None or <0 to disable")
    arg_parser.add_argument('--wandb_load_curr_model', type=str2bool, default=False, help="if to load mid step models from wandb")
    arg_parser.add_argument('--wandb_id', type=str, default=None, help="wandb run id with highest priority when searching for pretrained ckpt")
    arg_parser.add_argument('--wandb_key', type=str, default=None, help="wandb user key")
    arg_parser.add_argument('--wandb_load_step', type=int, default=None, help="wandb step of which to load ckpt")

    ### Loss related args ###
    arg_parser.add_argument('--use_ce_unbiased', type=str2bool, default=True, help="whether to use ce unbiased loss")
    arg_parser.add_argument('--lambda_distil', type=str2list(float), default=0., help="lambda distillation")  # default lambda=10.
    arg_parser.add_argument('--distil_type', type=str, default='std', choices=['std','pseudo_filter','soft_filter','std_style','pseudo_filter_style','soft_filter_style','pseudo_filter_all_styles'], help="lambda distillation")  # default lambda=10.
    arg_parser.add_argument('--distill_with_ce', type=str2bool, default=False, help="use CE loss when distilling with pseudo-labels")
    arg_parser.add_argument('--distil_curr_style', type=str2bool, default=False, help="use KD on self-stylized domain")
    # Pseudo-labeling
    arg_parser.add_argument('--global_thresh', type=float, default=0.9, help="kd pseudo labeling global confidence threshold")
    arg_parser.add_argument('--class_keep_fract', type=float, default=0.66, help="kd pseudo labeling per-class fraction of pixels to keep")
    arg_parser.add_argument('--global_thresh_low', type=float, default=-1., help="kd pseudo labeling global low confidence threshold (consistency)")
    # Feature-level losses
    arg_parser.add_argument('--lambda_kd_feat', type=float, default=0., help="lambda kd feat level")
    arg_parser.add_argument('--lambda_kd_feat_style', type=float, default=0., help="lambda kd feat level with stylized images")
    arg_parser.add_argument('--lambda_uda_feat', type=float, default=0., help="lambda loss uda feat-level")
    # Stylization
    arg_parser.add_argument('--use_fda', type=str2bool, default=False, help='Whether to use FDA or not')
    arg_parser.add_argument('--n_images_per_style', type=int, default=1, help="number of images per style")
    arg_parser.add_argument('--L', type=float, default=0.001, help="to control size of style fft amp window")
    arg_parser.add_argument('--avg_style', type=str2bool, default=False, help='Whether to use avg style or not')
    arg_parser.add_argument('--save_or_load_avg_style', type=str2bool, default=False, help='Whether to load/save avg style or not')
    arg_parser.add_argument('--lambda_ce_style', type=float, default=0., help="lambda ce loss with stylized images")
    arg_parser.add_argument('--lambda_kd_style', type=float, default=0., help="lambda kd loss with stylized images")
    arg_parser.add_argument('--label_inpainting', type=str2bool, default=False, help='Whether to inpaint GT')
    arg_parser.add_argument('--label_inpainting_type', type=str, choices=['old','new'], default='old', help='How to inpaint the GT, ie using pseudo labels from new or old style imgs')
    arg_parser.add_argument('--label_inpainting_aux_loss', type=str2bool, default=False, help='Whether to use CE unb loss on pixels ignored by pseudo-labeling')
    arg_parser.add_argument('--ce_curr_style', type=str2bool, default=False, help='Whether to use CE on current style')
    arg_parser.add_argument('--lambda_ce_curr_style', type=float, default=None, help="Loss weight to current-style CE, leave to None to replace original CE with current-style one")

    # Misc
    arg_parser.add_argument('--gradient_clip', type=float, default=None, help="leave to None for disabling")
    arg_parser.add_argument('--gradient_check', type=str2bool, default=False, help="check if gradients go to nan")

    ### Setup related args ###
    arg_parser.add_argument('--class_set', type=str, default='city19', choices=['city19','mapillary','idd18','synthia16','shift'], help="class set")
    arg_parser.add_argument('--incr_class_split', type=str, default='CIL', choices=['standard','CIL','doubleCIL','reverseCIL'], help="type of class split")
    arg_parser.add_argument('--incr_data_split', type=str, default='Cityscapes-bdd100k-IDD', help="type of dataset split")
    arg_parser.add_argument('--single_data_split', type=str2str_none_num, default=None, help="to train only on single dataset partition of selected data split")
    arg_parser.add_argument('--use_joint_data', type=str2bool, default=False, help="to use all domains at once")
    arg_parser.add_argument('--incr_val_data', type=str, default='before', choices=['all','before','current'], help="which domains to validate on")
    arg_parser.add_argument('--encoder_weights_from_past', type=str2bool, default=True, help="if to load encoder weights from past step or not")

    ### Path related args ###
    arg_parser.add_argument('--data_root_path', type=str, default='datasets', help="the path to datasets")
    arg_parser.add_argument('--list_root_path', type=str, default='datasets', help="the path to split files")
    arg_parser.add_argument('--checkpoint_dir', default="./log/CBI_full", help="the path to ckpt file")

    ### Model related args ###
    arg_parser.add_argument('--model', default='deeplabv2-resnet101', help="model to use, '{}-{}'.format(segmenter,backbone) format to be specified")
    arg_parser.add_argument('--imagenet_pretrained', type=str2bool, default=True,  help="whether to apply imagenet pretrained weights")
    arg_parser.add_argument('--freeze_encoder', type=str2bool, default=False, help="whether to freeze encoder from step>0")
    arg_parser.add_argument('--pretrained_ckpt_file', type=str, default=None, help="whether to apply pretrained checkpoint")
    arg_parser.add_argument('--show_num_images', type=int, default=1, help="show how many images during validate")
    arg_parser.add_argument('--loss_log_step_interval', type=int, default=100, help="interval for logging")

    ### Dataset related args ###
    arg_parser.add_argument('--resize_to', default=',512', type=str2list(int),  help='w,h of reside image')
    arg_parser.add_argument('--random_resize', default=None, type=str2str_none_num,  help='interval for random rescaling factor')
    arg_parser.add_argument('--crop_to', default='512,512', type=str2list(int), help='crop size of image')
    arg_parser.add_argument('--val_at_orig_size', type=str2bool, default=True, help="whether or not to evaluate at original input size")
    arg_parser.add_argument('--data_loader_workers', default=16, type=int, help='num_workers of Dataloader')
    arg_parser.add_argument('--pin_memory', default=2, type=int, help='pin_memory of Dataloader')
    arg_parser.add_argument('--random_mirror', default=True, type=str2bool, help='add random_mirror')
    arg_parser.add_argument('--gaussian_blur', default=True, type=str2bool, help='add gaussian_blur')
    arg_parser.add_argument('--random_translation', default=False, type=str2bool, help='add random_translation')
    arg_parser.add_argument('--numpy_transform', default=True, type=str2bool, help='image transform with numpy style')

    ### Training related args ###
    arg_parser.add_argument('--seed', default=12345, type=int, help='random seed')
    arg_parser.add_argument('--device', type=str, default='0', help="'cpu' if on cpu, else integer number to set gpu number")
    arg_parser.add_argument('--batch_size', default=1, type=int, help='input batch size')
    arg_parser.add_argument('--optim', default="SGD", type=str, help='optimizer')
    arg_parser.add_argument('--momentum', type=float, default=0.9)
    arg_parser.add_argument('--weight_decay', type=float, default=5e-4)
    arg_parser.add_argument('--lr', type=float, default=2.5e-4, help="init learning rate")
    arg_parser.add_argument('--lr_incremental', type=float, default=2.5e-4, help="init learning rate for incremental step > 0")
    arg_parser.add_argument('--poly_power', type=float, default=0.9, help="poly_power")
    arg_parser.add_argument('--max_iter_poly_decay', type=int, default=0, help="total number of steps of lr polynomial decay; if < 0, lr is kept fixed; if == 0 equal to total training steps")
    arg_parser.add_argument('--iter_max', type=int, default=None, help="the number of training iteration")
    arg_parser.add_argument('--epochs', type=int, default=None, help="the number of training epochs")
    arg_parser.add_argument('--epochs_incremental', type=str2list(int), default=None, help="the number of training epochs for incremental step > 0")
    arg_parser.add_argument('--val_interval', type=int, default=None, help="interval in EPOCHS between subsequent validations, set to None to validate at the end of each epoch")
    arg_parser.add_argument('--val_all_data_interval', type=int, default=None, help="interval in EPOCHS between subsequent validations on all selected datasets (see 'incr_val_data') within the incremental step (val at end always done)")
    arg_parser.add_argument('--dataset_dg_test', type=str, default=None, help="dataset to use for domain generalization testing")
    arg_parser.add_argument('--test_dg_data_interval', type=int, default=None, help="interval in EPOCHS between subsequent validations on dg dataset (see 'dataset_dg_test') within the incremental step (val at end always done)")

    return arg_parser

def init_args(args):
    """
    :param args: input args
    :return: updated args and the logger
    """
    if os.path.exists(args.checkpoint_dir):
        args.start_incr_step = 0
    else:
        args.start_incr_step = 0
        try:
            os.mkdir(args.checkpoint_dir)
        except FileNotFoundError:
            print('Missing parent folder in path:  {}'.format(args.checkpoint_dir))
            exit()

    # logger configure
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    if args.use_wandb:
        assert args.wandb_key is not None
        os.environ["WANDB_API_KEY"] = args.wandb_key

    if args.debug:
        args.epochs = args.epochs_incremental = 1

    #set seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)
    torch.backends.cudnn.benchmark = not args.deterministic  # False for reproducibility (if True, the CuDNN library will benchmark several algorithms and pick that which it found to be fastest)
    torch.backends.cudnn.deterministic = args.deterministic  # if True, will only allow those CuDNN algorithms that are (believed to be) deterministic

    return args, logger

def check_termination(args, logger):
    """
    Compare trainer and input params to determine whether to start training or not. Useful to avoid to re-train over already existing simulation
    :param args: input args
    :param logger: logger object
    """

    # If restart training from step 0, clear the ckpt folder from old files
    if args.start_incr_step == 0:
        shutil.rmtree(args.checkpoint_dir, ignore_errors=True)
        try:
            os.mkdir(args.checkpoint_dir)
        except FileExistsError:
            print('WARNING: Files {} have not been removed'.format(glob.glob('{}/*'.format(args.checkpoint_dir))))
        except FileNotFoundError:
            print('Missing parent folder in path:  {}'.format(args.checkpoint_dir))
            exit()

    # Set logging params
    fh = logging.FileHandler(os.path.join(args.checkpoint_dir, 'train_log.txt'))
    ch = logging.StreamHandler()
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    fh.setFormatter(formatter)
    ch.setFormatter(formatter)
    logger.addHandler(fh)
    logger.addHandler(ch)