from . import *

class StyleAugment:


    def __init__(self, n_images_per_style, L, avg_style=False, fixed_size=None):
        self.styles = {}
        self.n_images_per_style = n_images_per_style
        self.L = L
        self.avg_style = avg_style
        self.fixed_size = fixed_size
        self.window_size = None


    def preprocess(self, x):
        if self.fixed_size is not None:
            x = x.resize(self.fixed_size, Image.BICUBIC)
        x = np.asarray(x, np.float32)
        x = x[:, :, ::-1]
        x = x.transpose((2, 0, 1))
        return x.copy()

    def deprocess(self, x, size): # size: W,H
        x = Image.fromarray(np.uint8(x).transpose((1, 2, 0))[:, :, ::-1])
        if self.fixed_size is not None:
            x = x.resize(size, Image.BICUBIC)
        return x


    def add_style(self, dataset, step, save_or_load_path=None, dom_unrel_styles=False):

        if self.n_images_per_style<0: return
        if save_or_load_path is not None and self.avg_style:
            if self._load_avg_style(save_or_load_path, step):
                return

        dataset.fda_style_extraction = True

        indices = np.random.randint(0, len(dataset), min(self.n_images_per_style,len(dataset)))
        styles, style_sum = [], None
        for i in tqdm(indices, mininterval=None, miniters=min(self.n_images_per_style,len(dataset))//5, maxinterval=60000., desc=f"Extracting styles, L:{self.L}"):
            image,_,_,_ = dataset[i]
            image = self.preprocess(image)

            single_style = self._extract_style(image)
            if style_sum is None: style_sum = single_style
            else: style_sum += single_style

            if not self.avg_style: styles.append(single_style)

        avg_style = style_sum / len(indices)
        self.styles[step] = {'avg':avg_style, 'single':styles}
        if dom_unrel_styles: # images are randomly taken from a new dataset, so no need to have memory of past
            for s in range(step):
                self.styles[s] = {'avg': avg_style, 'single': styles}

        dataset.fda_style_extraction = False

        if save_or_load_path is not None:
            self._save_avg_style(save_or_load_path, step)


    def _extract_style(self, img_np):
        fft_np = np.fft.fft2( img_np, axes=(-2, -1) )
        amp = np.abs(fft_np)
        amp_shift = np.fft.fftshift( amp, axes=(-2, -1) )
        if self.window_size is None:
            self.window_size = self.compute_size(amp_shift)
        h1,h2,w1,w2 = self.window_size
        style = amp_shift[:,h1:h2,w1:w2]
        return style

    def _load_avg_style(self, path, step):
        if os.path.isfile(path):
            print(f'Loading avg style from {path}')
            self.styles[step] = {'avg':np.load(path)}
            return True
        return False

    def _save_avg_style(self, path, step):
        parent_dir = os.path.dirname(path)
        if not os.path.isdir(parent_dir): os.makedirs(parent_dir)
        print(f'Saving avg style from {path}')
        np.save(path, self.styles[step]['avg'])


    def compute_size(self, amp_shift):
        _, h, w = amp_shift.shape
        b = (  np.floor(np.amin((h,w))*self.L)  ).astype(int)
        c_h = np.floor(h/2.0).astype(int)
        c_w = np.floor(w/2.0).astype(int)
        h1 = c_h-b
        h2 = c_h+b+1
        w1 = c_w-b
        w2 = c_w+b+1
        return h1,h2,w1,w2


    def apply_style(self, img, step=None):
        if self.n_images_per_style<=0: return img
        if step is not None:
            return [img, self._apply_style(img, step)]
        else:
            return [img, *[self._apply_style(img, s) for s in self.styles]]


    def _apply_style(self, img, step):
        styles = self.styles[step]
        if self.avg_style:
            style = styles['avg']
        else:
            n = random.randint(0,len(styles['single'])-1)
            style = styles['single'][n]

        W,H = img.size
        img_np = self.preprocess(img)

        fft_np = np.fft.fft2( img_np, axes=(-2, -1) )
        amp, pha = np.abs(fft_np), np.angle(fft_np)
        amp_shift = np.fft.fftshift( amp, axes=(-2, -1) )
        if self.window_size is None:
            self.window_size = self.compute_size(amp_shift)
        h1,h2,w1,w2 = self.window_size
        amp_shift[:,h1:h2,w1:w2] = style
        amp_ = np.fft.ifftshift( amp_shift, axes=(-2, -1) )

        fft_ = amp_ * np.exp( 1j * pha )
        img_np_ = np.fft.ifft2( fft_, axes=(-2, -1) )
        img_np_ = np.real(img_np_)
        img_np__ = np.clip(np.round(img_np_), 0., 255.)

        img_with_style = self.deprocess(img_np__, (W,H))

        return img_with_style


    @staticmethod
    def compute_pseudo_labels(images_with_style, old_model):
        # convert images
        with torch.no_grad():
            softmax_with_style = [old_model(imgs)['out'] for imgs in images_with_style]
        softmax_values = torch.stack(softmax_with_style, dim=0)  # S x B x C x H x W
        S,B,C,H,W = softmax_values.size()
        max_values = torch.stack([pred.max(1)[0] for pred in softmax_with_style], dim=0)  # S x B x H x W
        style_per_pixel = max_values.max(0)[1]  # B x H x W
        style_per_pixel = style_per_pixel.unsqueeze(1).repeat(1, C, 1, 1).unsqueeze(0) # 1 x B x C x H x W
        softmax_values_refined = torch.gather(softmax_values, 0, style_per_pixel).squeeze(0) # 1 x B x C x H x W -> B x C x H x W
        return softmax_values_refined.max(1)[1]
