from . import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

affine_par = True

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, bn_momentum=0.1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)  # change
        self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)

        padding = dilation
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,  # change
                               padding=padding, bias=False, dilation=dilation)
        self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)

        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par)

        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class Classifier_Module(nn.Module):
    def __init__(self, inplanes, dilation_series, padding_series, num_classes):
        super(Classifier_Module, self).__init__()
        self.conv2d_list = nn.ModuleList()
        for dilation, padding in zip(dilation_series, padding_series):
            self.conv2d_list.append(
                nn.Conv2d(inplanes, num_classes, kernel_size=3, stride=1, padding=padding, dilation=dilation, bias=True))

        for m in self.conv2d_list:
            m.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.conv2d_list[0](x)
        for i in range(len(self.conv2d_list) - 1):
            out += self.conv2d_list[i + 1](x)
        return out

class DeeplabResnet(nn.Module):
    def __init__(self, block, layers, num_classes):
        self.inplanes = 64
        super(DeeplabResnet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64, affine=affine_par)
        for i in self.bn1.parameters():
            i.requires_grad = False
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True)  # change
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
        self.layer5 = self._make_pred_layer(Classifier_Module, 2048, [6, 12, 18, 24], [6, 12, 18, 24], num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, 0.01)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion or dilation == 2 or dilation == 4:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion, affine=affine_par))
        layers = []
        layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation))

        return nn.Sequential(*layers)

    def _make_pred_layer(self, block, inplanes, dilation_series, padding_series, num_classes):
        return block(inplanes, dilation_series, padding_series, num_classes)

    def forward(self, x):
        input_size = x.size()[2:]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)

        x = self.layer3(x)

        x = self.layer4(x)
        x = self.layer5(x)
        x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=True)

        return x

    def get_1x_lr_params_NOscale(self):
        """
        This generator returns all the parameters of the net except for
        the last classification layer. Note that for each batchnorm layer,
        requires_grad is set to False in deeplab_resnet.py, therefore this function does not return
        any batchnorm parameter
        """
        b = []

        b.append(self.conv1)
        b.append(self.bn1)
        b.append(self.layer1)
        b.append(self.layer2)
        b.append(self.layer3)
        b.append(self.layer4)

        for i in range(len(b)):
            for j in b[i].modules():
                jj = 0
                for k in j.parameters():
                    jj += 1
                    if k.requires_grad:
                        yield k

    def get_10x_lr_params(self):
        """
        This generator returns all the parameters for the last layer of the net,
        which does the classification of pixel into classes
        """
        b = []
        b.append(self.layer5.parameters())

        for j in range(len(b)):
            for i in b[j]:
                yield i

    def optim_parameters(self, lr, logger_fn=None):
        if logger_fn: logger_fn(f'Backbone lr set to {lr}, Classifier lr set to {10 * lr}')
        return [{'params': self.get_1x_lr_params_NOscale(), 'lr': lr},
                {'params': self.get_10x_lr_params(), 'lr': 10 * lr}]

    def freeze(self, network):
        n2n = {'encoder': ['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'layer4'], 'decoder': ['layer5']}
        assert network in n2n
        for n, p in self.named_parameters():
            for l in n2n[network]:
                if l in n.split('.')[0]:
                    p.requires_grad = False
        raise NotImplementedError('Method has been implemented but needs to be checked')

    def load_layers(self, checkpoint, include_layers=(), exclude_layers=(), network=''):
        new_params = self.state_dict().copy()
        not_loaded = []
        for pname,pvalue in checkpoint.items():
            if set(exclude_layers) & set([pname]):  # need to match the full pname
                continue
            if set(include_layers) & set(pname.split('.')):  # need to match one token of pname
                new_params[pname] = pvalue
            else:
                not_loaded.append(pname)
        if not_loaded: print(f'WARNING (loading {network}): {not_loaded} have not been loaded into current model')
        self.load_state_dict(new_params)

    def load_network(self, checkpoint, network, load_partial_decoder=True):
        encoder, decoder = ['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'layer4'], ['layer5']
        load_last_layer, exclude_layers = False, []
        if network == 'all':
            layers = encoder + decoder
        elif network == 'encoder':
            layers = encoder
            if load_partial_decoder:
                load_last_layer = True
        elif network == 'decoder':
            layers = decoder
        else:
            raise ValueError(f'Cannot load parameters from {network}')
        self.load_layers(checkpoint, include_layers=layers, exclude_layers=exclude_layers, network=network)
        if load_last_layer: self.load_last_layer_weights(checkpoint)

    def load_last_layer_weights(self, checkpoint):
        for i, conv in enumerate(self.layer5.conv2d_list):
            c = checkpoint[f'layer5.conv2d_list.{i}.weight'].size(0)
            conv.weight[:c].data.copy_(checkpoint[f'layer5.conv2d_list.{i}.weight'])
            conv.bias[:c].data.copy_(checkpoint[f'layer5.conv2d_list.{i}.bias'])


class DeeplabVGG(nn.Module):
    def __init__(self, num_classes, restore_from=None, pretrained=False, device='cuda'):
        super(DeeplabVGG, self).__init__()
        vgg = models.vgg16()
        # vgg = VGG
        if pretrained:
            vgg.load_state_dict(torch.load(restore_from, map_location=device))

        features, classifier = list(vgg.features.children()), list(vgg.classifier.children())

        #remove pool4/pool5
        features = nn.Sequential(*(features[i] for i in list(range(23))+list(range(24,30))))

        for i in [23,25,27]:
            features[i].dilation = (2,2)
            features[i].padding = (2,2)

        fc6 = nn.Conv2d(512, 1024, kernel_size=3, padding=4, dilation=4)
        fc7 = nn.Conv2d(1024, 1024, kernel_size=3, padding=4, dilation=4)

        self.features = nn.Sequential(*([features[i] for i in range(len(features))] + [ fc6, nn.ReLU(inplace=True), fc7, nn.ReLU(inplace=True)]))

        self.classifier = Classifier_Module(1024, [6,12,18,24],[6,12,18,24],num_classes)

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

    def get_1x_lr_params_NOscale(self):
        """
        This generator returns all the parameters of the net except for
        the last classification layer. Note that for each batchnorm layer,
        requires_grad is set to False in deeplab_resnet.py, therefore this function does not return
        any batchnorm parameter
        """
        b = []

        b.append(self.features)

        for i in range(len(b)):
            for j in b[i].modules():
                jj = 0
                for k in j.parameters():
                    jj += 1
                    if k.requires_grad:
                        yield k

    def get_10x_lr_params(self):
        """
        This generator returns all the parameters for the last layer of the net,
        which does the classification of pixel into classes
        """
        b = []
        b.append(self.classifier.parameters())

        for j in range(len(b)):
            for i in b[j]:
                yield i

    def optim_parameters(self, lr, logger_fn=None):
        if logger_fn: logger_fn(f'Backbone lr set to {lr}, Classifier lr set to {lr}')
        return [{'params': self.get_1x_lr_params_NOscale(), 'lr': lr},
                {'params': self.get_10x_lr_params(), 'lr': lr}]  #########

    def load_network(self, checkpoint, network, load_partial_decoder=True):
        raise NotImplementedError

    def load_last_layer_weights(self, checkpoint):
        raise NotImplementedError


class DeeplabResnetFeat(DeeplabResnet):

    def forward(self, x):
        input_size = x.size()[2:]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)

        x = self.layer3(x)

        x = self.layer4(x)
        feat = x
        x1 = self.layer5(x)  # classifier module
        x1 = F.interpolate(x1, size=input_size, mode='bilinear', align_corners=True)

        result = OrderedDict()
        result["features"] = feat
        result["out"] = x1

        return result

class DeeplabVGGFeat(DeeplabVGG):

    def forward(self, x):

        input_size = x.size()[2:]
        x = self.features(x)
        feat = x
        x1 = self.classifier(x)
        x1 = F.interpolate(x1, size=input_size, mode='bilinear', align_corners=True)

        result = OrderedDict()
        result["features"] = feat
        result["out"] = x1

        return result


def deeplabv2(num_classes, backbone, pretrained=True, logger_fn=None, device='cuda'):

    backbones = ('resnet101','vgg16')
    backbone = backbone.lower()
    assert backbone in backbones, f'{backbone} not among available backbones {backbones}'

    if backbone == 'resnet101':
        if logger_fn: logger_fn('Creating DeeplabV2-ResNet101 model...')
        model = DeeplabResnetFeat(Bottleneck, [3, 4, 23, 3], num_classes)
        if pretrained:
            restore_from = './pretrained_model/DeepLab_resnet_pretrained_init-f81d91e8.pth'
            if logger_fn: logger_fn('Loading pretrained ckpt...')
            saved_state_dict = torch.load(restore_from, map_location=device)
            if logger_fn: logger_fn('Loading selected weights from ckpt into model...')
            new_params = model.state_dict().copy()
            for i in saved_state_dict:
                i_parts = i.split('.')
                if not i_parts[1] == 'layer5':
                    new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
            model.load_state_dict(new_params)

    elif backbone == 'vgg16':
        restore_from = './pretrained_model/vgg16-397923af.pth'
        model = DeeplabVGGFeat(num_classes, restore_from=restore_from, pretrained=pretrained, device=device)
    else:
        raise Exception

    return model