import numpy as np
import processingIO as pio
import matplotlib.pyplot as plt
import pandas as pd

from pathlib import Path

## series of functions that calculate the velocity h_dot(x,t) and calculates power
def carangiform_power(x, t, Fy, st, length, waveNumber, amplitude):
    omega = 2*np.pi*st*1.0/(2.0*amplitude[0])
    x_bl = x / length
    localAmplitude = amplitude[0] * (1 + (x_bl - 1) * amplitude[1] + (np.power(x_bl,2) - 1) * amplitude[2])
    h_dot = np.multiply(-omega * localAmplitude, np.cos(t*omega - waveNumber*x_bl)) * length

    return np.abs(np.sum(np.multiply(h_dot, Fy)))

def ostraciiform_power(x, t, Fy, st,length, waveNumber, amplitude, pivot, maxAngle, phaseAngle):
    omega = 2*np.pi*st*1.0/(2.0*amplitude[0])
    
    x_bl = x / length

    index = np.argwhere(x_bl >= pivot)

    index = x_bl >= pivot
    x_tail = x_bl[index]
    x_body = x_bl[~index]
    
    h_dot_head = amplitude[0] * omega * (amplitude[1] + amplitude[2] -1) * np.cos(omega*t) * length
    h_dot_foil = amplitude[0] * -omega * (amplitude[1]*(pivot-1) + amplitude[2]*(np.power(pivot,2) - 1) + 1) * np.cos(omega * t - waveNumber * pivot) * length

    h_dot_body = np.multiply((h_dot_foil - h_dot_head)/(pivot - 0),x_body) + h_dot_head

    # Solve for the tail    
    theta_max = np.radians(maxAngle)
    phi = np.radians(phaseAngle)

    first_term = -theta_max * omega * np.cos(omega*t - waveNumber*pivot - phi)
    second_term = np.power(np.cos(theta_max * np.sin(omega*t-waveNumber*pivot - phi)), -2)

    h_dot_tail = h_dot_foil + np.multiply(x_tail, np.multiply(first_term, second_term)) * length

    h_dot = np.append(h_dot_body, h_dot_tail)
    
    return np.abs(np.sum(np.multiply(h_dot, Fy)))

def thunniform_power(x, t, Fy, st, length, waveNumber, amplitude, pivot, maxAngle, phaseAngle):
    omega = 2*np.pi*st*1.0/(2.0*amplitude[0])
    
    x_bl = x/length

    index = x_bl >= pivot
    x_tail = x_bl[index]
    x_body = x_bl[~index]
    
    # calculate body like carangiform    
    localAmplitude = amplitude[0] * (1 + (x_body - 1) * amplitude[1] + (np.power(x_body,2) - 1) * amplitude[2])
    h_dot_foil = amplitude[0] * -omega * (amplitude[1]*(pivot-1) + amplitude[2]*(np.power(pivot,2) - 1) + 1) * np.cos(omega * t - waveNumber * pivot) * length
    
    h_dot_body = np.multiply(-omega * length, np.multiply(localAmplitude, np.cos(waveNumber*x_body - omega*t)))

    # Solve for the tail    
    theta_max = np.radians(maxAngle)
    phi = np.radians(phaseAngle)

    first_term = -theta_max * omega * np.cos(omega*t - waveNumber*pivot - phi)
    second_term = np.power(np.cos(theta_max * np.sin(omega*t-waveNumber*pivot - phi)), -2)

    h_dot_tail = h_dot_foil + np.multiply(x_tail, np.multiply(first_term, second_term)) * length

    h_dot = np.append(h_dot_body, h_dot_tail)
    return np.abs(np.sum(np.multiply(h_dot, Fy)))

def uniform_sampling(dx, x, y):
    y_values = np.reshape(y, y.size)
    
    return np.interp(dx, x, y_values)

#-------- Start of code --------------#
force_flag = 1
forceBin_flag = 1

# get values with the airfoil not moving
no_move_location = Path(r'./no_movements/Output data')
no_move_force_paths = pio.get_files(no_move_location.joinpath('forces'), '*.csv')
no_move_force_paths.sort()
no_move_force_dfs = [pd.read_csv(path) for path in no_move_force_paths]

no_move_dict = {}

for i, df in enumerate(no_move_force_dfs):
    Re_case = no_move_force_paths[i].name[2:6]
    no_move_dict[Re_case] = {}

    cuttoff = 4.0
    index = np.argwhere(df['time'].values >= cuttoff)

    for force in ['total_x', 'pressure_x', 'viscous_x', 'CD', 'CDf', 'CDp']:
        no_move_dict[Re_case][force] = np.mean(df[force].values)

# all data locations
data_locations = [Path(r'./anguilliform_base/raw_data'), Path(r'./carangiform_base/raw_data'), Path(r'./ostraciiform_base/raw_data'), Path(r'./thunniform_base/raw_data')]


# define some constants for all the simulations

# xcoords corresponds to the location on the airfoil for the measurement in the .csv file
x_coords = np.array([0.0489, 0.0976, 0.1463, 0.1949, 0.2436, 0.2923, 0.3409, 0.3896, 0.4383, 0.487, 0.5356, 0.5843, 0.633, 0.6816, 0.7303, 
                    0.779, 0.8276, 0.8763, 0.925, 0.9736])

density = 1026.021
U = 1.0
chord = 1.0
waveNumber=6.28
length = 1.0
pivot = 0.85
maxAngle = 16
phaseAngle = -90

for location in data_locations:
    data_location = location
    save_location = data_location.parent.joinpath('processed_data')

    if force_flag == 1:
        force_paths = pio.get_files(data_location.joinpath('forces'), '*.csv')
        force_paths.sort()
        force_dfs = [pd.read_csv(path) for path in force_paths]

    if forceBin_flag == 1:
        forceBin_paths_top = pio.get_files(data_location.joinpath('force_bins', 'wing_top'), '*.csv')
        forceBin_paths_top.sort()    
        forceBin_dfs_top = [pd.read_csv(path, skiprows=5) for path in forceBin_paths_top]

        forceBin_paths_bottom = pio.get_files(data_location.joinpath('force_bins', 'wing_bottom'), '*.csv')
        forceBin_paths_bottom.sort()    
        forceBin_dfs_bottom = [pd.read_csv(path, skiprows=5) for path in forceBin_paths_top]

    if force_flag == 1:
        processed_force_dfs = []

        for i, df in enumerate(force_dfs):
            processed_dict = {}
            # get general data about dataframe
            columns = df.columns
            st_raw = force_paths[i].name[-8:-4]
            st = float(st_raw.replace('_','.'))
            Re_case = force_paths[i].name[2:6]

            time = df['time'].values
            time = np.reshape(time, time.size)
            dx = np.linspace(time[0], time[-1], 9000)
            
            processed_dict['time'] = dx

            # uniform sample the data for each column and put into dict
            for column in columns[1:]:
                y_values = df[column].values        
                processed_dict[column] = uniform_sampling(dx, time, y_values)

            out_df = pd.DataFrame(data=processed_dict)

            if forceBin_flag == 1:
                forceBin_dict = {'top':{}, 'bottom':{}}
                base_case = data_location.parts[0][0:-5]        
                
                forceBin_df_top = forceBin_dfs_top[i]
                columns_top = forceBin_df_top.columns

                time = forceBin_df_top['time'].values
                time = np.reshape(time, time.size)
                dx = np.linspace(time[0], time[-1], 9000)

                forceBin_dict['top']['time'] = dx

                # uniform sample the data for each column and put into dict
                for column in columns_top[1:]:
                    y_values = forceBin_df_top[column].values       
                    forceBin_dict['top'][column] = uniform_sampling(dx, time, y_values)

                out_bin_top_df = pd.DataFrame(data=forceBin_dict['top']) 
                
                # perform the same operations on the bottom
                forceBin_df_bottom = forceBin_dfs_bottom[i]
                columns_bottom = forceBin_df_bottom.columns
                
                time = forceBin_df_bottom['time'].values
                time = np.reshape(time, time.size)
                dx = np.linspace(time[0], time[-1], 9000)

                forceBin_dict['bottom']['time'] = dx
                # uniform sampling for the bottom airofil (times should be the same for each case)
                for column in columns_bottom[1:]:
                    y_values = forceBin_df_bottom[column].values       
                    forceBin_dict['bottom'][column] = uniform_sampling(dx, time, y_values)


                out_bin_bottom_df = pd.DataFrame(data=forceBin_dict['bottom']) 

                #grabs all the columns that have total_x or total_y
                top_total_x = out_bin_top_df.filter(like='total_x').values
                top_total_y = out_bin_top_df.filter(like='total_y').values
                bottom_total_x = out_bin_bottom_df.filter(like='total_x').values
                bottom_total_y = out_bin_bottom_df.filter(like='total_y').values
                
                CP = np.zeros((out_bin_top_df['time'].shape[0], 2))
                plat = np.zeros((out_bin_top_df['time'].shape[0], 2))
                ptotal = np.zeros((out_bin_top_df['time'].shape[0], 2))
                
                # perform the power calculation based on the locomotion mode
                if base_case == 'anguilliform':
                    amplitude = np.array([0.1, 0.323, 0.310])
                    for i, t in enumerate(forceBin_dict['top']['time']):
                        plat[i, 0] = carangiform_power(x_coords, t, top_total_y[i,:], st, length, waveNumber, amplitude)
                        ptotal[i, 0] = np.sum(np.multiply(top_total_x[i], U)) + plat[i, 0]
                        CP[i, 0] = ptotal[i, 0]/(0.5*density*U**3*chord)
                    for i, t in enumerate(forceBin_dict['bottom']['time']):
                        plat[i, 1] = carangiform_power(x_coords, t, bottom_total_y[i,:], st, length, waveNumber, amplitude)
                        ptotal[i, 1] = np.sum(np.multiply(bottom_total_x[i], U)) + plat[i, 1]
                        CP[i, 1] = ptotal[i, 1]/(0.5*density*U**3*chord)
                elif base_case == 'carangiform':
                    amplitude = np.array([0.1, -0.825, 1.625])
                    for i, t in enumerate(forceBin_dict['top']['time']):
                        plat[i, 0] = carangiform_power(x_coords, t, top_total_y[i,:], st, length, waveNumber, amplitude)
                        ptotal[i, 0] = np.sum(np.multiply(top_total_x[i], U)) + plat[i, 0]
                        CP[i, 0] = ptotal[i, 0]/(0.5*density*U**3*chord)
                    for i, t in enumerate(forceBin_dict['bottom']['time']):
                        plat[i, 1] = carangiform_power(x_coords, t, bottom_total_y[i,:], st, length, waveNumber, amplitude)
                        ptotal[i, 1] = np.sum(np.multiply(bottom_total_x[i], U)) + plat[i, 1]
                        CP[i, 1] = ptotal[i, 0]/(0.5*density*U**3*chord)
                elif base_case == 'ostraciiform':
                    amplitude = np.array([0.1, -0.825, 1.625])
                    for i, t in enumerate(forceBin_dict['top']['time']):
                        plat[i, 0] = ostraciiform_power(x_coords, t, top_total_y[i,:], st, length, waveNumber, amplitude, pivot, maxAngle, phaseAngle)
                        ptotal[i, 0] = np.sum(np.multiply(top_total_x[i], U)) + plat[i, 0]
                        CP[i, 0] = ptotal[i, 0]/(0.5*density*U**3*chord)
                    for i, t in enumerate(forceBin_dict['bottom']['time']):
                        plat[i, 1] = ostraciiform_power(x_coords, t, bottom_total_y[i,:], st, length, waveNumber, amplitude, pivot, maxAngle, phaseAngle)
                        ptotal[i, 1] = np.sum(np.multiply(bottom_total_x[i], U)) + plat[i, 1]
                        CP[i, 1] = ptotal[i, 0]/(0.5*density*U**3*chord)
                elif base_case == 'thunniform':
                    amplitude = np.array([0.1, -0.825, 1.625])
                    for i, t in enumerate(forceBin_dict['top']['time']):
                        plat[i, 0] = ostraciiform_power(x_coords, t, top_total_y[i,:], st, length, waveNumber, amplitude, pivot, maxAngle, phaseAngle)
                        ptotal[i, 0] = np.sum(np.multiply(top_total_x[i], U)) + plat[i, 0]
                        CP[i, 0] = ptotal[i, 0]/(0.5*density*U**3*chord)
                    for i, t in enumerate(forceBin_dict['bottom']['time']):
                        plat[i, 1] = ostraciiform_power(x_coords, t, bottom_total_y[i,:], st, length, waveNumber, amplitude, pivot, maxAngle, phaseAngle)
                        ptotal[i, 1] = np.sum(np.multiply(bottom_total_x[i], U)) + plat[i, 1]
                        CP[i, 1] = ptotal[i, 1]/(0.5*density*U**3*chord)

                out_df.insert(loc=len(out_df.columns), column='Plat', value=plat[:,0] + plat[:,1])
                out_df.insert(loc=len(out_df.columns), column='Ptotal', value=ptotal[:,0] + plat[:,1])
                out_df.insert(loc=len(out_df.columns), column='CP', value=CP[:,0] + CP[:,1])

            save_file = 'Re-' + Re_case + '_' + 'St-' + st_raw + '.csv'

            out_df.to_csv(save_location.joinpath(save_file), index=False)
            #out_bin_df.to_csv(save_location.joinpath('force_bins', save_file), index=False)