Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate EvanBeal's contrast agnostic registration model into SCT's registration scripts #3760

Closed
1 task
joshuacwnewton opened this issue Apr 12, 2022 · 24 comments · Fixed by #3807
Closed
1 task
Assignees
Labels
Milestone

Comments

@joshuacwnewton
Copy link
Member

This issue is for planning the integration of @EvanBeal's contrast agnostic registration model project into SCT's registration scripts.

Proposed changes (copied from 2022-04-05 meeting minutes)

  • Include preprocessing steps
  • Load the models weights and build the architecture (using VoxelMorph functions)
  • Use the models to produce the warping field
  • Transform the warping field so it can be used with sct_apply_transfo
  • Call this function from the register function of sct_register_to_template
    • And/or sct_register_multimodal? (TODO: Needs clarification)

Remaining work

@EvanBeal
Copy link
Contributor

EvanBeal commented Apr 18, 2022

Registration model conversion to ONNX format

Converting the H5 models to ONNX is feasible.
This can be done by converting the H5 model to the TensorFlow SavedModel format in a python script:

import tensorflow as tf
import voxelmorph as vxm
model = vxm.networks.VxmDense.load('model/model_velres3264_2000.h5', input_model=None)
tf.saved_model.save(model, "model_tf_savedmodel_format")

And then use tf2onnx to convert the model from TensorFlow SavedModel format to ONNX using the following command line in the terminal:

python3 -m tf2onnx.convert --saved-model model_tf_savedmodel_format --output "model.onnx"

The ONNX registration model takes a pair of images as input and returns a dense deformation field, as done with the H5 registration model.

Limit appearing when using ONNX format

The drawback of the conversion to the ONNX format is that the size of the input images need to be defined prior to the conversion, which means that the models will be limited to register images of a specific size (in the above example, the ONNX model can only register images of size 160 x 160 x 192, which is the default size on which the model model_velres3264_2000 has been trained on).

In the multimodal registration project, images of any size could be registered. The approach was to load the weights of the model trained for a specific size, create a new model for the size of interest and then set the weights of the trained model to this new model, allowing images of any size to be registered. This is no longer possible with the conversion to ONNX and the dependencies on VoxelMorph and tensorflow removed.

Potential solutions

Two possibilities exist to enable registration of images of any size with this new constraint.

  1. Use the sub-volumes approach that has been developed in the multimodal registration project. This approach uses an algorithm to split the image, performs registration on chunks of the image, smooths out discontinuities between the warping fields of two adjacent chunks, and concatenate all sub-warping fields in space. Registration accuracy is slightly decreased compared to when registration is performed directly on the whole input images. E.g. the registration of a pair of images of size 220 x 220 x 192 by a model that takes images of size 160 x 160 x 192 as inputs will be done by creating 4 sub-volumes pairs (of size: 0-160 x 0-160 x 0-192, 60-220 x 0-160 x 0-192, 0-160 x 60-220 x 0-192 and 60-220 x 60-220 x 0-192), registering the 4 pairs and concatenating the 4 sub-warping fields to obtain the warping field of size 220 x 220 x 192.
  2. Create a set of models with different input shapes (using the weights from the same trained model) and then choose the most appropriate model regarding the size of input images during inference. The model with the closest (superior) input shape should be chosen and the input images should be padded to this size. The drawback of this approach is that it is unknown how many models should be present in the set of models to cover the majority of cases of interest. However, it may be slightly more precise than the sub-volumes approach as there will not be any smoothing caused by the sub-volumes overlapping. E.g. Registration of a pair of images of size 220 x 220 x 192. For a set of models taking input images of size 128 x 128 x 128, 160 x 160 x 160, 192 x 192 x 192 and 224 x 224 x 224, the registration will be done by the model taking images of size 224 x 224 x 224 as inputs after a padding of the pair of images of size 220 x 220 x 192.

These two possibilities could also be combined together by prioritising the second option but using the first one if no model in the set is close enough to the input images size.

Code for model conversion and application of ONNX registration model

Conversion of H5 model to ONNX using a different input size than the one used during the model training (e.g. 128 x 128 x 128):
import tensorflow as tf
import voxelmorph as vxm

size_x = 128
size_y = 128
size_z = 128
reg_args = dict(
    inshape=[size_x, size_y, size_z],
    int_steps=5,
    int_resolution=2,
    svf_resolution=2,
    nb_unet_features=([256, 256, 256, 256], [256, 256, 256, 256, 256, 256])
)

trained_model = vxm.networks.VxmDense.load('model/model_velres3264_2000.h5', input_model=None)
model = vxm.networks.VxmDense(**reg_args)
model.set_weights(trained_model.get_weights())

tf.saved_model.save(model, "tmp_model_128128128")

Use of an ONNX registration model to produce a dense deformation field compatible with sct_apply_transfo:
import onnxruntime as rt
import numpy as np
import nibabel as nib
from scipy import ndimage

# Load preprocessed data (scaled between 0 and 1 and with the moving data in the space of the fixed one)
fixed = nib.load("data_processed_time_analysis/data1/sub-vallHebron06/anat/sub-vallHebron06_T1w.nii.gz")
moving = nib.load("data_processed_time_analysis/data1/sub-vallHebron06/anat/sub-vallHebron06_T2w.nii.gz")

# N.B.
# These data are in my local computer but any data could be used to perform the same analysis.
# It only needs to be scaled and set in a common space (e.g. using sct_register_multimodal with -identity 1)

session = rt.InferenceSession("model.onnx", None)
input_name_moving = session.get_inputs()[0].name
input_name_fixed = session.get_inputs()[1].name
output_name = session.get_outputs()[0].name

# Crop the data to the input size that was specified when the registration model has been converted to the ONNX format
# (Here 160x160x192)
data_moving = np.expand_dims(moving.get_fdata()[:160, :160, :192].squeeze(), axis=(0, -1)).astype(np.float32)
data_fixed = np.expand_dims(fixed.get_fdata()[:160, :160, :192].squeeze(), axis=(0, -1)).astype(np.float32)

result = session.run([output_name], {input_name_moving: data_moving, input_name_fixed: data_fixed})

warp_data = result[0][0]

# Warping field
# Modify the warp data so it can be used with sct_apply_transfo()
# (upsample, add a time dimension, change the sign of some axes and set the intent code to vector)
warp_data_zoom = ndimage.zoom(warp_data, zoom=(2, 2, 2, 1))
warp_data = 2 * warp_data_zoom
# Change the sign of the vectors and the order of the axes components to be correctly used with sct_apply_transfo
# and to to get the same results with sct_apply_transfo() and when using model.predict() or vxm.networks.Transform()
orientation_conv = "LPS"
fx_im_orientation = list(nib.aff2axcodes(fixed.affine))
opposite_character = {'L': 'R', 'R': 'L', 'A': 'P', 'P': 'A', 'I': 'S', 'S': 'I'}
perm = [0, 1, 2]
inversion = [1, 1, 1]
for i, character in enumerate(orientation_conv):
    try:
        perm[i] = fx_im_orientation.index(character)
    except ValueError:
        perm[i] = fx_im_orientation.index(opposite_character[character])
        inversion[i] = -1
warp_data_exp = np.expand_dims(warp_data, axis=3)
warp_data_exp_copy = np.copy(warp_data_exp)
warp_data_exp[..., 0] = inversion[0] * warp_data_exp_copy[..., perm[0]]
warp_data_exp[..., 1] = inversion[1] * warp_data_exp_copy[..., perm[1]]
warp_data_exp[..., 2] = inversion[2] * warp_data_exp_copy[..., perm[2]]
warp = nib.Nifti1Image(warp_data_exp, fixed.affine)
warp.header['intent_code'] = 1007

# Save the warping field that can be later used with sct_apply_transfo
nib.save(warp, f'warp_field.nii.gz')

@jcohenadad
Copy link
Member

The drawback of the conversion to the ONNX format is that the size of the input images need to be defined prior to the conversion, which means that the models will be limited to register images of a specific size

Ah! I didn't think of that when we switched to ONNX-only support with SCT. That is quite a significant problem, and I foresee it be a problem for other models as well (for segmentation). Thank you for proposing the potential solutions, @EvanBeal , but they are both quite limiting/heavy as you also pointed out (1: requires extra code, 2: requires unnecessarily high amount of models to be stored, and each with a fixed resolution so there is always a risk that a user will input a volume that is not well covered by our sets of resolutions).

@joshuacwnewton how 'problematic' would it be to support H5 models again? IIRC the issue was mostly at SCT installation?

@joshuacwnewton
Copy link
Member Author

joshuacwnewton commented Apr 19, 2022

I agree with @jcohenadad that the workarounds for fixed-input ONNX models probably wouldn't be sufficient, and that this represents a significant problem.


@joshuacwnewton how 'problematic' would it be to support H5 models again? IIRC the issue was mostly at SCT installation?

It depends on which version of Tensorflow @EvanBeal used in his development. (This is relevant for #3367.)

@EvanBeal, could you do a quick pip freeze to share the versions of the packages you're currently working with?


Regardless of the specific version of TF, the main issue I have is that this would place extra burden on SCT devs in the long-term, because we would be the ones maintaining a dependency on Tensorflow just to support this feature.

It's definitely possible to depend on Tensorflow (we've been doing it for years). but I would personally prefer to work towards being a PyTorch-only project, and explore other options first, if we can?

(Admittedly, I am a bit biased though, because I had set my hopes on #3738 for making SCT dev's lives easier in the future.)


Aside...

I think this situation is a good indicator that there could be better communication in the future between SCT devs and NP research students.

Say, for example, there was a quick meeting early on (to consult about the dev side of the project). We might have been able to catch that this would be a Tensorflow-based project much earlier, before @EvanBeal put in the the work that he has.

@joshuacwnewton
Copy link
Member Author

I would personally prefer to explore other options first, if we can?

I have some ideas off the top of my head that I want to quickly look into:

  1. Variable-input ONNX models (I assume this is not possible based on @EvanBeal's investigations, but I want to quickly check this
  2. Converting the Tensorflow models to PyTorch models somehow
    • I mention this because I notice the VoxelMorph repo has some torch scripts -- however, they seem to be much more limited, and don't contain SynthMorph training scripts, so I assume this isn't viable?

@EvanBeal
Copy link
Contributor

It depends on which version of Tensorflow @EvanBeal used in his development. (This is relevant for #3367.)

I am using Tensorflow and Keras 2.7.0

@EvanBeal, could you do a quick pip freeze to share the versions of the packages you're currently working with?

pip freeze
absl-py @ file:///home/conda/feedstock_root/build_artifacts/absl-py_1634676905105/work
aiohttp @ file:///Users/runner/miniforge3/conda-bld/aiohttp_1636085381581/work
antspyx==0.3.1
astor @ file:///home/conda/feedstock_root/build_artifacts/astor_1593610464257/work
astunparse @ file:///home/conda/feedstock_root/build_artifacts/astunparse_1610696312422/work
async-timeout==3.0.1
attrs @ file:///home/conda/feedstock_root/build_artifacts/attrs_1620387926260/work
blinker==1.4
brotlipy @ file:///Users/runner/miniforge3/conda-bld/brotlipy_1636012303255/work
cached-property @ file:///home/conda/feedstock_root/build_artifacts/cached_property_1615209429212/work
cachetools @ file:///home/conda/feedstock_root/build_artifacts/cachetools_1633010882559/work
certifi==2021.10.8
cffi @ file:///Users/runner/miniforge3/conda-bld/cffi_1625835329396/work
chardet @ file:///Users/runner/miniforge3/conda-bld/chardet_1635814945668/work
chart-studio==1.1.0
click @ file:///Users/runner/miniforge3/conda-bld/click_1635822684275/work
cmake==3.22.0
contrib==0.3.0
cryptography @ file:///Users/runner/miniforge3/conda-bld/cryptography_1636041000388/work
cycler==0.11.0
dataclasses @ file:///home/conda/feedstock_root/build_artifacts/dataclasses_1628958434797/work
et-xmlfile==1.1.0
fire==0.4.0
flatbuffers @ file:///home/conda/feedstock_root/build_artifacts/python-flatbuffers_1617723079010/work
gast @ file:///home/conda/feedstock_root/build_artifacts/gast_1596839682936/work
google-auth @ file:///home/conda/feedstock_root/build_artifacts/google-auth_1635863830555/work
google-auth-oauthlib @ file:///home/conda/feedstock_root/build_artifacts/google-auth-oauthlib_1630497468950/work
google-pasta==0.2.0
grpcio @ file:///Users/runner/miniforge3/conda-bld/grpcio_1619796056251/work
h5py==3.5.0
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1593328102638/work
imageio==2.10.3
importlib-metadata @ file:///Users/runner/miniforge3/conda-bld/importlib-metadata_1636431750260/work
joblib==1.1.0
jsonschema==4.2.1
keras==2.7.0
Keras-Applications==1.0.8
Keras-Preprocessing @ file:///home/conda/feedstock_root/build_artifacts/keras-preprocessing_1610713559828/work
keras2onnx==1.7.0
kiwisolver==1.3.2
libclang==12.0.0
littleutils==0.2.2
Markdown @ file:///home/conda/feedstock_root/build_artifacts/markdown_1614595805172/work
matplotlib==3.4.3
multidict @ file:///Users/runner/miniforge3/conda-bld/multidict_1636019292225/work
networkx==2.6.3
neurite @ file:///Users/evan/Desktop/multimodal-registration/neurite
nibabel==3.2.1
nilearn==0.8.1
numpy==1.22.3
oauthlib @ file:///home/conda/feedstock_root/build_artifacts/oauthlib_1622563202229/work
onnx==1.11.0
onnxconverter-common==1.9.0
onnxruntime==1.11.0
openpyxl==3.0.9
opt-einsum @ file:///home/conda/feedstock_root/build_artifacts/opt_einsum_1617859230218/work
outdated==0.2.1
packaging==21.2
pandas==1.3.4
pandas-flavor==0.2.0
patsy==0.5.2
Pillow==8.4.0
pingouin==0.5.1
plotly==5.4.0
protobuf==3.15.8
pyasn1==0.4.8
pyasn1-modules==0.2.7
pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work
PyJWT @ file:///home/conda/feedstock_root/build_artifacts/pyjwt_1634405536383/work
pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1633192417276/work
pyparsing==2.4.7
PyQt5==5.15.6
PyQt5-Qt5==5.15.2
PyQt5-sip==12.9.0
pyrsistent==0.18.0
PySocks @ file:///Users/runner/miniforge3/conda-bld/pysocks_1635862544510/work
pystrum @ file:///Users/evan/Desktop/multimodal-registration/pystrum
python-dateutil==2.8.2
pytz==2021.3
pyu2f @ file:///home/conda/feedstock_root/build_artifacts/pyu2f_1604248910016/work
PyWavelets==1.2.0
PyYAML==6.0
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1608156231189/work
requests-oauthlib @ file:///home/conda/feedstock_root/build_artifacts/requests-oauthlib_1595492159598/work
retrying==1.3.3
rsa @ file:///home/conda/feedstock_root/build_artifacts/rsa_1614171254180/work
scikit-image==0.18.3
scikit-learn==1.0.1
scipy @ file:///Users/runner/miniforge3/conda-bld/scipy_1636410502957/work
seaborn==0.11.2
six==1.16.0
statsmodels==0.13.1
tabulate==0.8.9
tenacity==8.0.1
tensorboard @ file:///home/conda/feedstock_root/build_artifacts/tensorboard_1636578784510/work/tensorboard-2.7.0-py3-none-any.whl
tensorboard-data-server @ file:///Users/runner/miniforge3/conda-bld/tensorboard-data-server_1636046140314/work/tensorboard_data_server-0.6.0-py3-none-macosx_10_9_x86_64.whl
tensorboard-plugin-wit @ file:///home/conda/feedstock_root/build_artifacts/tensorboard-plugin-wit_1611075653546/work/tensorboard_plugin_wit-1.8.0-py3-none-any.whl
tensorflow==2.7.0
tensorflow-estimator==2.7.0
tensorflow-io-gcs-filesystem==0.22.0
termcolor==1.1.0
tf2onnx==1.9.3
threadpoolctl==3.0.0
tifffile==2021.11.2
tqdm==4.62.3
transforms3d==0.3.1
typing-extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1602702424206/work
urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1632350318291/work
voxelmorph @ file:///Users/evan/Desktop/multimodal-registration/voxelmorph
webcolors==1.11.1
Werkzeug @ file:///home/conda/feedstock_root/build_artifacts/werkzeug_1621518206714/work
wquantiles==0.6
wrapt @ file:///Users/runner/miniforge3/conda-bld/wrapt_1624971809674/work
xarray==2022.3.0
yarl @ file:///Users/runner/miniforge3/conda-bld/yarl_1636046920505/work
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1633302054558/work

Variable-input ONNX models (I assume this is not possible based on @EvanBeal's investigations, but I want to quickly check this

It may be possible for certain models and I might have missed how to do it. However, I don't think that it is possible for the model of this registration project because the model implies some specific transformations with interpolation computed on grids that are used in the model to rescale the deformation field produced as well as some custom layers that require to set a specific input size.

Converting the Tensorflow models to PyTorch models somehow
I mention this because I notice the VoxelMorph repo has some torch scripts -- however, they seem to be much more limited, and don't contain SynthMorph training scripts, so I assume this isn't viable?

This could definitely be a good idea and I think that it can be viable (the parts of SynthMorph that are not implemented in PyTorch concern the training of the model) ! If it is not problematic to have a PyTorch dependency, then it may be possible to convert the trained h5 model to PyTorch and then use the PyTorch implementation of the model used for registration (VxmDense) to enable data of any size as inputs. However, I have never tested the PyTorch implementations but from what I see (PyTorch vs Tensorflow), it seems like the two implementations of the VxmDense model are not exactly the same (but it may concern only optional parts), so I would need to investigate if it's possible to use the weights obtained on the model implemented with TF to build a model implemented with PyTorch.
If this is working, I think that we would still need to depend on VoxelMorph, to build the model architecture with the specific input size for the data to register but it will remove the Tensorflow dependency.

@joshuacwnewton
Copy link
Member Author

joshuacwnewton commented Apr 19, 2022

It depends on which version of Tensorflow @EvanBeal used in his development. (This is relevant for #3367.)

I am using Tensorflow and Keras 2.7.0

This is good to hear! This means that (worst case scenario) if we had to go back to Tensorflow, we could at least use an up-to-date version of Tensorflow, which won't block Python 3.8/3.9 upgrades.

If it is not problematic to have a PyTorch dependency, then it may be possible to convert the trained h5 model to PyTorch

Just to clarify this one part: SCT already depends on PyTorch through Ivadomed via our sct_deepseg functions. So, using PyTorch for this project would fit nicely with the current and future direction for SCT.

so I would need to investigate if it's possible to use the weights obtained on the model implemented with TF to build a model implemented with PyTorch.

Just to check -- Would it be possible to retrain the model using the Torch training scripts? (I'm guessing not, because the tensorflow-specific train_sythmorph.py script was used, and this has no torch analog?)

So, if I'm understanding correctly, the only option is to build the same model in both Tensorflow and PyTorch, then assign the weights layer by layer manually? (e.g. https://www.adrian.idv.hk/2020-12-31-torch2tf/, https://datascience.stackexchange.com/a/40511)

In that case, it might not be worth the time/effort... and it may just be better for SCT to depend on Tensorflow again.

@EvanBeal
Copy link
Contributor

Just to clarify this one part: SCT already depends on PyTorch through Ivadomed via our sct_deepseg functions. So, using PyTorch for this project would fit nicely with the current and future direction for SCT.

Yes perfect !

Just to check -- Would it be possible to retrain the model using the Torch training scripts? (I'm guessing not, because the tensorflow-specific train_sythmorph.py script was used, and this has no torch analog?)

So, no this is not possible. Your guess is correct, the training of the SynthMorph method implies some parts that are only build with TF and not PyTorch (generation of synthetic label maps and gray-scale images).

So, if I'm understanding correctly, the only option is to build the same model in both Tensorflow and PyTorch, then assign the weights layer by layer manually? (e.g. https://www.adrian.idv.hk/2020-12-31-torch2tf/, https://datascience.stackexchange.com/a/40511)

Indeed, it seems with these sources that the weights need to be assigned manually. However, I think that there might be other possibilities (like this one) but I don't know if it is applicable in our case. But even if the weights need to be set for each layer, I think that it could be interesting to try as it would need to be done only once to then have a PyTorch model that could be really useful.

In that case, it might not be worth the time/effort... and it may just be better for SCT to depend on Tensorflow again.

I will try to have a look as soon as possible on how feasible it is to go with PyTorch and I will document it in this issue so we could determine the best option ;)

@EvanBeal
Copy link
Contributor

EvanBeal commented Apr 25, 2022

Conversion of the registration models from Keras .h5 to PyTorch models

Using the strategy of assigning the weights layer by layer and using the PyTorch architecture of the registration model available in VoxelMorph (VxmDense), I was able to convert the TensorFlow models to PyTorch.

This was done with the following code:
import os
import numpy as np

# Import the Keras model and copy the weights
import voxelmorph as vxm
tf_model = vxm.networks.VxmDense.load('model/model_velres3264_2000.h5', input_model=None)
weights = tf_model.get_weights()

# import VoxelMorph with pytorch backend
import importlib
import torch
os.environ['VXM_BACKEND'] = 'pytorch'
importlib.reload(vxm)

# Build a Torch model and set the weights from the Keras model
reg_args = dict(
    inshape=[160, 160, 192],
    int_steps=5,
    int_downsize=2,   # same as int_resolution=2, in keras model
    unet_half_res=True,  # same as svf_resolution=2, in keras model
    nb_unet_features=([256, 256, 256, 256], [256, 256, 256, 256, 256, 256])
)
# Create the PyTorch model
pt_model = vxm.networks.VxmDense(**reg_args)

# Load the weights onto the PyTorch model
i = 0
i_max = len(list(pt_model.named_parameters()))
torchparam = pt_model.state_dict()
for k, v in torchparam.items():
    if i < i_max:
        print("{:20s} {}".format(k, v.shape))
        if k.split('.')[-1] == 'weight':
            torchparam[k] = torch.from_numpy(np.transpose(weights[i]))
        else:
            torchparam[k] = torch.from_numpy(weights[i])
        i += 1

pt_model.load_state_dict(torchparam)

# Save the model
torch.save(pt_model.state_dict(), 'pt_model.pt')

Input data of any size

Once this was done I have faced the problematic of how can we use data of any size as inputs now that we have these PyTorch models. It appears that this is possible by building the architecture of the model for the input size of interest and then set the weights layer by layer using the trained model (this is somehow similar to what I was doing with the TensorFlow models).

This can be done with a code similar to this one:

# Choose any input shape (multiple of 16) - with real data, choose the size of the data to do the registration on the whole data without cropping
input_shape = [224, 224, 160]

# Set the parameters of the registration model
reg_args = dict(
    inshape=input_shape,
    int_steps=5,
    int_downsize=2,
    unet_half_res=True,
    nb_unet_features=([256, 256, 256, 256], [256, 256, 256, 256, 256, 256])
)

# Create the PyTorch model
pt_model = vxm.networks.VxmDense(**reg_args)

# Load the weights of the trained PyTorch model
trained_state_dict = torch.load('pt_model.pt')

# Initialize the new PyTorch model (with the input size of interest) with the weights of the trained model
weights = []
for k in trained_state_dict:
    weights.append(trained_state_dict[k])
i = 0
i_max = len(list(pt_model.named_parameters()))
torchparam = pt_model.state_dict()
for k, v in torchparam.items():
    if i < i_max:
        print("{:20s} {}".format(k, v.shape))
        torchparam[k] = weights[i]
        i += 1

pt_model.load_state_dict(torchparam)

Results

Once the models were converted and when I have found how to use data of any size as inputs, I have tested these new models to see if I could obtain the same results as with the TF models. Unfortunately, this is not the case. I don’t know what is wrong in the process but when I tried to register images using the PyTorch models instead of the TF models, the results are not the same and it leads to poor registration. Therefore, either I manage to identify the part of the process that might be problematic and I am able to fix it (e.g. not linked to the PyTorch model implementation done in VoxelMorph), or we have to consider another solution than using PyTorch models.

Here is an example of results obtained:

ezgif com-gif-maker - 2022-04-25T175644 913

ezgif com-gif-maker - 2022-04-25T175837 333

ezgif com-gif-maker - 2022-04-25T175937 882

And here the code that I used (cascaded registration):
import os

import numpy as np
import nibabel as nib
from nilearn.image import resample_img

import torch
# import VoxelMorph with pytorch backend
os.environ['VXM_BACKEND'] = 'pytorch'
import voxelmorph as vxm

# Load preprocessed data (scaled between 0 and 1 and with the moving data in the space of the fixed one)
fixed = nib.load("data_processed_time_analysis/data1/sub-vallHebron06/anat/sub-vallHebron06_T1w.nii.gz")
moving = nib.load("data_processed_time_analysis/data1/sub-vallHebron06/anat/sub-vallHebron06_T2w.nii.gz")
# N.B.
# These data are in my local computer but any data could be used to perform the same analysis.
# It only needs to be scaled and set in a common space (e.g. using sct_register_multimodal with -identity 1)

# Define the input shape of the model (smallest multiple of 16 above the input data shape)
# Ensure that the volumes can be used in the registration model
fx_img_shape = fixed.get_fdata().shape
mov_img_shape = moving.get_fdata().shape
max_img_shape = max(fx_img_shape, mov_img_shape)
new_img_shape = (int(np.ceil(max_img_shape[0] // 16)) * 16, int(np.ceil(max_img_shape[1] // 16)) * 16,
                 int(np.ceil(max_img_shape[2] // 16)) * 16)

# Pad the volumes to the max shape
fx_paded = resample_img(fixed, target_affine=fixed.affine, target_shape=new_img_shape, interpolation='continuous')
mov_paded = resample_img(moving, target_affine=moving.affine, target_shape=new_img_shape, interpolation='continuous')
input_shape = list(new_img_shape)

# Set the parameters of the registration model
reg_args = dict(
    # inshape=input_shape,
    inshape=[192, 192, 192],
    int_steps=5,
    int_downsize=2,
    unet_half_res=True,
    nb_unet_features=([256, 256, 256, 256], [256, 256, 256, 256, 256, 256])
)
# Create the PyTorch model and specify the device
device = 'cpu'

# ---- First Model ---- #
pt_first_model = vxm.networks.VxmDense(**reg_args)
trained_state_dict_first_model = torch.load('pt_first_model.pt')
# Load the weights to the PyTorch model
weights_first_model = []
for k in trained_state_dict_first_model:
    weights_first_model.append(trained_state_dict_first_model[k])
i = 0
i_max = len(list(pt_first_model.named_parameters()))
torchparam = pt_first_model.state_dict()
for k, v in torchparam.items():
    if i < i_max:
        # print("{:20s} {}".format(k, v.shape))
        torchparam[k] = weights_first_model[i]
        i += 1
pt_first_model.load_state_dict(torchparam)
pt_first_model.eval()

# ---- Second Model ---- #
pt_second_model = vxm.networks.VxmDense(**reg_args)
trained_state_dict_second_model = torch.load('pt_second_model.pt')
# Load the weights to the PyTorch model
weights_second_model = []
for k in trained_state_dict_second_model:
    weights_second_model.append(trained_state_dict_second_model[k])
i = 0
i_max = len(list(pt_second_model.named_parameters()))
torchparam = pt_second_model.state_dict()
for k, v in torchparam.items():
    if i < i_max:
        torchparam[k] = weights_second_model[i]
        i += 1
pt_second_model.load_state_dict(torchparam)
pt_second_model.eval()

# Prepare the data for inference
data_moving = np.expand_dims(mov_paded.get_fdata()[:192, :192, :192].squeeze(), axis=(0, -1)).astype(np.float32)
data_fixed = np.expand_dims(fx_paded.get_fdata()[:192, :192, :192].squeeze(), axis=(0, -1)).astype(np.float32)
# Set up tensors and permute for inference
input_moving = torch.from_numpy(data_moving).to(device).float().permute(0, 4, 1, 2, 3)
input_fixed = torch.from_numpy(data_fixed).to(device).float().permute(0, 4, 1, 2, 3)

# Predict using cascaded networks
moved, warp_tensor = pt_first_model(input_moving, input_fixed, registration=True)
warp_data_first = warp_tensor[0].permute(1, 2, 3, 0).detach().numpy()

moved_final, warp_tensor = pt_second_model(moved, input_fixed, registration=True)
warp_data_second = warp_tensor[0].permute(1, 2, 3, 0).detach().numpy()

# Saved the moved data to directly observe the results
moved_first_reg_data = moved[0][0].detach().numpy()
nib.save(nib.Nifti1Image(moved_first_reg_data, fixed.affine), 'moved_first_reg_pt.nii.gz')
moved_data = moved_final[0][0].detach().numpy()
nib.save(nib.Nifti1Image(moved_data, fixed.affine), 'moved_pt.nii.gz')

# Warping field
# Modify the warp data so it can be used with sct_apply_transfo()
# (add a time dimension, change the sign of some axes and set the intent code to vector)

# Change the sign of the vectors and the order of the axes components to be correctly used with sct_apply_transfo
# and to to get the same results with sct_apply_transfo() and when using model.predict() or vxm.networks.Transform()
orientation_conv = "LPS"
fx_im_orientation = list(nib.aff2axcodes(fixed.affine))
opposite_character = {'L': 'R', 'R': 'L', 'A': 'P', 'P': 'A', 'I': 'S', 'S': 'I'}
perm = [0, 1, 2]
inversion = [1, 1, 1]
for i, character in enumerate(orientation_conv):
    try:
        perm[i] = fx_im_orientation.index(character)
    except ValueError:
        perm[i] = fx_im_orientation.index(opposite_character[character])
        inversion[i] = -1

# First Warping Field
# Add the time dimension
warp_data_exp = np.expand_dims(warp_data_first, axis=3)
warp_data_exp_copy = np.copy(warp_data_exp)
warp_data_exp[..., 0] = inversion[0] * warp_data_exp_copy[..., perm[0]]
warp_data_exp[..., 1] = inversion[1] * warp_data_exp_copy[..., perm[1]]
warp_data_exp[..., 2] = inversion[2] * warp_data_exp_copy[..., perm[2]]
warp = nib.Nifti1Image(warp_data_exp, fixed.affine)
warp.header['intent_code'] = 1007

# Save the warping field that can be later used with sct_apply_transfo
nib.save(warp, f'warp_field_first_reg_pt.nii.gz')

# Second Warping Field
# Add the time dimension
warp_data_exp = np.expand_dims(warp_data_second, axis=3)
warp_data_exp_copy = np.copy(warp_data_exp)
warp_data_exp[..., 0] = inversion[0] * warp_data_exp_copy[..., perm[0]]
warp_data_exp[..., 1] = inversion[1] * warp_data_exp_copy[..., perm[1]]
warp_data_exp[..., 2] = inversion[2] * warp_data_exp_copy[..., perm[2]]
warp = nib.Nifti1Image(warp_data_exp, fixed.affine)
warp.header['intent_code'] = 1007

# Save the warping field that can be later used with sct_apply_transfo
nib.save(warp, f'warp_field_second_reg_pt.nii.gz')

Note: I did not use the input image size but reduced it to 192 x 192 x 192 because otherwise my computer's RAM was overloaded and the process stopped for that reason. I did not encounter this problem with the TF models.

@jcohenadad
Copy link
Member

Thank you very much for doing these thorough and well-documented investigations, @EvanBeal ! 🙏 This is quite frustrating that the registration results are not comparable between pytorch and tensorflow. I cross my fingers that the 'bug' (if we can call it a bug) is easy to spot 🤞

Note: I did not use the input image size but reduced it to 192 x 192 x 192 because otherwise my computer's RAM was overloaded and the process stopped for that reason. I did not encounter this problem with the TF models.

Hopefully this is related to the 'bug'. If RAM is more of a limitation with pytorch, then this is another argument for keeping TF, which would be annoying.

@EvanBeal
Copy link
Contributor

This is quite frustrating that the registration results are not comparable between pytorch and tensorflow. I cross my fingers that the 'bug' (if we can call it a bug) is easy to spot 🤞

After spending several hours trying to find an error in the code and looking for potential alternatives for converting from TensorFlow to PyTorch, I could not find anything relevant that would explain the observed differences and the poor registration results obtained with PyTorch.

Therefore, as a last attempt to use PyTorch, I opened an issue (issue 425) on the VoxelMorph repository to explain the situation and see if they have any idea what the reason might be or if they have ever tried to do such a conversion.

@jcohenadad
Copy link
Member

Thank you for the thorough digging, @EvanBeal . Let's hope that the VoxelMorph folks will be able to pinpoint what the problem could be 🤞.

@EvanBeal
Copy link
Contributor

Conversion of the registration models from Keras .h5 to PyTorch models (Corrected)

Good news ! The VoxelMorph folks were able to find the issue and their suggestion enabled to solve the problem.

Suggestion: You may not want to reverse the order of all dimensions of the kernel weights with weights[i].T, but instead try to only move the input/output dimensions with something like torch.movedim(weights[i], (-1, -2), (0, 1)).

Therefore, using the following code allows to obtain PyTorch models that perform exactly like the TensorFlow ones !

import os

# Import the Keras model and copy the weights
import voxelmorph as vxm
tf_model = vxm.networks.VxmDense.load('model/tf_registration_model.h5', input_model=None)
weights = tf_model.get_weights()

# import VoxelMorph with pytorch backend
import importlib
import torch
os.environ['VXM_BACKEND'] = 'pytorch'
importlib.reload(vxm)

# Build a Torch model and set the weights from the Keras model
reg_args = dict(
    inshape=[160, 160, 192],
    int_steps=5,
    int_downsize=2,   # same as int_resolution=2, in keras model
    unet_half_res=True,  # same as svf_resolution=2, in keras model
    nb_unet_features=([256, 256, 256, 256], [256, 256, 256, 256, 256, 256])
)
# Create the PyTorch model
pt_model = vxm.networks.VxmDense(**reg_args)

# Load the weights onto the PyTorch model
i = 0
i_max = len(list(pt_model.named_parameters()))
torchparam = pt_model.state_dict()
for k, v in torchparam.items():
    if i < i_max:
        print("{:20s} {}".format(k, v.shape))
        if k.split('.')[-1] == 'weight':
            torchparam[k] = torch.movedim(torch.tensor(weights[i]), (-1, -2), (0, 1))
        else:
            torchparam[k] = torch.tensor(weights[i])
        i += 1


pt_model.load_state_dict(torchparam)
torch.save(pt_model.state_dict(), 'pt_registration_model.pt')

Availability of PyTorch models

As a next step, I created a release on the multimodal registration project (r20220512) so the PyTorch models are now publicly available.

Integration in SCT

We should now be able to work on the integration of the code in SCT. To perform the registration on data of any input size with the cascaded registration models, the following code needs to be used.

Code for cascaded registration using PyTorch models
import os

import numpy as np
import nibabel as nib
from nilearn.image import resample_img

import torch
# import VoxelMorph with pytorch backend
os.environ['VXM_BACKEND'] = 'pytorch'
import voxelmorph as vxm

# Load preprocessed data (scaled between 0 and 1 and with the moving data in the space of the fixed one)
fixed = nib.load("data_processed_time_analysis/data2/sub-geneva06/anat/sub-geneva06_T1w.nii.gz")
moving = nib.load("data_processed_time_analysis/data2/sub-geneva06/anat/sub-geneva06_T2w.nii.gz")
# N.B.
# These data are in my local computer but any data could be used to perform the same analysis.
# It only needs to be scaled and set in a common space (e.g. using sct_register_multimodal with -identity 1)

# Define the input shape of the model (smallest divider of 16 above the input data shape)
# Ensure that the volumes can be used in the registration model
fx_img_shape = fixed.get_fdata().shape
mov_img_shape = moving.get_fdata().shape
max_img_shape = max(fx_img_shape, mov_img_shape)
new_img_shape = (int(np.ceil(max_img_shape[0] // 16)) * 16, int(np.ceil(max_img_shape[1] // 16)) * 16,
                 int(np.ceil(max_img_shape[2] // 16)) * 16)

# Pad the volumes to the max shape
fx_paded = resample_img(fixed, target_affine=fixed.affine, target_shape=new_img_shape, interpolation='continuous')
mov_paded = resample_img(moving, target_affine=moving.affine, target_shape=new_img_shape, interpolation='continuous')
input_shape = list(new_img_shape)

# Set the parameters of the registration model
reg_args = dict(
    inshape=input_shape,
    int_steps=5,
    int_downsize=2,
    unet_half_res=True,
    nb_unet_features=([256, 256, 256, 256], [256, 256, 256, 256, 256, 256])
)
# Create the PyTorch model and specify the device
device = 'cpu'

# ---- First Model ---- #
pt_first_model = vxm.networks.VxmDense(**reg_args)
trained_state_dict_first_model = torch.load('pt_cascaded_first_model.pt')
# Load the weights to the PyTorch model
weights_first_model = []
for k in trained_state_dict_first_model:
    weights_first_model.append(trained_state_dict_first_model[k])
i = 0
i_max = len(list(pt_first_model.named_parameters()))
torchparam = pt_first_model.state_dict()
for k, v in torchparam.items():
    if i < i_max:
        torchparam[k] = weights_first_model[i]
        i += 1
pt_first_model.load_state_dict(torchparam)
pt_first_model.eval()

# ---- Second Model ---- #
pt_second_model = vxm.networks.VxmDense(**reg_args)
trained_state_dict_second_model = torch.load('pt_cascaded_second_model.pt')
# Load the weights to the PyTorch model
weights_second_model = []
for k in trained_state_dict_second_model:
    weights_second_model.append(trained_state_dict_second_model[k])
i = 0
i_max = len(list(pt_second_model.named_parameters()))
torchparam = pt_second_model.state_dict()
for k, v in torchparam.items():
    if i < i_max:
        torchparam[k] = weights_second_model[i]
        i += 1
pt_second_model.load_state_dict(torchparam)
pt_second_model.eval()

# Prepare the data for inference
data_moving = np.expand_dims(mov_paded.get_fdata().squeeze(), axis=(0, -1)).astype(np.float32)
data_fixed = np.expand_dims(fx_paded.get_fdata().squeeze(), axis=(0, -1)).astype(np.float32)
# Set up tensors and permute for inference
input_moving = torch.from_numpy(data_moving).to(device).float().permute(0, 4, 1, 2, 3)
input_fixed = torch.from_numpy(data_fixed).to(device).float().permute(0, 4, 1, 2, 3)

# Predict using cascaded networks
moved, warp_tensor = pt_first_model(input_moving, input_fixed, registration=True)
warp_data_first = warp_tensor[0].permute(1, 2, 3, 0).detach().numpy()

# Saved the moved data after the first step of the process to directly observe the results
moved_first_reg_data = moved[0][0].detach().numpy()
moved_nifti = nib.Nifti1Image(moved_first_reg_data, fixed.affine)
nib.save(moved_nifti, 'moved_first_reg.nii.gz')

moved_final, warp_tensor = pt_second_model(moved, input_fixed, registration=True)
warp_data_second = warp_tensor[0].permute(1, 2, 3, 0).detach().numpy()
# Saved the moved data at the end of the registration process
moved_data = moved_final[0][0].detach().numpy()
nib.save(nib.Nifti1Image(moved_data, fixed.affine), 'moved.nii.gz')

# Warping field
# Modify the warp data so it can be used with sct_apply_transfo()
# (add a time dimension, change the sign of some axes and set the intent code to vector)

# Change the sign of the vectors and the order of the axes components to be correctly used with sct_apply_transfo
# and to to get the same results with sct_apply_transfo() and when using model.predict() or vxm.networks.Transform()
orientation_conv = "LPS"
fx_im_orientation = list(nib.aff2axcodes(fixed.affine))
opposite_character = {'L': 'R', 'R': 'L', 'A': 'P', 'P': 'A', 'I': 'S', 'S': 'I'}
perm = [0, 1, 2]
inversion = [1, 1, 1]
for i, character in enumerate(orientation_conv):
    try:
        perm[i] = fx_im_orientation.index(character)
    except ValueError:
        perm[i] = fx_im_orientation.index(opposite_character[character])
        inversion[i] = -1

# First Warping Field
# Add the time dimension
warp_data_exp = np.expand_dims(warp_data_first, axis=3)
warp_data_exp_copy = np.copy(warp_data_exp)
warp_data_exp[..., 0] = inversion[0] * warp_data_exp_copy[..., perm[0]]
warp_data_exp[..., 1] = inversion[1] * warp_data_exp_copy[..., perm[1]]
warp_data_exp[..., 2] = inversion[2] * warp_data_exp_copy[..., perm[2]]
warp = nib.Nifti1Image(warp_data_exp, fixed.affine)
warp.header['intent_code'] = 1007

# Save the warping field that can be later used with sct_apply_transfo
nib.save(warp, f'warp_field_first_reg.nii.gz')

# Second Warping Field
# Add the time dimension
warp_data_exp = np.expand_dims(warp_data_second, axis=3)
warp_data_exp_copy = np.copy(warp_data_exp)
warp_data_exp[..., 0] = inversion[0] * warp_data_exp_copy[..., perm[0]]
warp_data_exp[..., 1] = inversion[1] * warp_data_exp_copy[..., perm[1]]
warp_data_exp[..., 2] = inversion[2] * warp_data_exp_copy[..., perm[2]]
warp = nib.Nifti1Image(warp_data_exp, fixed.affine)
warp.header['intent_code'] = 1007

# Save the warping field that can be later used with sct_apply_transfo
nib.save(warp, f'warp_field_second_reg.nii.gz')

Therefore, (an adapted version of) this code will need to be integrated into a function in SCT to be used via sct_register_multimodal.
A new dependency with VoxelMorph will thus be created.

With the TensorFlow dependency that has been removed, the two warping fields resulting from the two registration steps are not composed anymore inside this code. I don't think this is a problem because we could give both deformation fields to sct_apply_transfo so that the composition of the fields (and thus the transformations) happens there, but if you think a better solution could be used, let me know.

Note: I did not use the input image size but reduced it to 192 x 192 x 192 because otherwise my computer's RAM was overloaded and the process stopped for that reason. I did not encounter this problem with the TF models.

As a negative point, this problem remains and I am not able to run the whole process on my own computer for data of size 192x256x320 because the RAM is overloaded. Only the first registration step is performed and then the process is killed.

@joshuacwnewton
Copy link
Member Author

Oh my goodness, this is wonderful news! Bravo! Thank you for taking the time and effort to investigate, and kudos for being able to navigate such a tricky investigation.

(Truthfully, I had been feeling a bit unsure whether this would even be possible to do! So, I am very very relieved to hear that you were able to identify the issue! 🎉)

Note: I did not use the input image size but reduced it to 192 x 192 x 192 because otherwise my computer's RAM was overloaded and the process stopped for that reason. I did not encounter this problem with the TF models.

As a negative point, this problem remains and I am not able to run the whole process on my own computer for data of size 192x256x320 because the RAM is overloaded. Only the first registration step is performed and then the process is killed.

I'm very curious about this point. I would like to try on my computer as well to see if I encounter the same issue.

But, this may also be something worth bringing up in a new issue on the VoxelMorph repo, too? Maybe they have encountered these resource-usage issues too, and could have some insight into the differences of their model architectures between TF and PyTorch.

@jcohenadad
Copy link
Member

OMG! I'm so happy to read this. Thank you for having persisted @EvanBeal , it did pay off at the end!

@jcohenadad
Copy link
Member

I don't think this is a problem because we could give both deformation fields to sct_apply_transfo so that the composition of the fields (and thus the transformations) happens there, but if you think a better solution could be used, let me know.

yup! that seems to be the right way to do it, as we currently do for multi-step approaches with other algorithms (ANTs, centermassrot, etc.).

@joshuacwnewton
Copy link
Member Author

I would like to try on my computer as well to see if I encounter the same issue.

So, I tried following along with @EvanBeal's test script from the "Code for cascaded registration using PyTorch models" dropdown in #3760 (comment).

However, I ran into a new issue:

  • @EvanBeal's model definition uses the parameter unet_half_res, which was introduced in Feb 2021 (voxelmorph/voxelmorph@579a995).

    reg_args = dict(
        inshape=input_shape,
        int_steps=5,
        int_downsize=2,
        unet_half_res=True,
        nb_unet_features=([256, 256, 256, 256], [256, 256, 256, 256, 256, 256])
    )
  • But, VoxelMorph only has a single release on PyPI (voxelmorph-0.1) from Sept 2020.

  • So, when you run pip install voxelmorph, you currently don't get access to the newest features from the master branch that @EvanBeal used in development.

I have reported this upstream in the hopes that the VoxelMorph team can release a new version on PyPI .

But, right now, this poses a problem: If SCT tries to include voxelmorph as a dependency in its requirements.txt, then we won't have access to the features that we need in order to load @EvanBeal's model.

@kousu
Copy link
Contributor

kousu commented May 17, 2022

But, right now, this poses a problem: If SCT tries to include voxelmorph as a dependency in its requirements.txt, then we won't have access to the features that we need in order to load @EvanBeal's model.

Actually.... because we're using requirements.txt and not pip, we can, just add this:

voxelmorph@git+https://github.com/voxelmorph/voxelmorph@579a995492bddfe9ce38161e58cf260fc155c4fd

I'm still going to ask them to publish to pypi (voxelmorph/voxelmorph#432), because we'll need that for #1526, but we're not stuck in the meantime.

And in the long run, when we run into situations like that, we can always fork and publish dependencies ourselves as 'neuropoly-$PKG'.

@joshuacwnewton
Copy link
Member Author

joshuacwnewton commented May 20, 2022

Note: I did not use the input image size but reduced it to 192 x 192 x 192 because otherwise my computer's RAM was overloaded and the process stopped for that reason. I did not encounter this problem with the TF models.

As a negative point, this problem remains and I am not able to run the whole process on my own computer for data of size 192x256x320 because the RAM is overloaded. Only the first registration step is performed and then the process is killed.

To follow up on the RAM issues, @EvanBeal has created voxelmorph/voxelmorph#434, and I adapted Evan's cascaded registration script into a reproducible example in voxelmorph/voxelmorph#434 (comment).

(This work is also currently located in the jn/3760-add-voxelmorph-registration branch.)

@kousu
Copy link
Contributor

kousu commented Jun 9, 2022

I'm still going to ask them to publish to pypi (voxelmorph/voxelmorph#432), because we'll need that for #1526

This is done: https://pypi.org/project/voxelmorph/0.2/

@joshuacwnewton
Copy link
Member Author

joshuacwnewton commented Jun 13, 2022

Given that both the PyTorch and VoxelMorph issues have been (more or less) resolved, I imagine we can start looking at integrating the VoxelMorph approach into SCT.

Here's are some rough, cursory TODOs, along with some pointers to relevant areas of SCT's codebase:

  • Adding voxelmorph to -param algo= choices
    • For sct_register_multimodal, -param step argument validity isn't currently enforced.
    • In other words, there is no argparse choices validation going on here.
    • So, right now, I don't think there is any need to manually add voxelmorph to a list of valid algo choices.
  • In register.py, create a function (similar to the existing functions) that takes in src/dest images and returns the warping fields.
  • In the in register() function, check param.algo, then call the newly-created registration function.

Aside: Observations about register_wrapper()

If you dig into how sct_register_multimodal actually performs registration, the trace is a bit complicated. We go from:

Now, it's a bit odd that sct_register_multimodal is calling a function from sct_register_to_template. (I would have assumed that register_wrapper() would be in the API, and called by both registration CLI scripts.)

To me, this is just begging for a CLI/API refactor. But, for pragmatic purposes, I think it's probably safe to ignore register_wrapper() for the purposes of this change, and instead jump straight into mimicking the existing conventions that are inside register().

@jcohenadad
Copy link
Member

Now, it's a bit odd that sct_register_multimodal is calling a function from sct_register_to_template. (I would have assumed that register_wrapper() would be in the API, and called by both registration CLI scripts.)

you're absolutely right. As you probably figured this is a legacy issue. Many years ago we didn't have APIs, all the algos where in the scripts. Then, progressively, things were moved around. I remember having worked on these scripts/APIs quite extensively, and there are still a lot of 'todo' to make it proper

@EvanBeal
Copy link
Contributor

To follow-up with @joshuacwnewton’s comment, I have worked on integrating the deep learning multimodal registration method into SCT following what is mentioned in the comment. And this work can be found on this branch.

From the tests that I have done today, it seems to be now quite functional, with a usage that is similar than for the others registration algorithms. (commit)

What has been done

The files registration/register.py and scripts/sct_register_to_template.py have been modified to perform multimodal registration using deep learning models through the sct_register_multimodal command line by specifying -param algo=dl,….
[N.B. we can of course change the name of the algorithm for this method that has been set for now to dl]

This is done by adding an if condition in the register() function of sct_register_to_template that detects when the user choose the deep learning method for registration and will call a new function (register_step_dl_multimodal_cascaded_reg) in register.py. This function prepares the data with some preprocessing steps and then calls another new function in register.py, register_dl_multimodal_cascaded_reg, that performs registration using deep learning models.

How it will be used in the future

There is no need of prior preprocessing steps as these ones are directly included in the method.
Therefore, to use this method we will simply do:

sct_register_multimodal -i moving_data.nii.gz -d fixed_data.nii.gz -param step=1,type=im,algo=dl

The deep learning registration models are currently not present in SCT, so they will need to be added somewhere to be able to directly perform the registration without having to ask the user to download some models somewhere else. My guess is that they should be added in the data folder under a new folder with a name like deepreg_models ?

How to test it now

To test the method now, you need to download the deep learning models and put them in a folder called test_data_models in your SCT folder and then use the code present on this branch (with the correct installation of the packages needed).

First, download and extract cascaded_models_pytorch.zip (from here) to a test_data_models/ folder inside your SCT folder (same location than the data, documentation or dev folder).

Then, choose some data to do your test. For instance, considering the spine generic dataset you can choose to register the T2w data of sub-geneva06 to the T1w of the same subject.

I highly recommend to crop the data before using it, otherwise it is very likely that the process will be killed because it will use too much memory.

So, if you use these data you can for example do the following:

sct_crop_image -i data/sub-geneva06/anat/sub-geneva06_T1w.nii.gz -o data/sub-geneva06/anat/sub-geneva06_T1w_crop.nii.gz -xmin 73 -xmax -68 -ymin 50 -ymax -101 -zmin 32 -zmax -105
sct_crop_image -i data/sub-geneva06/anat/sub-geneva06_T2w.nii.gz -o data/sub-geneva06/anat/sub-geneva06_T2w_crop.nii.gz -ymin 85 -ymax -99 -zmin 50 -zmax -41

Eventually, you can perform registration:

sct_register_multimodal -i data/sub-geneva06/anat/sub-geneva06_T2w_crop.nii.gz -d data/sub-geneva06/anat/sub-geneva06_T1w_crop.nii.gz -param step=1,type=im,algo=dl

N.B. I didn’t perform a lot of tests now so it is likely to encounter bugs, but these are some first results to share with you and to let you know the current state of the integration.

@joshuacwnewton
Copy link
Member Author

joshuacwnewton commented Jun 13, 2022

From the tests that I have done today, it seems to be now quite functional, with a usage that is similar than for the others registration algorithms. (commit)

Wonderful!! Thank you so much for promptly following through on integrating your model into SCT.

One thing to mention: I realize now that my original branch, jn/3760-add-voxelmorph-registration, was mostly meant for debugging purposes (so that we could report the memory issues to the VoxelMorph developers). Because of this, the branch contains some unnecessary debugging code that you should be able to get rid of entirely. :)

Apart from that, could you please open a pull request then request my review? I'd love to start going over your work so that we can get it merged. 😄

I highly recommend to crop the data before using it, otherwise it is very likely that the process will be killed because it will use too much memory.

One thing that I ended up doing when debugging the memory issues was to use sct_deepseg_sc and sct_create_mask to crop around the spinal cord:

# Cropping around the spinal cord (T1w)
sct_deepseg_sc -i T1w.nii.gz -c t1 -centerline cnn
sct_create_mask -i T1w.nii.gz -p centerline,T1w_seg.nii.gz -size 35mm -f cylinder -o mask_T1w.nii.gz
sct_crop_image -i T1w.nii.gz -m mask_T1w.nii.gz

# Cropping around the spinal cord (T2w)
sct_deepseg_sc -i T2w.nii.gz -c t2 -centerline cnn
sct_create_mask -i T2w.nii.gz -p centerline,T2w_seg.nii.gz -size 35mm -f cylinder -o mask_T2w.nii.gz
sct_crop_image -i T2w.nii.gz -m mask_T2w.nii.gz

I think this might be a little more friendly if we end up testing multiple different images? Since we won't have to rely on hardcoded coordinate values that are specific to a single image.

@EvanBeal
Copy link
Contributor

Because of this, the branch contains some unnecessary debugging code that you should be able to get rid of entirely. :)

Yes alright ! I have deleted cascaded_registration.py, that was indeed used as a temporary solution.

Apart from that, could you please open a pull request then request my review? I'd love to start going over your work so that we can get it merged. 😄

Yep, done with PR #3807

One thing that I ended up doing when voxelmorph/voxelmorph#434 (comment) was to use sct_deepseg_sc and sct_create_mask to crop around the spinal cord. I think this might be a little more friendly if we end up testing multiple different images? Since we won't have to rely on hardcoded coordinate values that are specific to a single image.

Yes totally agree ! Here it was just to provide a specific example but indeed to be more general and being able to easily test the method on multiple different images your solution is waaay better ;)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants