import torch
import torch.nn.functional as F
import torch.nn as nn
from math import sqrt


class ApertureMask(nn.Module):
    def __init__(self,  view_point_num, aperture_num):
        super().__init__()
        self.view_point_num = view_point_num    
        self.aperture_num = aperture_num
        
        k = 1 / self.view_point_num
        param = torch.empty(self.view_point_num,int(self.aperture_num/2)).uniform_(-sqrt(k), sqrt(k))
        self.param= nn.Parameter(param)
        self.W = torch.empty(self.view_point_num,int(self.aperture_num/2)).uniform_(-sqrt(k), sqrt(k))
        self.scale = 1.02

    def forward(self, LF, epoch):
        self.W = torch.sigmoid(self.param*(self.scale**epoch))
        A = torch.stack((self.W[:,0],1-self.W[:,0]), axis=1)
        for t in range(int(self.aperture_num/2)-1):
            A = torch.concat((A,torch.stack((self.W[:,t+1],1-self.W[:,t+1]), axis=1)), axis=1)
        I = torch.permute(torch.matmul(torch.permute(LF,(0,2,3,1)),A),(0,3,1,2))
        return I

    def get_mask(self, num):
        W_num = int(num/2)
        if num % 2 == 0:
            mask = self.W[:,W_num]
        else:
            mask = 1-self.W[:,W_num]
        return mask.detach()

    def clip_mask(self):
        self.W.data.clamp_(0.0, 1.0)


class RoundNoGradient(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x.to(torch.int64).to(torch.float32)
    @staticmethod
    def backward(ctx, g):
        return g 


class DiffEventCam(torch.nn.Module):
    def __init__(self, threshold_low, threshold_high, pixel_noise, eps, device):
        super().__init__()
        self.th_mid = (threshold_high + threshold_low)/2.0
        self.th_range = (threshold_high - threshold_low)
        self.pixel_noise = pixel_noise
        self.e = eps
        self.device = device

    def forward(self,im1,im2,noise_rand):
        ''' batch-level threshold randamization'''
        th_batch_rand = (self.th_range/2.0)*(2*noise_rand - 1)
        ''' pixel-level threshold noise '''
        th_pixel_noise = self.pixel_noise*(torch.randn((im1.shape))).to(self.device)

        true_threshold = self.th_mid + th_batch_rand + th_pixel_noise
        true_threshold[true_threshold < 0.01] = 0.01

        out = torch.log(im2+self.e) - torch.log(im1+self.e)
        out = out/(true_threshold)
        out = RoundNoGradient.apply(out)
        return out


class Reconstruct(torch.nn.Module):
    def __init__(self, view_point_num, aperture_num, ch_mid):
        super(Reconstruct, self).__init__()
        self.view_point_num = view_point_num
        self.input_num = aperture_num
        self.conv1 = nn.Conv2d(self.input_num, 64, 5, padding=2)
        self.conv2 = nn.Conv2d(64, 64, 5, padding=2)
        self.conv3 = nn.Conv2d(64, self.view_point_num, 5, padding=2)

        self.ch_in = view_point_num
        self.ch_mid = ch_mid
        self.ch_out = view_point_num

        self.l1 = torch.nn.Conv2d(in_channels=self.ch_in, out_channels=self.ch_mid, kernel_size=3, stride=1, padding=1)
        self.l2 = torch.nn.Conv2d(in_channels=self.ch_mid, out_channels=self.ch_mid, kernel_size=3, stride=1, padding=1)
        self.l3 = torch.nn.Conv2d(in_channels=self.ch_mid, out_channels=self.ch_mid, kernel_size=3, stride=1, padding=1)
        self.l4 = torch.nn.Conv2d(in_channels=self.ch_mid, out_channels=self.ch_mid, kernel_size=3, stride=1, padding=1)
        self.l5 = torch.nn.Conv2d(in_channels=self.ch_mid, out_channels=self.ch_mid, kernel_size=3, stride=1, padding=1)
        self.l6 = torch.nn.Conv2d(in_channels=self.ch_mid, out_channels=self.ch_mid, kernel_size=3, stride=1, padding=1)
        self.l7 = torch.nn.Conv2d(in_channels=self.ch_mid, out_channels=self.ch_mid, kernel_size=3, stride=1, padding=1)
        self.l8 = torch.nn.Conv2d(in_channels=self.ch_mid, out_channels=self.ch_mid, kernel_size=3, stride=1, padding=1)
        self.l9 = torch.nn.Conv2d(in_channels=self.ch_mid, out_channels=self.ch_mid, kernel_size=3, stride=1, padding=1)
        self.l10 = torch.nn.Conv2d(in_channels=self.ch_mid, out_channels=self.ch_mid, kernel_size=3, stride=1, padding=1)
        self.l11 = torch.nn.Conv2d(in_channels=self.ch_mid, out_channels=self.ch_mid, kernel_size=3, stride=1, padding=1)
        self.l12 = torch.nn.Conv2d(in_channels=self.ch_mid, out_channels=self.ch_mid, kernel_size=3, stride=1, padding=1)
        self.l13 = torch.nn.Conv2d(in_channels=self.ch_mid, out_channels=self.ch_mid, kernel_size=3, stride=1, padding=1)
        self.l14 = torch.nn.Conv2d(in_channels=self.ch_mid, out_channels=self.ch_mid, kernel_size=3, stride=1, padding=1)
        self.l15 = torch.nn.Conv2d(in_channels=self.ch_mid, out_channels=self.ch_mid, kernel_size=3, stride=1, padding=1)
        self.l16 = torch.nn.Conv2d(in_channels=self.ch_mid, out_channels=self.ch_mid, kernel_size=3, stride=1, padding=1)
        self.l17 = torch.nn.Conv2d(in_channels=self.ch_mid, out_channels=self.ch_mid, kernel_size=3, stride=1, padding=1)
        self.l18 = torch.nn.Conv2d(in_channels=self.ch_mid, out_channels=self.ch_mid, kernel_size=3, stride=1, padding=1)
        self.l19 = torch.nn.Conv2d(in_channels=self.ch_mid, out_channels=self.ch_mid, kernel_size=3, stride=1, padding=1)
        self.l20 = torch.nn.Conv2d(in_channels=self.ch_mid, out_channels=self.ch_out, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        f1 = self.conv1(x)
        f2 = self.conv2(f1)
        f3 = self.conv3(f2)

        y1 = F.relu(self.l1(f3))
        y2 = F.relu(self.l2(y1))
        y3 = F.relu(self.l3(y2))
        y4 = F.relu(self.l4(y3))
        y5 = F.relu(self.l5(y4))
        y6 = F.relu(self.l6(y5))
        y7 = F.relu(self.l7(y6))
        y8 = F.relu(self.l8(y7))
        y9 = F.relu(self.l9(y8))
        y10 = F.relu(self.l10(y9))
        y11 = F.relu(self.l11(y10))
        y12 = F.relu(self.l12(y11))
        y13 = F.relu(self.l13(y12))
        y14 = F.relu(self.l14(y13))
        y15 = F.relu(self.l15(y14))
        y16 = F.relu(self.l16(y15))
        y17 = F.relu(self.l17(y16))
        y18 = F.relu(self.l18(y17))
        y19 = F.relu(self.l19(y18))
        y20 = self.l20(y19)
        return y20 + f3
        