# Code taken from https://github.com/pytorch/vision/blob/main/torchvision/models/segmentation/deeplabv3.py and slightly modified
from . import *
import torch
from torch import nn, Tensor
from torch.nn import functional as F

from torchvision.models import mobilenetv3
from torchvision.models import resnet
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.segmentation._utils import _SimpleSegmentationModel
from torchvision.models.segmentation.segmentation import _load_weights
from torchvision.models.segmentation.fcn import FCNHead


__all__ = [
    "DeepLabV3",
    "deeplabv3_resnet50",
    "deeplabv3_resnet101",
    "deeplabv3_mobilenet_v3_large",
    "deeplabv3",
]


model_urls = {
    "deeplabv3_resnet50_coco": "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
    "deeplabv3_resnet101_coco": "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
    "deeplabv3_mobilenet_v3_large_coco": "https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
}


class DeepLabV3(_SimpleSegmentationModel):
    """
    Implements DeepLabV3 model from
    `"Rethinking Atrous Convolution for Semantic Image Segmentation"
    <https://arxiv.org/abs/1706.05587>`_.

    Args:
        backbone (nn.Module): the network used to compute the features for the model.
            The backbone should return an OrderedDict[Tensor], with the key being
            "out" for the last feature map used, and "aux" if an auxiliary classifier
            is used.
        classifier (nn.Module): module that takes the "out" element returned from
            the backbone and returns a dense prediction.
        aux_classifier (nn.Module, optional): auxiliary classifier used during training
    """

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        input_shape = x.shape[-2:]
        # contract: features is a dict of tensors
        features = self.backbone(x)

        result = OrderedDict()
        x = features["out"]
        x = self.classifier(x)
        x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
        result["out"] = x
        result["features"] = features["out"]

        if self.aux_classifier is not None:
            x = features["aux"]
            x = self.aux_classifier(x)
            x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
            result["aux"] = x

        return result

    def optim_parameters(self, lr, logger_fn=None):
        if logger_fn: logger_fn(f'Backbone lr set to {lr}, Classifier lr set to {10 * lr}')
        return [{'params': self.backbone.parameters(), 'lr': lr},
                {'params': self.classifier.parameters(), 'lr': 10 * lr}]

    def freeze(self, network):
        n2n = {'encoder':'backbone','decoder':'classifier'}
        assert network in n2n
        for n, p in self.named_parameters():
            if n2n[network] in n:
                p.requires_grad=False
        raise NotImplementedError('Method has been implemented but needs to be checked')

    def load_layers(self, checkpoint, include_layers=(), exclude_layers=(), network=''):
        new_params = self.state_dict().copy()
        not_loaded = []
        for pname,pvalue in checkpoint.items():
            if set(exclude_layers) & set([pname]):  # need to match the full pname
                not_loaded.append(pname)
                continue
            if set(include_layers) & set(pname.split('.')):  # need to match one token of pname
                new_params[pname] = pvalue
            else:
                not_loaded.append(pname)
        if not_loaded: print(f'WARNING (loading {network}): {not_loaded} have not been loaded into current model')
        self.load_state_dict(new_params)

    def load_network(self, checkpoint, network, load_partial_decoder=True):
        encoder = ['backbone']
        decoder = ['classifier']
        load_last_layer, exclude_layers = False, []
        if network == 'all':
            layers = encoder + decoder
        elif network == 'encoder':
            layers = encoder
            if load_partial_decoder:
                layers += decoder
                exclude_layers = ['classifier.4.weight', 'classifier.4.bias']
                load_last_layer = True
        elif network == 'decoder':
            layers = decoder
        else:
            raise ValueError(f'Cannot load parameters from {network}')
        self.load_layers(checkpoint, include_layers=layers, exclude_layers=exclude_layers, network=network)
        if load_last_layer: self.load_last_layer_weights(checkpoint)

    def load_last_layer_weights(self, checkpoint):
        last_layer = {'classifier.4.weight': self.classifier[4].weight, 'classifier.4.bias': self.classifier[4].bias}
        for pname,pvalue in checkpoint.items():
            if pname in last_layer.keys():
                c = pvalue.size(0)
                last_layer[pname][:c].data.copy_(pvalue)



class DeepLabHead(nn.Sequential):
    def __init__(self, in_channels: int, num_classes: int) -> None:
        super().__init__(
            ASPP(in_channels, [12, 24, 36]),
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, num_classes, 1),
        )


class ASPPConv(nn.Sequential):
    def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None:
        modules = [
            nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        ]
        super().__init__(*modules)


class ASPPPooling(nn.Sequential):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        size = x.shape[-2:]
        for mod in self:
            x = mod(x)
        return F.interpolate(x, size=size, mode="bilinear", align_corners=False)


class ASPP(nn.Module):
    def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None:
        super().__init__()
        modules = []
        modules.append(
            nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU())
        )

        rates = tuple(atrous_rates)
        for rate in rates:
            modules.append(ASPPConv(in_channels, out_channels, rate))

        modules.append(ASPPPooling(in_channels, out_channels))

        self.convs = nn.ModuleList(modules)

        self.project = nn.Sequential(
            nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        _res = []
        for conv in self.convs:
            _res.append(conv(x))
        res = torch.cat(_res, dim=1)
        return self.project(res)


def _deeplabv3_resnet(
        backbone: resnet.ResNet,
        num_classes: int,
        aux: Optional[bool],
) -> DeepLabV3:
    return_layers = {"layer4": "out"}
    if aux:
        return_layers["layer3"] = "aux"
    backbone = create_feature_extractor(backbone, return_layers)

    aux_classifier = FCNHead(1024, num_classes) if aux else None
    classifier = DeepLabHead(2048, num_classes)
    return DeepLabV3(backbone, classifier, aux_classifier)


def _deeplabv3_mobilenetv3(
        backbone: mobilenetv3.MobileNetV3,
        num_classes: int,
        aux: Optional[bool],
) -> DeepLabV3:
    backbone = backbone.features
    # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
    # The first and last blocks are always included because they are the C0 (conv1) and Cn.
    stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
    out_pos = stage_indices[-1]  # use C5 which has output_stride = 16
    out_inplanes = backbone[out_pos].out_channels
    aux_pos = stage_indices[-4]  # use C2 here which has output_stride = 8
    aux_inplanes = backbone[aux_pos].out_channels
    return_layers = {str(out_pos): "out"}
    if aux:
        return_layers[str(aux_pos)] = "aux"
    backbone = create_feature_extractor(backbone, return_layers)

    aux_classifier = FCNHead(aux_inplanes, num_classes) if aux else None
    classifier = DeepLabHead(out_inplanes, num_classes)
    return DeepLabV3(backbone, classifier, aux_classifier)


def deeplabv3_resnet50(
        pretrained: bool = False,
        progress: bool = True,
        num_classes: int = 21,
        aux_loss: Optional[bool] = None,
        pretrained_backbone: bool = True,
) -> DeepLabV3:
    """Constructs a DeepLabV3 model with a ResNet-50 backbone.

    Args:
        pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
            contains the same classes as Pascal VOC
        progress (bool): If True, displays a progress bar of the download to stderr
        num_classes (int): number of output classes of the model (including the background)
        aux_loss (bool, optional): If True, it uses an auxiliary loss
        pretrained_backbone (bool): If True, the backbone will be pre-trained.
    """
    if pretrained:
        aux_loss = True
        pretrained_backbone = False

    backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
    model = _deeplabv3_resnet(backbone, num_classes, aux_loss)

    if pretrained:
        arch = "deeplabv3_resnet50_coco"
        _load_weights(arch, model, model_urls.get(arch, None), progress)
    return model


def deeplabv3_resnet101(
        pretrained: bool = False,
        progress: bool = True,
        num_classes: int = 21,
        aux_loss: Optional[bool] = None,
        pretrained_backbone: bool = True,
) -> DeepLabV3:
    """Constructs a DeepLabV3 model with a ResNet-101 backbone.

    Args:
        pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
            contains the same classes as Pascal VOC
        progress (bool): If True, displays a progress bar of the download to stderr
        num_classes (int): The number of classes
        aux_loss (bool, optional): If True, include an auxiliary classifier
        pretrained_backbone (bool): If True, the backbone will be pre-trained.
    """
    if pretrained:
        aux_loss = True
        pretrained_backbone = False

    backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
    model = _deeplabv3_resnet(backbone, num_classes, aux_loss)

    if pretrained:
        arch = "deeplabv3_resnet101_coco"
        _load_weights(arch, model, model_urls.get(arch, None), progress)
    return model


def deeplabv3_mobilenet_v3_large(
        pretrained: bool = False,
        progress: bool = True,
        num_classes: int = 21,
        aux_loss: Optional[bool] = None,
        pretrained_backbone: bool = True,
) -> DeepLabV3:
    """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.

    Args:
        pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
            contains the same classes as Pascal VOC
        progress (bool): If True, displays a progress bar of the download to stderr
        num_classes (int): number of output classes of the model (including the background)
        aux_loss (bool, optional): If True, it uses an auxiliary loss
        pretrained_backbone (bool): If True, the backbone will be pre-trained.
    """
    if pretrained:
        aux_loss = True
        pretrained_backbone = False

    backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True)
    model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)

    if pretrained:
        arch = "deeplabv3_mobilenet_v3_large_coco"
        _load_weights(arch, model, model_urls.get(arch, None), progress)
    return model


def deeplabv3(num_classes, backbone, pretrained=True, logger_fn=None, **kwargs):

    backbones = {'mobilenetv3': deeplabv3_mobilenet_v3_large,
                 'resnet50': deeplabv3_resnet50,
                 'resnet101': deeplabv3_resnet101,}
    backbone = backbone.lower()
    assert backbone in backbones, f'{backbone} not among available backbones {backbones.keys()}'
    if logger_fn: logger_fn(f'Loading DeeplabV3-{backbone}')

    return backbones[backbone](pretrained_backbone=pretrained, num_classes=num_classes)