from file_tree import FileTree
from fsl_pipe import Pipeline, In, Out, Ref

# Load some libraries that help us to run FSL tools
from fsl import wrappers
from subprocess import run
from os import getenv

# Load the file-tree describing the data directory
tree = FileTree.read("data.tree").update_glob("T1w", link=[("group", "subject")])

# Create the pipeline
pipe = Pipeline()


# Add recipes to the pipeline
# Filled by user
@pipe(submit=dict(jobtime=3))
def brain_extract(T1w: In, T1w_brain_first_attempt: Out, T1w_cut: Out, T1w_brain: Out):
    wrappers.bet(T1w, T1w_brain_first_attempt)
    run(["standard_space_roi", T1w_brain_first_attempt, T1w_cut, "-roiNONE", "-altinput", T1w, "-ssref", f"{getenv('FSLDIR')}/data/standard/MNI152_T1_2mm_brain"])
    wrappers.bet(T1w_cut, T1w_brain, fracintensity=0.4)

@pipe(submit=dict(jobtime=10))
def segmentation(T1w_brain: In, basename: Ref("segment/basename"), gm_pve: Out("segment/gm_pve")):
    wrappers.fast(T1w_brain, out=basename)

@pipe(submit=dict(jobtime=3))
def linear_registration(gm_pve: In("segment/gm_pve"), linear_reg: Out):
    run(["flirt", "-in", gm_pve, "-ref", f"{getenv('FSLDIR')}/data/standard/tissuepriors/avg152T1_gray", "-omat", linear_reg])

@pipe(submit=dict(jobtime=10))
def nonlinear_registration(gm_pve: In("segment/gm_pve"), linear_reg: In, nonlinear_reg: Out, gm_pve_in_standard: Out):
    run(["fnirt", f"--in={gm_pve}", f"--ref={getenv('FSLDIR')}/data/standard/tissuepriors/avg152T1_gray", f"--aff={linear_reg}", f"--cout={nonlinear_reg}", "--config=GM_2_MNI152GM_2mm.cnf", f"--iout={gm_pve_in_standard}"])

@pipe(no_iter=["group", "subject"], submit=dict(jobtime=3))
def create_template(gm_pve_in_standard: In, gm_pve_template: Out):
    import nibabel as nib
    import numpy as np
    ref_img = nib.load(gm_pve_in_standard.data[0])
    average_gm_pve = np.mean([nib.load(fn).get_fdata() for fn in gm_pve_in_standard.data], axis=0)
    template = (average_gm_pve + average_gm_pve[::-1, :, :]) / 2  # add the flipped image to enforce left-right symmetry
    nib.save(nib.Nifti1Image(template, ref_img.affine, ref_img.header), gm_pve_template)

@pipe(submit=dict(jobtime=10))
def register_to_template(gm_pve: In("segment/gm_pve"), gm_pve_template: In, linear_reg: In, gm_pve_in_template: Out, gm_pve_jac: Out, gm_pve_mod: Out):
    wrappers.fnirt(gm_pve, gm_pve_template, aff=linear_reg, iout=gm_pve_in_template, jout=gm_pve_jac, config="GM_2_MNI152GM_2mm.cnf")
    wrappers.fslmaths(gm_pve_in_template).mul(gm_pve_jac).run(gm_pve_mod)


# Run the pipeline command line interface
if __name__ == "__main__":
    pipe.cli(tree)
