import torch
import torch.backends.cudnn as cudnn
from tqdm import tqdm
from PIL import Image 
import numpy as np

from model import *
from utils.utils import *


def main(args):
    ''' Create Dir for Save '''
    _, model_dir, _, result_dir = create_dir(args)
    result_dir = result_dir.joinpath('from_acquired_data')
    result_dir.mkdir(exist_ok=True)

    ''' CPU or Cuda'''
    device = torch.device(args.device)
    if 'cuda' in args.device:
        torch.cuda.set_device(device)

    ''' MODEL LOADING '''
    view_point_num = args.angRes*args.angRes
    reconstruct = Reconstruct(view_point_num, args.aperture_num, 64).to(device)

    ''' Load Pre-Trained PTH '''
    if args.use_pre_ckpt_for_test == False:
        # Unimplemented
        pass
    else:
        ckpt_path_reconstruct  = model_dir.joinpath('epoch%03d'%(args.epoch) + '_reconstruct.pth')
        reconstruct.load_state_dict(torch.load(ckpt_path_reconstruct))
        print('Use pretrain model!')
        
    cudnn.benchmark = True

    ''' Print Parameters '''
    print('PARAMETER ...')
    print(args)

    ''' TEST on every dataset '''
    print('\nStart test...')
    with torch.no_grad():
        ''' Create Excel for PSNR/SSIM '''
        print('\nLoad Test Dataset ...')
        files_dir = [f for f in os.listdir(args.path_for_acquired_data) if os.path.isdir(os.path.join(args.path_for_acquired_data, f))]
        for dir in tqdm(files_dir):
            save_dir = result_dir.joinpath(dir)
            save_dir.mkdir(exist_ok=True)

            ''' image frame '''
            image_frame = np.array(Image.open(args.path_for_acquired_data + dir + "/image_frame.png")).astype('float32')/255.0
            H, W = image_frame.shape
            image_frame = image_frame.reshape(1,1,H,W)

            ''' event stacks '''
            event_stacks = np.zeros((1,args.aperture_num-1,H,W))
            for t in range(args.aperture_num-1):
                event_stacks[:,t,:,:] = (np.load(args.path_for_acquired_data + dir + "/event_stack_%05d.npy" %(t)).reshape(1,H,W).astype('float32'))/8

            test_input = np.concatenate([event_stacks,image_frame], axis = 1, dtype = np.float32)
            test_input = torch.from_numpy(test_input).to(device)

            test(test_input, device, reconstruct, save_dir, args)


def test(test_input, device, reconstruct, save_dir, args, save_image=True):

    view_point_num = args.angRes*args.angRes
    _, _, H, W = test_input.shape

    with torch.no_grad():
        re_data = np.zeros([view_point_num, H, W, 1], np.float32)

        tmp_light_field = reconstruct(test_input)

        re_data[..., 0] = tmp_light_field[0].to("cpu").numpy()
        re_data[re_data < 0.0]  = 0.0
        re_data[re_data > 1.0]  = 1.0

        ''' Save Reconstructed LF '''
        if save_image is True:
            re_light_field = np.zeros([view_point_num, H, W, 1], np.uint8)
            for t in range(view_point_num):
                re_light_field[t] = np.uint8(re_data[t]*255.0)
                img_save_re = Image.fromarray(re_light_field[t,...,0])
                u = int(t%np.sqrt(view_point_num))
                v = int(t/np.sqrt(view_point_num))
                img_save_re.save(str(save_dir) + "/%02d_%02d.png" % (v, u))

    return 


if __name__ == '__main__':
    from option import args

    main(args)
