from . import *

from datasets.dataloader import Dataset
from datasets.dataset_utils import *
from datasets.dataset_config import DatasetConfig

ImageFile.LOAD_TRUNCATED_IMAGES = True

class BDD100k_Dataset(Dataset):

    Label = namedtuple('Label', [

        'name',  # The identifier of this label, e.g. 'car', 'person', ... .
        # We use them to uniquely name a class

        'id',  # An integer ID that is associated with this label.
        # The IDs are used to represent the label in ground truth images
        # An ID of -1 means that this label does not have an ID and thus
        # is ignored when creating ground truth images (e.g. license plate).
        # Do not modify these IDs, since exactly these IDs are expected by the
        # evaluation server.

        'trainId',
        # Feel free to modify these IDs as suitable for your method. Then create
        # ground truth images with train IDs, using the tools provided in the
        # 'preparation' folder. However, make sure to validate or submit results
        # to our evaluation server using the regular IDs above!
        # For trainIds, multiple labels might have the same ID. Then, these labels
        # are mapped to the same class in the ground truth images. For the inverse
        # mapping, we use the label that is defined first in the list below.
        # For example, mapping all void-type classes to the same ID in training,
        # might make sense for some approaches.
        # Max value is 255!

        'category',  # The name of the category that this label belongs to

        'categoryId',
        # The ID of this category. Used to create ground truth images
        # on category level.

        'hasInstances',
        # Whether this label distinguishes between single instances or not

        'ignoreInEval',
        # Whether pixels having this class as ground truth label are ignored
        # during evaluations or not

        'color',  # The color of this label
    ])

    original_labels = [
        #       name                     id    trainId   category            catId     hasInstances   ignoreInEval   color
        Label(  'unlabeled'            ,  0 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
        Label(  'dynamic'              ,  1 ,      255 , 'void'            , 0       , False        , True         , (111, 74,  0) ),
        Label(  'ego vehicle'          ,  2 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
        Label(  'ground'               ,  3 ,      255 , 'void'            , 0       , False        , True         , ( 81,  0, 81) ),
        Label(  'static'               ,  4 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
        Label(  'parking'              ,  5 ,      255 , 'flat'            , 1       , False        , True         , (250,170,160) ),
        Label(  'rail track'           ,  6 ,      255 , 'flat'            , 1       , False        , True         , (230,150,140) ),
        Label(  'road'                 ,  7 ,        0 , 'flat'            , 1       , False        , False        , (128, 64,128) ),
        Label(  'sidewalk'             ,  8 ,        1 , 'flat'            , 1       , False        , False        , (244, 35,232) ),
        Label(  'bridge'               ,  9 ,      255 , 'construction'    , 2       , False        , True         , (150,100,100) ),
        Label(  'building'             , 10 ,        2 , 'construction'    , 2       , False        , False        , ( 70, 70, 70) ),
        Label(  'fence'                , 11 ,        4 , 'construction'    , 2       , False        , False        , (190,153,153) ),
        Label(  'garage'               , 12 ,      255 , 'construction'    , 2       , False        , True         , (180,100,180) ),
        Label(  'guard rail'           , 13 ,      255 , 'construction'    , 2       , False        , True         , (180,165,180) ),
        Label(  'tunnel'               , 14 ,      255 , 'construction'    , 2       , False        , True         , (150,120, 90) ),
        Label(  'wall'                 , 15 ,        3 , 'construction'    , 2       , False        , False        , (102,102,156) ),
        Label(  'banner'               , 16 ,      255 , 'object'          , 3       , False        , True         , (250,170,100) ),
        Label(  'billboard'            , 17 ,      255 , 'object'          , 3       , False        , True         , (220,220,250) ),
        Label(  'lane divider'         , 18 ,      255 , 'object'          , 3       , False        , True         , (255, 165, 0) ),
        Label(  'parking sign'         , 19 ,      255 , 'object'          , 3       , False        , False        , (220, 20, 60) ),
        Label(  'pole'                 , 20 ,        5 , 'object'          , 3       , False        , False        , (153,153,153) ),
        Label(  'polegroup'            , 21 ,      255 , 'object'          , 3       , False        , True         , (153,153,153) ),
        Label(  'street light'         , 22 ,      255 , 'object'          , 3       , False        , True         , (220,220,100) ),
        Label(  'traffic cone'         , 23 ,      255 , 'object'          , 3       , False        , True         , (255, 70,  0) ),
        Label(  'traffic device'       , 24 ,      255 , 'object'          , 3       , False        , True         , (220,220,220) ),
        Label(  'traffic light'        , 25 ,        6 , 'object'          , 3       , False        , False        , (250,170, 30) ),
        Label(  'traffic sign'         , 26 ,        7 , 'object'          , 3       , False        , False        , (220,220,  0) ),
        Label(  'traffic sign frame'   , 27 ,      255 , 'object'          , 3       , False        , True         , (250,170,250) ),
        Label(  'terrain'              , 28 ,        9 , 'nature'          , 4       , False        , False        , (152,251,152) ),
        Label(  'vegetation'           , 29 ,        8 , 'nature'          , 4       , False        , False        , (107,142, 35) ),
        Label(  'sky'                  , 30 ,       10 , 'sky'             , 5       , False        , False        , ( 70,130,180) ),
        Label(  'person'               , 31 ,       11 , 'human'           , 6       , True         , False        , (220, 20, 60) ),
        Label(  'rider'                , 32 ,       12 , 'human'           , 6       , True         , False        , (255,  0,  0) ),
        Label(  'bicycle'              , 33 ,       18 , 'vehicle'         , 7       , True         , False        , (119, 11, 32) ),
        Label(  'bus'                  , 34 ,       15 , 'vehicle'         , 7       , True         , False        , (  0, 60,100) ),
        Label(  'car'                  , 35 ,       13 , 'vehicle'         , 7       , True         , False        , (  0,  0,142) ),
        Label(  'caravan'              , 36 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0, 90) ),
        Label(  'motorcycle'           , 37 ,       17 , 'vehicle'         , 7       , True         , False        , (  0,  0,230) ),
        Label(  'trailer'              , 38 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0,110) ),
        Label(  'train'                , 39 ,       16 , 'vehicle'         , 7       , True         , False        , (  0, 80,100) ),
        Label(  'truck'                , 40 ,       14 , 'vehicle'         , 7       , True         , False        , (  0,  0, 70) ),
    ]

    id2name_original = {lb.id:lb.name for lb in original_labels}

    def __init__(self,
                 args,
                 data_path='./datasets/bdd100k',
                 list_path='./datasets/bdd100k/standard_split',
                 split='train',  # train->training, test->testing or val->validation
                 incr_class_split = 'standard',
                 incr_class_step = 0,
                 crop_to=None,
                 resize_to=None,
                 make_val_size_div=False,
                 random_resize=None,
                 training=True,
                 resize_mask=True,
                 class_set='city19'):

        kwargs = {k:v for k,v in locals().items() if k in inspect.getargspec(self.__init__).args and k not in ('self',)}
        self.dataset_name = 'bdd100k'
        super().__init__(**kwargs)

        item_list_filepath = os.path.join(self.list_path, self.split + ".txt")
        if not os.path.isfile(item_list_filepath):
            print(f'WARNING: split file {item_list_filepath} not found, creating it...')
            self.create_split_file(root=data_path,
                                   path_to_imgs=os.path.join(data_path,'images/10k'),
                                   path_to_labels=os.path.join(data_path,'labels/sem_seg/masks'),
                                   valid_splits=(split,),
                                   list_path=list_path)
        self.items = [id for id in open(item_list_filepath)]

        self.id2trainid = DatasetConfig.get_id2iid(dataset=self.dataset_name, incr_class_split=incr_class_split, class_set=class_set, incr_class_step=incr_class_step, training=training)
        self.image_mean = np.array(DatasetConfig.mean_bgr[self.dataset_name], dtype=np.float32)

        self.print_for_debug = True if args.debug else False

        if args.debug: print(f'DEBUG: {self.dataset_name} {self.split} -> item_list_filepath: {item_list_filepath} , first item: {self.items[0]}')
        if args.debug: print(f'{len(self.items)} num images in {self.dataset_name} {self.split} set have been loaded')

        if not args.debug and split=='train': self._remove_unlabeled_images()


    def __getitem__(self, item):
        id_img, id_gt = self.items[item].strip('\n').split(' ')

        # image
        image_path = os.path.join(self.data_path, id_img)
        image = Image.open(image_path).convert("RGB")
        w,h = image.size
        # gt
        gt_image_path = os.path.join(self.data_path, id_gt)
        gt_image = Image.open(gt_image_path)
        if self.doing_label_check: return gt_image

        if self.split == "train" and self.training:
            image, gt_image = self._train_sync_transform(image, gt_image)
        else:
            image, gt_image = self._val_sync_transform(image, gt_image, resize_mask=self.resize_mask)

        if self.print_for_debug:
            self.print_for_debug = False
            print(f'DEBUG: {self.dataset_name} {self.split} -> item: {item}, image_path: {image_path}')
            print(f'DEBUG: {self.dataset_name} {self.split} -> item: {item}, gt_path: {gt_image_path}')
            print(f'DEBUG: GT values in {gt_image[gt_image>=0].unique().sort()[0].tolist()}')
        return image, gt_image, id_img.split('/')[-1].split('.')[0], (w,h)


    @staticmethod
    def create_split_file(root, path_to_imgs, path_to_labels, valid_splits=('train', 'val'), list_path=None):
        for path, subdirs, files in os.walk(path_to_imgs):
            split = os.path.split(path)[-1]
            if split not in valid_splits: continue

            if list_path is None: list_path = os.path.join(root, 'splits')
            try: os.makedirs(list_path)
            except: pass

            with open(os.path.join(list_path, f'{split}.txt'), 'w') as f:
                for name in files:
                    fpi = os.path.join(path, name)
                    pi = os.path.relpath(fpi, root)
                    fpl = os.path.join(path_to_labels, os.path.relpath(fpi.replace('jpg', 'png'), path_to_imgs))
                    pl = os.path.relpath(fpl, root)
                    s = f'{pi} {pl}\n'
                    f.write(s)