import torch
import torch.backends.cudnn as cudnn
from tqdm import tqdm
from PIL import Image

from model import *
from utils.utils import *
from utils.utils_datasets import MultiTestSetDataLoader


def main(args):
    ''' Create Dir for Save '''
    _, model_dir, _, result_dir = create_dir(args)
    result_dir = result_dir.joinpath('from_LF')
    result_dir.mkdir(exist_ok=True)
    result_dir = result_dir.joinpath(f'test_threshold_{args.test_threshold}')
    result_dir.mkdir(exist_ok=True)

    ''' CPU or Cuda'''
    device = torch.device(args.device)
    if 'cuda' in args.device:
        torch.cuda.set_device(device)

    ''' DATA TEST LOADING '''
    print('\nLoad Test Dataset ...')
    test_Names, test_Loaders, length_of_tests = MultiTestSetDataLoader(args)
    print("The number of test data is: %d" % length_of_tests)

    ''' MODEL LOADING '''
    view_point_num = args.angRes*args.angRes
    aperture_mask = ApertureMask(view_point_num, args.aperture_num).to(device)
    reconstruct = Reconstruct(view_point_num, args.aperture_num, 64).to(device)
    esim = DiffEventCam(args.test_threshold, args.test_threshold, args.test_threshold_noise, args.eps, device).to(device)

    total_trainable_params = sum(p.numel() for p in reconstruct.parameters() if p.requires_grad)
    print("Training parameters: %d" %total_trainable_params)

    ''' Load Pre-Trained PTH '''
    if args.use_pre_ckpt_for_test == False:
        # Unimplemented
        pass
    else:
        ckpt_path_mask         = model_dir.joinpath('epoch%03d'%(args.epoch) + '_mask.pth')
        ckpt_path_reconstruct  = model_dir.joinpath('epoch%03d'%(args.epoch) + '_reconstruct.pth')
        aperture_mask.load_state_dict(torch.load(ckpt_path_mask))
        reconstruct.load_state_dict(torch.load(ckpt_path_reconstruct))
        print('Use pretrain model!')

    cudnn.benchmark = True

    ''' Print Parameters '''
    print('PARAMETER ...')
    print(args)

    ''' TEST on every dataset '''
    print('\nStart test...')
    with torch.no_grad():
        ''' Create Excel for PSNR/SSIM '''
        excel_file = ExcelFile()
        psnr_testset = []
        ssim_testset = []
        idx_epoch = args.epoch

        for index, test_name in enumerate(test_Names):
            test_loader = test_Loaders[index]

            save_dir = result_dir.joinpath(test_name)
            save_dir.mkdir(exist_ok=True)

            psnr_iter_test, ssim_iter_test, LF_name = test(test_loader, device, aperture_mask, reconstruct, esim, save_dir, args, idx_epoch)
            excel_file.write_sheet(test_name, LF_name, psnr_iter_test, ssim_iter_test)

            psnr_epoch_test = float(np.array(psnr_iter_test).mean())
            ssim_epoch_test = float(np.array(ssim_iter_test).mean())
            psnr_testset.append(psnr_epoch_test)
            ssim_testset.append(ssim_epoch_test)
            print('Test on %s, psnr/ssim is %.2f/%.3f' % (test_name, psnr_epoch_test, ssim_epoch_test))

        psnr_mean_test = float(np.array(psnr_testset).mean())
        ssim_mean_test = float(np.array(ssim_testset).mean())

        excel_file.add_sheet('ALL', 'Average', psnr_mean_test, ssim_mean_test)
        print('The mean psnr on testsets is %.5f, mean ssim is %.5f' % (psnr_mean_test, ssim_mean_test))
        excel_file.xlsx_file.save(str(result_dir) + '/evaluation.xls')


def test(test_loader, device, aperture_mask, reconstruct, esim, save_dir, args, idx_epoch, save_image=True):

    LF_iter_test = []
    psnr_iter_test = []
    ssim_iter_test = []

    for idx_iter, (data, LF_name) in tqdm(enumerate(test_loader), total=len(test_loader), ncols=70):
        data = data.to(device)
        _, view_point_num, H, W = data.shape
        data = data.reshape(_, view_point_num, H, W, 1)
        ac_event_stacks = np.zeros([args.aperture_num-1, H, W, 1], np.float32)
        ac_image_frame = np.zeros([1, H, W, 1], np.float32)
        re_data = np.zeros([view_point_num, H, W, 1], np.float32)

        with torch.no_grad():
            image_frame = torch.zeros(1,1,H,W).to(device)
            event_stacks = torch.zeros(1,args.aperture_num-1,H,W).to(device)
            temporal_images = aperture_mask(data[..., 0], idx_epoch)/view_point_num

            ''''image frame'''
            for t in range(args.aperture_num):
                image_frame += temporal_images[:,t:t+1,:,:]
            image_frame = image_frame/(args.aperture_num/2)
            noise = args.image_noise*torch.randn(image_frame[0].shape, dtype=torch.float32).to(device)
            noise_image_frame = image_frame + noise.float()

            '''event stacks'''
            threshold_rand = torch.rand(1,1,1).to(device)
            for t in range(args.aperture_num-1):
                event_stacks[:,t,:,:] = esim(temporal_images[:,t,:,:], temporal_images[:,t+1,:,:], threshold_rand)
            event_stacks = event_stacks/8.0
            # print(event_num = torch.sum(torch.abs(event_stacks)))

            input = torch.cat([event_stacks, noise_image_frame],dim = 1)

            tmp_light_field = reconstruct(input)
            
            ac_event_stacks[..., 0] = event_stacks[0].to("cpu").numpy()
            ac_image_frame[..., 0] = noise_image_frame[0].to("cpu").numpy()
            re_data[..., 0] = tmp_light_field[0].to("cpu").numpy()
                
            re_data[re_data < 0.0]  = 0.0
            re_data[re_data > 1.0]  = 1.0

        ''' Save Original & Reconstructed LF '''
        gt_data = data[0].to("cpu").numpy()

        if save_image is True:
            save_dir_ = save_dir.joinpath(LF_name[0])
            save_dir_.mkdir(exist_ok=True)
            re_views_dir = save_dir_.joinpath('re_views')
            re_views_dir.mkdir(exist_ok=True)
            gt_views_dir = save_dir_.joinpath('gt_views')
            gt_views_dir.mkdir(exist_ok=True)

        gt_light_field = np.zeros([view_point_num, H, W, 1], np.uint8)
        re_light_field = np.zeros([view_point_num, H, W, 1], np.uint8)
        for t in range(view_point_num):
            gt_light_field[t] = np.uint8(gt_data[t]*255.0)
            re_light_field[t] = np.uint8(re_data[t]*255.0)

            if save_image is True:
                img_save_gt = Image.fromarray(gt_light_field[t,...,0])
                img_save_re = Image.fromarray(re_light_field[t,...,0])
                u = int(t%np.sqrt(view_point_num))
                v = int(t/np.sqrt(view_point_num))
                img_save_gt.save(str(gt_views_dir) + "/%02d_%02d.png" % (v, u))
                img_save_re.save(str(re_views_dir) + "/%02d_%02d.png" % (v, u))

        ''' Calculate PSNR & SSIM '''
        psnr, ssim, mse = cal_metrics(gt_light_field, re_light_field)
        ave_ssim = np.average(ssim)
        ave_mse = np.average(mse)
        # ave_psnr = np.average(psnr) # mse->psnr->ave_psnr
        ave_psnr = 10*np.log10(255.0**2/ave_mse) # mse->ave_mse->ave_psnr
        print("\n Data %s \tpsnr : %f \tssim : %f"%(LF_name[0], ave_psnr, ave_ssim))
        
        psnr_iter_test.append(ave_psnr)
        ssim_iter_test.append(ave_ssim)
        LF_iter_test.append(LF_name[0])

        if save_image is True:
            '''Save PSNR heatmap'''
            fname = os.path.join(save_dir_, "%s_heatmap"%(LF_name[0]))
            create_heatmap(fname, psnr, vmax=38, vmin=23)

            acquired_img_dir = save_dir_.joinpath('acquired_data')
            acquired_img_dir.mkdir(exist_ok=True)
            
            '''Save event stacks'''
            for t in range(ac_event_stacks.shape[0]):
                event_stacks_save = vis_event((ac_event_stacks[t]*8).astype(np.int8), event_max_count=4, cbar=False)
                event_stacks_save.savefig(str(acquired_img_dir) + "/event_stack_vis_%05d.png" % (t), bbox_inches = 'tight', pad_inches = 0)
                np.save(str(acquired_img_dir) + "/event_stack_%05d" % (t), (ac_event_stacks[t]*8).astype(np.int8))

            '''Save image frame'''
            img_save = ac_image_frame[0]
            img_save[img_save < 0.0] = 0.0
            img_save[img_save > 1.0] = 1.0
            img_save = Image.fromarray(np.uint8(img_save[...,0] * 255.0))
            img_save.save(str(acquired_img_dir) + "/image_frame.png")

    return psnr_iter_test, ssim_iter_test, LF_iter_test


if __name__ == '__main__':
    from option import args

    main(args)
