from graphs import deeplabv2, deeplabv3, erfnet

def get_model(args, num_classes, imagenet_pretrained=False, logger_fn=None, device='cuda'):

    model_backbone = args.model.split('-')
    if len(model_backbone) not in (1,2):
        raise ValueError(f'{model_backbone} model not supported')
    elif len(model_backbone)==1:
        model_name, backbone = model_backbone[0], ''
    else:
        model_name, backbone = model_backbone

    model_fns = {'deeplabv2': deeplabv2, 'deeplabv3': deeplabv3, 'erfnet': erfnet}
    model = model_fns[model_name.lower()](num_classes=num_classes, backbone=backbone.lower(), pretrained=imagenet_pretrained, logger_fn=logger_fn, device=device)

    params = model.optim_parameters(lr=args.lr, logger_fn=logger_fn)
    args.numpy_transform = True
    return model, params