from train import *
import os, sys
import argparse
from utils import add_train_args, init_args


class Test(Trainer):
    def __int__(self, args, kwargs):
        super.__init__(*args, **kwargs)
        self.ckpt = None


    def model_init(self, step):

        # Update class names and total number
        if step > 0 and not self.class_names:  # if resuming training from step > 0, recover class names and number from previous steps
            classes_per_old_step = []
            for s in range(step):
                class_list = [k for k,v in self.name2iid[s].items() if v != self.args.ignore_label and k not in self.class_names] # unknown class may be already in list
                self.num_classes += len(class_list)
                self.class_names += class_list
                classes_per_old_step.append(class_list)
        # update for current step
        class_list = [k for k,v in self.name2iid[step].items() if v != self.args.ignore_label and k not in self.class_names] # unknown class may be already in list
        self.num_classes += len(class_list)
        self.class_names += class_list

        # Model init
        self.logger.info('  Getting model...')
        logger_fn = lambda input_str: self.logger.info(f'    {input_str}')
        self.model, params = get_model(self.args, self.num_classes, imagenet_pretrained=False, logger_fn=logger_fn, device=self.device)
        self.logger.info('  Loading model on gpu...')
        self.model.to(self.device)

        # Initialize or update eval obj
        self.logger.info('  Setting eval framework...')
        if step > 0 and not self.Eval: # if resuming training from step > 0, initialize eval obj and set class names of previous steps
            self.Eval = Eval(classes_per_old_step[0])
            for s in range(1,step):
                self.Eval.update(classes_per_old_step[1])
        if self.Eval:
            self.Eval.update(class_list)
        else:
            self.Eval = Eval(class_list)

        self.logger.info('  Loading local checkpoint...')
        self.ckpt = torch.load(f'{self.args.checkpoint_dir}/model_step{step}.pth')['state_dict']

        self.model.load_network(self.ckpt, network='all')
        if step not in self.result_dict: self.result_dict[step] = {}

    def test_init(self):

        self.logger.info = print

        # display args details
        self.logger.info("Global configuration as follows:")
        for key, val in vars(self.args).items():
            self.logger.info("{:22} {}".format(key, val))

        # setup writer
        self.writer = SummaryWriter(self.args.checkpoint_dir)

        class_split = [['{}:{}'.format(v,k) for k,v in d.items() if v != -1] for d in self.name2iid]
        for step in range(self.num_incr_steps):
            self.logger.info('Step{} ->'.format(step) + ' Data split: {}'.format(self.data_splits[step].split('/')[0] + '-' + self.data_splits[step].split('/')[-1]))
            self.logger.info('{:8}'.format('') + ' Class split: ' + ' '.join([' {}']*len(class_split[step])).format(*class_split[step]))

        # choose cuda
        if self.cuda:
            current_device = torch.cuda.current_device()
            self.logger.info("This model will run on {}".format(torch.cuda.get_device_name(current_device)))
        else:
            self.logger.info("This model will run on CPU")

        for step in range(self.args.start_incr_step, self.num_incr_steps):

            # set seed at the beginning of each incremental step
            s = step * self.args.seed
            set_seed(s)

            if self.args.step_to_stop is not None and step > self.args.step_to_stop: break

            self.logger.info('')
            self.logger.info('### INCREMENTAL STEP {} ###'.format(step))
            self.current_incr_step = step

            self.logger.info('Model initialization...')
            self.model_init(step)
            self.logger.info('Done')

            if self.args.incr_val_data in ['before','all']:
                if self.args.incr_val_data == 'before':
                    test_steps = list(range(step+1))
                else:
                    test_steps = [i for i,el in enumerate(list_paths[self.args.incr_data_split])]
                if self.args.use_joint_data:
                    test_steps = [0] + test_steps  # we want to validate to the first dataset as well
                for vs in test_steps:
                    self.dataset_test_init(step=vs)
                    if step not in self.result_dict: self.result_dict[step] = {}
                    self.result_dict[step][self.val_dataloader.dataset.dataset_name] = self.validate(log_wandb_imgs=True)
            if self.args.dataset_dg_test is not None:
                self.dataset_dg_test(step)
                if step not in self.result_dict: self.result_dict[step] = {}
                self.result_dict[step][self.val_dataloader.dataset.dataset_name] = self.validate(log_wandb_imgs=True)

            if self.args.incr_class_split not in ('standard',):
                self.compute_and_print_global_metrics(step)

            if self.args.use_wandb:
                self.wandb_logger.step_offset += self.current_iter

        self.writer.close()


    def dataset_test_init(self, step):
        """
        Load testing DATASET of specified incremental step. CLASS SET defined by last call to 'dataset_init()' fn
        :param step: step index of incremental dataset for which to load the test dataset
        """
        self.data_kwargs['incr_class_step'] = self.current_incr_step if not self.single_class_step else 0
        if is_single_dataset_split(self.args.incr_data_split):
            data_split = f'{get_single_dataset_split(self.args.incr_data_split)}/{{}}'
        else:
            data_split = '{}/splits/standard'
        data_split = data_split.format(list_paths[self.args.incr_data_split][step])
        self.data_kwargs['list_path'] = os.path.join(self.args.list_root_path, data_split)

        dataset = data_split.split('/')[0].lower()
        self.data_kwargs['data_path'] = os.path.join(self.args.data_root_path, data_paths[fix_data_name(dataset)])
        self.data_gen = self.available_datasets[fix_data_name(dataset)]
        self.data_kwargs['make_val_size_div'] = self.args.crop_to is not None and issubclass(self.data_gen,Mapillary_Dataset) and self.args.model in ('erfnet',)
        self.curr_data_split_val = self.data_splits[step].split('/')[0] + '-' + self.data_splits[step].split('/')[-1]

        test_data_gen = self.data_gen(self.args, split='val', training=False, resize_mask=not self.args.val_at_orig_size, **self.data_kwargs)
        self.val_dataloader = data.DataLoader(test_data_gen, shuffle=False, batch_size=1, **self.dataloader_kwargs)
        self.validation_steps = len(test_data_gen) + 1



if __name__ == '__main__':

    file_os_dir = os.path.dirname(os.path.realpath(__file__))
    os.chdir(file_os_dir)

    arg_parser = argparse.ArgumentParser()
    arg_parser = add_train_args(arg_parser)

    args = arg_parser.parse_args()
    args, logger = init_args(args)

    filepath = f'{args.checkpoint_dir}/test_log.txt'

    orig_stdout = sys.stdout
    f = open(filepath, 'w')
    sys.stdout = f

    agent = Test(args=args, logger=logger)
    agent.test_init()
    f.flush()
    f.close()