from skimage.metrics import structural_similarity, peak_signal_noise_ratio, mean_squared_error
from pathlib import Path
import os
import logging
import xlwt
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize

from option import args


class ExcelFile():
    def __init__(self):
        self.xlsx_file = xlwt.Workbook()
        self.worksheet = self.xlsx_file.add_sheet(r'sheet1', cell_overwrite_ok=True)
        self.worksheet.write(0, 0, 'Datasets')
        self.worksheet.write(0, 1, 'Scenes')
        self.worksheet.write(0, 2, 'PSNR')
        self.worksheet.write(0, 3, 'SSIM')
        self.worksheet.col(0).width = 256 * 16
        self.worksheet.col(1).width = 256 * 22
        self.worksheet.col(2).width = 256 * 10
        self.worksheet.col(3).width = 256 * 10
        self.sum = 1

    def write_sheet(self, test_name, LF_name, psnr_iter_test, ssim_iter_test):
        ''' Save PSNR & SSIM '''
        for i in range(len(psnr_iter_test)):
            self.add_sheet(test_name, LF_name[i], psnr_iter_test[i], ssim_iter_test[i])

        psnr_epoch_test = float(np.array(psnr_iter_test).mean())
        ssim_epoch_test = float(np.array(ssim_iter_test).mean())
        self.add_sheet(test_name, 'average', psnr_epoch_test, ssim_epoch_test)
        self.sum = self.sum + 1

    def add_sheet(self, test_name, LF_name, psnr_iter_test, ssim_iter_test):
        ''' Save PSNR & SSIM '''
        self.worksheet.write(self.sum, 0, test_name)
        self.worksheet.write(self.sum, 1, LF_name)
        self.worksheet.write(self.sum, 2, '%.6f' % psnr_iter_test)
        self.worksheet.write(self.sum, 3, '%.6f' % ssim_iter_test)
        self.sum = self.sum + 1


def get_logger(log_dir, args):
    '''LOG '''
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/log.txt' % (log_dir))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    return logger


def create_dir(args):
    save_dir = Path(args.path_save)
    save_dir.mkdir(exist_ok=True)

    model_dir = save_dir.joinpath('model/')
    model_dir.mkdir(exist_ok=True)

    mask_dir = save_dir.joinpath('mask/')
    mask_dir.mkdir(exist_ok=True)

    val_dir = save_dir.joinpath('test')
    val_dir.mkdir(exist_ok=True)

    return save_dir, model_dir, mask_dir, val_dir


class Logger():
    def __init__(self, log_dir, args):
        self.logger = get_logger(log_dir, args)

    def log_string(self, str):
        self.logger.info(str)
        print(str)


def cal_metrics(gt_LF, reconstructed_LF):
        if gt_LF.ndim == 3:
            view_point_num, _, _ = gt_LF.shape
            chanel_num = 1
        else:
            view_point_num, _, _, chanel_num = gt_LF.shape

        ssim = np.zeros((view_point_num), np.float32)
        psnr = np.zeros((view_point_num), np.float32)
        mse = np.zeros((view_point_num), np.float32)

        for t in  range(view_point_num):
            sub_psnr = peak_signal_noise_ratio(reconstructed_LF[t], gt_LF[t])
            sub_mse = mean_squared_error(reconstructed_LF[t], gt_LF[t])

            if chanel_num == 1:
                sub_ssim = structural_similarity(reconstructed_LF[t,...,0], gt_LF[t,...,0], multichannel=False)
            else:
                sub_ssim = structural_similarity(reconstructed_LF[t], gt_LF[t], multichannel=True)

            psnr[t] = sub_psnr
            ssim[t] = sub_ssim
            mse[t] = sub_mse

        return psnr, ssim, mse

def create_heatmap(filename, psnr_data, vmax=45.0, vmin=25.0, cbar=False, annot=False):
        ag_size = int(np.sqrt(psnr_data.shape[0]))
        psnr_data = np.reshape(psnr_data, (ag_size, ag_size))

        plt.style.use('default')
        sns.set()
        sns.set_style('whitegrid')

        df = pd.DataFrame(psnr_data)
        sns.heatmap(df,  vmax = vmax, vmin = vmin,cmap="jet",fmt='.6f',xticklabels=False, yticklabels=False, cbar= cbar, square = True, annot=annot)

        plt.savefig(filename, bbox_inches="tight")
        plt.close()

def rgb2ycbcr(x):
    y = np.zeros(x.shape, dtype='double')
    y[:,:,0] =  65.481 * x[:, :, 0] + 128.553 * x[:, :, 1] +  24.966 * x[:, :, 2] +  16.0
    y[:,:,1] = -37.797 * x[:, :, 0] -  74.203 * x[:, :, 1] + 112.000 * x[:, :, 2] + 128.0
    y[:,:,2] = 112.000 * x[:, :, 0] -  93.786 * x[:, :, 1] -  18.214 * x[:, :, 2] + 128.0

    y = y / 255.0
    return y


def ycbcr2rgb(x):
    mat = np.array(
        [[65.481, 128.553, 24.966],
         [-37.797, -74.203, 112.0],
         [112.0, -93.786, -18.214]])
    mat_inv = np.linalg.inv(mat)
    offset = np.matmul(mat_inv, np.array([16, 128, 128]))
    mat_inv = mat_inv * 255

    y = np.zeros(x.shape, dtype='double')
    y[:,:,0] =  mat_inv[0,0] * x[:, :, 0] + mat_inv[0,1] * x[:, :, 1] + mat_inv[0,2] * x[:, :, 2] - offset[0]
    y[:,:,1] =  mat_inv[1,0] * x[:, :, 0] + mat_inv[1,1] * x[:, :, 1] + mat_inv[1,2] * x[:, :, 2] - offset[1]
    y[:,:,2] =  mat_inv[2,0] * x[:, :, 0] + mat_inv[2,1] * x[:, :, 1] + mat_inv[2,2] * x[:, :, 2] - offset[2]
    return y


def vis_event(event_frame, event_max_count, cbar=False):
    H, W, _ = event_frame.shape
    dpi = 100

    fig = plt.figure(figsize=(W/dpi, H/dpi))
    ax = fig.add_subplot(111)
    ax.axis("off")
    img = ax.imshow(event_frame, cmap="bwr", norm=Normalize(vmin=-event_max_count, vmax=event_max_count))
    if cbar == True:
        plt.colorbar(img, orientation="vertical")
    plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
    
    plt.close()
    return fig