""""
Creating the levle set function

INPUT segmented image
OUTPUT: mesh, level set fucntion of the astrocyte


""""


import numpy as np
from dolfin import *
import time


def Image2levelset3D(final_image, eps = 0.001):

    startime = time.time()

    # real size image
    x_size_real, y_size_real, z_size_real = np.shape(final_image)

    #Set the x-size of the image euqal to x_axes_im_size
    x_size = 1.0

    #   y_size proportioned to x_size
    y_size = (y_size_real - 1) * x_size / (x_size_real - 1)

    #   z_size proportioned to x_size
    z_size = (z_size_real - 1) * x_size / (x_size_real - 1)



    def vertex2pixel(x_coord, y_coord, z_coord):
        x_coord_node = x_coord * (x_size_real - 1) / x_size
        y_coord_node = y_coord * (x_size_real - 1) / x_size
        z_coord_node = z_coord * (x_size_real - 1) / x_size
        return (x_coord_node, y_coord_node, z_coord_node)


    mesh = BoxMesh(0., 0., 0., x_size, y_size, z_size, x_size_real,  y_size_real, z_size_real)

    x = mesh.coordinates()

    ii, jj, kk = vertex2pixel(x[:, 0], x[:, 1], x[:,2])

    ii = np.array(ii, dtype=int)
    jj = np.array(jj, dtype=int)
    kk = np.array(kk, dtype=int)

    image_values = final_image[ii, jj, kk]


    V = FunctionSpace(mesh, 'CG', 1)
    image_f = Function(V)

    # Values will be dof ordered
    d2v = dof_to_vertex_map(V)
    image_values = image_values[d2v]
    image_f.vector()[:] = image_values


    # Image manip
    u = TrialFunction(V)
    v = TestFunction(V)
    a = eps ** 2 * inner(grad(u), grad(v)) * dx + inner(u, v) * dx
    L = inner(image_f, v) * dx

    # Setup solver
    A = assemble(a)

    solver = KrylovSolver(A, 'cg', 'amg')


    b = assemble(L)
    image_f_smoothed = Function(V)
    solver.solve(image_f_smoothed.vector(), b)


    return (mesh, image_f_smoothed, x_size, y_size, z_size, x_size_real, y_size_real, z_size_real, image_f)
