# From https://github.com/Eromera/erfnet_pytorch/blob/master/train/erfnet.py
#######################

import torch
import torch.nn as nn
import torch.nn.functional as F

class DownsamplerBlock (nn.Module):
    def __init__(self, ninput, noutput):
        super().__init__()

        self.conv = nn.Conv2d(ninput, noutput-ninput, (3, 3), stride=2, padding=1, bias=True)
        self.pool = nn.MaxPool2d(2, stride=2)
        self.bn = nn.BatchNorm2d(noutput, eps=1e-3)

    def forward(self, input):
        output = torch.cat([self.conv(input), self.pool(input)], 1)
        output = self.bn(output)
        return F.relu(output)


class non_bottleneck_1d (nn.Module):
    def __init__(self, chann, dropprob, dilated):
        super().__init__()

        self.conv3x1_1 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1,0), bias=True)

        self.conv1x3_1 = nn.Conv2d(chann, chann, (1,3), stride=1, padding=(0,1), bias=True)

        self.bn1 = nn.BatchNorm2d(chann, eps=1e-03)

        self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1))

        self.conv1x3_2 = nn.Conv2d(chann, chann, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1, dilated))

        self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)

        self.dropout = nn.Dropout2d(dropprob)


    def forward(self, input):

        output = self.conv3x1_1(input)
        output = F.relu(output)
        output = self.conv1x3_1(output)
        output = self.bn1(output)
        output = F.relu(output)

        output = self.conv3x1_2(output)
        output = F.relu(output)
        output = self.conv1x3_2(output)
        output = self.bn2(output)

        if (self.dropout.p != 0):
            output = self.dropout(output)

        return F.relu(output+input)    #+input = identity (residual connection)


class Encoder(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.initial_block = DownsamplerBlock(3,16)

        self.layers = nn.ModuleList()

        self.layers.append(DownsamplerBlock(16,64))

        for x in range(0, 5):    #5 times
            self.layers.append(non_bottleneck_1d(64, 0.03, 1))

        self.layers.append(DownsamplerBlock(64,128))

        for x in range(0, 2):    #2 times
            self.layers.append(non_bottleneck_1d(128, 0.3, 2))
            self.layers.append(non_bottleneck_1d(128, 0.3, 4))
            self.layers.append(non_bottleneck_1d(128, 0.3, 8))
            self.layers.append(non_bottleneck_1d(128, 0.3, 16))

        #Only in encoder mode:
        self.output_conv = nn.Conv2d(128, num_classes, 1, stride=1, padding=0, bias=True)

    def forward(self, input, predict=False):
        output = self.initial_block(input)

        for layer in self.layers:
            output = layer(output)

        if predict:
            output = self.output_conv(output)

        return output


class UpsamplerBlock (nn.Module):
    def __init__(self, ninput, noutput):
        super().__init__()
        self.conv = nn.ConvTranspose2d(ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True)
        self.bn = nn.BatchNorm2d(noutput, eps=1e-3)

    def forward(self, input):
        output = self.conv(input)
        output = self.bn(output)
        return F.relu(output)

class Decoder (nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.layers = nn.ModuleList()

        self.layers.append(UpsamplerBlock(128,64))
        self.layers.append(non_bottleneck_1d(64, 0, 1))
        self.layers.append(non_bottleneck_1d(64, 0, 1))

        self.layers.append(UpsamplerBlock(64,16))
        self.layers.append(non_bottleneck_1d(16, 0, 1))
        self.layers.append(non_bottleneck_1d(16, 0, 1))

        self.output_conv = nn.ConvTranspose2d( 16, num_classes, 2, stride=2, padding=0, output_padding=0, bias=True)

    def forward(self, input):
        output = input

        for layer in self.layers:
            output = layer(output)

        output = self.output_conv(output)

        return output

#ERFNet
class ERFNet(nn.Module):
    def __init__(self, num_classes, encoder=None):  #use encoder to pass pretrained encoder
        super().__init__()

        if encoder is None:
            self.encoder = Encoder(num_classes)
        else:
            self.encoder = encoder
        self.decoder = Decoder(num_classes)

    def __str__(self):
        return 'ERFNet'

    def forward(self, x):
        feat = self.encoder(x)
        out = self.decoder(feat)
        return {'features': feat, 'out': out}


    def optim_parameters(self, lr, logger_fn=None):
        if logger_fn: logger_fn(f'Backbone lr set to {lr}, Classifier lr set to {lr}')
        return [{'params': self.encoder.parameters(), 'lr': lr},
                {'params': self.decoder.parameters(), 'lr': lr}]

    def freeze(self, network):
        assert network in ('encoder', 'decoder')
        for n, p in self.named_parameters():
            if network in n:
                p.requires_grad=False

    def load_layers(self, checkpoint, include_layers=(), exclude_layers=(), network=''):
        new_params = self.state_dict().copy()
        not_loaded = []
        for pname,pvalue in checkpoint.items():
            if set(exclude_layers) & set([pname]):  # need to match the full pname
                not_loaded.append(pname)
                continue
            if set(include_layers) & set(pname.split('.')):  # need to match one token of pname
                new_params[pname] = pvalue
            else:
                not_loaded.append(pname)
        if not_loaded: print(f'WARNING (loading {network}): {not_loaded} have not been loaded into current model')
        self.load_state_dict(new_params)

    def load_network(self, checkpoint, network, load_partial_decoder=True):
        encoder = ['encoder']
        decoder = ['decoder']
        load_last_layer, exclude_layers = False, []
        if network == 'all':
            layers = encoder + decoder
        elif network == 'encoder':
            layers = encoder
            if load_partial_decoder:
                layers += decoder
                exclude_layers = ['decoder.output_conv.weight', 'decoder.output_conv.bias', 'encoder.output_conv.weight', 'encoder.output_conv.bias']
                load_last_layer = True
        elif network == 'decoder':
            layers = decoder
        else:
            raise ValueError(f'Cannot load parameters from {network}')
        self.load_layers(checkpoint, include_layers=layers, exclude_layers=exclude_layers, network=network)
        if load_last_layer: self.load_last_layer_weights(checkpoint)

    def load_last_layer_weights(self, checkpoint):
        last_layer = {'decoder.output_conv.weight': self.decoder.output_conv.weight, 'decoder.output_conv.bias': self.decoder.output_conv.bias}
        for pname,pvalue in checkpoint.items():
            if pname in last_layer.keys():
                if pvalue.dim() == 1:  # bias
                    c = pvalue.size(0)
                    last_layer[pname][:c].data.copy_(pvalue)
                else:  # weight
                    c = pvalue.size(1)
                    last_layer[pname][:,:c].data.copy_(pvalue)


def erfnet(num_classes=19, pretrained=False, logger_fn=None, device='cuda', **kwargs):
    if logger_fn: logger_fn(f'Loading ERFNet')
    model = ERFNet(num_classes)
    if pretrained:
        import urllib.request, os
        url = 'https://github.com/Eromera/erfnet_pytorch/raw/master/trained_models/erfnet_encoder_pretrained.pth.tar'
        path = './pretrained_model/erfnet_encoder.pth.tar'
        if not os.path.isfile(path):
            urllib.request.urlretrieve(url,path)
        ckpt = {}
        for k,v in torch.load(path, map_location=device)['state_dict'].items():
            ckpt[''.join(k.split('module.features.'))] = v
        check = model.load_state_dict(ckpt, strict=False)
        missing, unexpected = check.missing_keys, check.unexpected_keys
        missing = [m for m in missing if 'decoder' not in m]
        print('!!! WARNING !!!')
        if len(missing) > 0: print(f'Missing keys: {missing}')
        if len(unexpected) > 0: print(f'Unexpected keys: {unexpected}')
    return model