"""
Solving reaction-diffusion metabolic model in a Rectangular domain
Reference Farina et al. DOI: https://doi.org/10.1101/2022.07.21.500921

The setting is shown in Figure 3 and Figure 4

10 raction sites uniformly distributed for the four metabolic pathways:
glycolysis (HXK and PYRK), lactate dehydrogenase (LDH) and mitochondrial activity (Mito).


"""


from matplotlib import pyplot as plt
from mshr import *
from dolfin import *
import numpy as np
from timeit import default_timer as timer
import sys

from parameters_experiment_2 import parameters_adp_atp_sum_dep

b_0, c_0, value_hxk, value_pyrk, value_ldh, value_Mito, value_act, f_glc, lac_degr = parameters_adp_atp_sum_dep(3.2)

l = 4.
L=140

# function to sort
def uniform_point_circle(N, distribution, x_0 = l/2, y_0 = L/2):

    if distribution == 'uniform':
        # Define randomly the coordinate x and y of N darts
        x_coordinate = np.random.uniform(low=0, high=l, size=N)
        y_coordinate = np.random.uniform(low=0, high=L, size=N)
    else:
        print('Only acceptable distribution is uniform')

    coordinat = np.array([x_coordinate, y_coordinate])


    new_coord1 = coordinat[:, coordinat[0] < l]
    new_coord2 = new_coord1[:, new_coord1[1] < L]
    new_coord3 = new_coord2[:, new_coord2[0] >0]
    new_coord = new_coord3[:, new_coord3[1] > 0]

    while len(new_coord[0]) <  N:
        new_coord = uniform_point_circle(N, distribution,x_0, y_0)

    return (new_coord)

def read_or_sort_enzyme_location(distribution):

    if distribution=='uniform':
        coordinate_enzymes_hxk = uniform_point_circle(M, 'uniform')
        coordinate_enzymes_pyrk = uniform_point_circle(M, 'uniform')
        coordinate_enzymes_ldh = uniform_point_circle(M, 'uniform')
        coordinate_enzymes_mito = uniform_point_circle(M, 'uniform')
    else:
        print('only accepable distribution is uniform')

    list_of_enzymes = np.ones((8,M))
    list_of_enzymes[0] = coordinate_enzymes_hxk[0]
    list_of_enzymes[1] = coordinate_enzymes_hxk[1]
    list_of_enzymes[2] = coordinate_enzymes_pyrk[0]
    list_of_enzymes[3] = coordinate_enzymes_pyrk[1]
    list_of_enzymes[4] = coordinate_enzymes_ldh[0]
    list_of_enzymes[5] = coordinate_enzymes_ldh[1]
    list_of_enzymes[6] = coordinate_enzymes_mito[0]
    list_of_enzymes[7] = coordinate_enzymes_mito[1]

    return list_of_enzymes


# Start timer
startime = timer()

T = 500 # final time
num_step = 2800 # number of time step
dt = T / num_step


# Create mesh
channel = Rectangle(Point(0, 0), Point(l,L))
mesh = generate_mesh(channel, 380)
area = l * L

# Define gaussian function to locate a single spatial reacion rate with center x0,y0
def Gauss_ufl(x0, y0, sigma = 1.0, mesh=mesh):
    x = SpatialCoordinate(mesh)
    gauss = 1./(pi * 2 * sigma**2) * exp(-((x[0] - x0) ** 2 + (x[1] - y0) ** 2)/(2 * sigma**2))
    return(gauss)

# Compute the adaptive normalisattion for the spatial reaction sites
def eta_gauss(gaussian):
    gauss_norm = assemble(gaussian * dx(mesh))
    gauss_normalised = gaussian / gauss_norm
    return(gauss_normalised)

# Finite Element space for the concentration
P1 = FiniteElement('P', triangle, 1)
element = MixedElement([P1,P1,P1,P1,P1,P1])
V = FunctionSpace(mesh,element)

# Define test functions
v_1, v_2, v_3, v_4, v_5, v_6 = TestFunctions(V)

# Define Trial functions which must be Functions instead of Trial Functions cause the pb is non linear
u = Function (V)

# Define the initial condition of concetrations
a_0 = 0.
d_0 = 0.
e_0 = 0.
f_0 = 0.

u_0 = Expression(('a_0', 'b_0', 'c_0','d_0', 'e_0', 'f_0'), a_0=a_0, b_0=b_0, c_0=c_0, d_0=d_0, e_0=e_0, f_0=f_0, degree=1)
u_n = project(u_0, V)

a, b, c, d, e, f = split(u)
a_n, b_n, c_n, d_n, e_n, f_n = split(u_n)

# Time stepping
t = [0.0]

# Define Constant
k = Constant(dt)

# Define Reaction Rate

M = 10 # Number of reaction sites per react

# Define the coordinates (x,y) for where to locate M reaction sites per reaction

coord_enz = read_or_sort_enzyme_location('uniform')

coordinate_enzymes_hxk = coord_enz[0:2]
coordinate_enzymes_pyrk = coord_enz[2:4]
coordinate_enzymes_ldh = coord_enz[4:6]
coordinate_enzymes_mito = coord_enz[6:8]


#######################################
k_hxk = Constant(value_hxk)

Number_enzymes_hxk = len(coordinate_enzymes_hxk[0])
print('number of hxk', Number_enzymes_hxk)

Gaussian_hxk = Constant(0)
for i in range(Number_enzymes_hxk):
    Gaussian_hxk +=  Gauss_ufl(coordinate_enzymes_hxk[0,i], coordinate_enzymes_hxk[1,i])

Gaussian_hxk_normalized = eta_gauss(Gaussian_hxk)
K_hxk = Gaussian_hxk_normalized * k_hxk * Constant(area)
print(assemble(K_hxk* dx(mesh))/area)


########################################
k_pyrk = Constant(value_pyrk)

Number_enzymes_pyrk = len(coordinate_enzymes_pyrk[0])
print('number of pyrk', Number_enzymes_pyrk)

Gaussian_pyrk = Constant(0)
for i in range(Number_enzymes_pyrk):
    Gaussian_pyrk += Gauss_ufl(coordinate_enzymes_pyrk[0,i], coordinate_enzymes_pyrk[1,i])

Gaussian_pyrk_normalized = eta_gauss(Gaussian_pyrk)
K_pyrk = Gaussian_pyrk_normalized * k_pyrk  * Constant(area)
print(assemble(K_pyrk* dx(mesh))/area)
########################################
k_ldh = Constant(value_ldh)

Number_enzymes_ldh = len(coordinate_enzymes_ldh[0])
print('number of ldh', Number_enzymes_ldh)

Gaussian_ldh =  Constant(0)
for i in range(Number_enzymes_ldh):
    Gaussian_ldh +=  Gauss_ufl(coordinate_enzymes_ldh[0,i], coordinate_enzymes_ldh[1,i])

Gaussian_ldh_normalized = eta_gauss(Gaussian_ldh)
K_ldh = Gaussian_ldh_normalized * k_ldh  * Constant(area)
print(assemble(K_ldh* dx(mesh))/area)
########################################
k_mito = Constant(value_Mito)

Number_enzymes_mito= len(coordinate_enzymes_mito[0])
print('number of mito', Number_enzymes_mito)

Gaussian_mito = Constant(0)
for i in range(Number_enzymes_mito):
    Gaussian_mito += Gauss_ufl(coordinate_enzymes_mito[0,i], coordinate_enzymes_mito[1,i])

Gaussian_mito_normalized = eta_gauss(Gaussian_mito)
K_mito = Gaussian_mito_normalized * k_mito  * Constant(area)
print(assemble(K_mito* dx(mesh))/area)
########################################
K_act = Constant(value_act)

print(assemble(K_act* dx(mesh))/area)
#########################################

# Save an image of the configuration

Rect= plt.Rectangle((0 , 0,),l, L, color='g', clip_on=False, fill=False)

fig, ax = plt.subplots()
ax.scatter(coordinate_enzymes_hxk[0], coordinate_enzymes_hxk[1],marker = '*', label='hxk')
ax.scatter(coordinate_enzymes_pyrk[0], coordinate_enzymes_pyrk[1],marker = 'v', label='pyrk')
ax.scatter(coordinate_enzymes_ldh[0], coordinate_enzymes_ldh[1],marker = 's', label='ldh')
ax.scatter(coordinate_enzymes_mito[0], coordinate_enzymes_mito[1],marker = 'o', label='mito')
ax.add_patch(Rect)
ax.set_xlim(-2,15)
fig.legend()
fig.tight_layout()


# Diffusion constant [\mu m^2/s]

D_a = Constant(0.6E3)
D_b = Constant(0.15E3)
D_c = Constant(0.15E3)
D_d = Constant(0.51E3)
D_e = Constant(0.64E3)
D_f = Constant(0.64E3)

# Define GLC influx

radius_influx = 1.0
subdomain = Expression('(pow(x[0],2)+pow(x[1],2)) < (r * r) ? 1. : 0', r=radius_influx, degree=1)
subdomain_area = assemble(subdomain * dx(mesh))

print('subdomain area', subdomain_area)
# define influx of GLC in a subdomain of the circle
influx = f_glc * area /subdomain_area

f_1 = Expression('(pow(x[0],2)+pow(x[1],2)) < (r * r) ? influx : 0', influx=influx, r=radius_influx, degree=1)


# Degradation od LAC

q_degree = 3
dx = dx(metadata={'quadrature_degree': q_degree})

radius_outflux = radius_influx
subdomain_outflux = Expression('(pow(x[0]- l,2)+pow(x[1]- L,2)) < (r * r) ? 1. : 0', r=radius_outflux,l=l,  L=L, degree=1)
subdomain_outflux_area = assemble(subdomain_outflux * dx(mesh))

lac_degr = lac_degr * area / subdomain_outflux_area

eta_f = Expression('(pow(x[0] - l,2)+pow(x[1]- L,2)) < (r * r) ? outflux : 0', outflux = lac_degr, r=radius_outflux,l=l,  L=L, degree=1)

# Weak form

F = ((a - a_n) / k) * v_1 * dx \
    + D_a * dot(grad(a), grad(v_1)) * dx + K_hxk * a * b**2 * v_1 * dx \
    + ((b - b_n) / k) * v_2 * dx  \
    + D_b * dot(grad(b), grad(v_2)) * dx + 2 * K_hxk * a * b**2 * v_2 * dx - 2 * K_pyrk * d *  c**2 * v_2 * dx - 28 * K_mito * e * c**28 * v_2 * dx + K_act * b * v_2 * dx\
    + ((c - c_n) / k)*v_3 * dx \
    + D_c * dot(grad(c), grad(v_3)) * dx - 2 * K_hxk * a * b**2 * v_3 * dx  + 2 * K_pyrk * d * c**2 * v_3 * dx - K_act * b * v_3 * dx + 28 * K_mito * e * c**28 * v_3 * dx\
    + ((d - d_n) / k)*v_4 * dx\
    + D_d * dot(grad(d),grad(v_4)) * dx - 2 * K_hxk * a * b**2 * v_4 * dx + K_pyrk * d * c**2 * v_4 * dx\
    + ((e - e_n) / k)*v_5 * dx\
    + D_e * dot(grad(e),grad(v_5)) * dx  - K_pyrk  * d * c**2 * v_5 * dx + K_ldh * e * v_5 * dx + K_mito * e * c**28 * v_5 * dx\
    + ((f - f_n) / k)*v_6 * dx\
    + D_f * dot(grad(f),grad(v_6)) * dx - K_ldh * e * v_6 * dx + eta_f * f * v_6 * dx\
    - f_1 * v_1 * dx

# Empty list to store the solutions

list_a =[]
list_b =[]
list_c =[]
list_d =[]
list_e = []
list_f = []

list_a.append(assemble(a_n * dx)/area)
list_b.append(assemble(b_n * dx)/area)
list_c.append(assemble(c_n * dx)/area)
list_d.append(assemble(d_n * dx)/area)
list_e.append(assemble(e_n * dx)/area)
list_f.append(assemble(f_n * dx)/area)

time_list = []


time_list.append(t[0])
J = derivative(F, u)

Nmax = 50
toll_a = 1.e-10

for n in range(num_step):
    print(n)
    # Solve the variational form for time step

    solve(F == 0, u,  solver_parameters={'newton_solver': {'maximum_iterations': 200}})

    # Save solution to file (VTK)
    _a, _b, _c, _d, _e, _f = u.split()


    # Update the previous solution
    u_n.assign(u)

    t[0] = t[0] + dt

    # update source term

    #f_1.t = t[0]

    # Update time
    time_list.append(t[0])

    # Save the concentrations in a list

    list_a.append(assemble(_a * dx)/area)

    list_b.append(assemble(_b * dx)/area)

    list_c.append(assemble(_c * dx)/area)

    list_d.append(assemble(_d * dx)/area)

    list_e.append(assemble(_e * dx)/area)

    list_f.append(assemble(_f * dx)/area)

number_sample = float(sys.argv[1])

# Name of the output file
folder = 'uniform'
file_name = 'Realisation'+str(number_sample)

# Create a single list with all the solutions
list_of_list = [list_a[-1], list_b[-1], list_c[-1], list_d[-1], list_e[-1], list_f[-1]]

# stop time
aftersolve = timer()
tottime = aftersolve-startime
print('final time', tottime)

plt.savefig(folder + '/' + file_name + '.png', dpi=150)
np.save(folder + '/' + file_name + '.npy', np.asarray(list_of_list))
np.save(folder + '/' + file_name + 'enzymes'+'.npy', coord_enz)