-
Notifications
You must be signed in to change notification settings - Fork 577
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modeling multi-channel NCC-based registration #313
Comments
Hi @neel-dey , Thanks, these are very good points. The multi-modal NCC was only experimentally added (and we should note this in the code), and we haven't really done thorough testing -- and I think your domain shift (counter)example is a good one. I suspect we could still implement a 'split' NCC without actually splitting the computation. @brf2 and @ahoopes might be interested in this discussion as well. |
Makes sense, thanks for the response. Implementing channel-wise NCC directly should just need two changes in https://github.com/voxelmorph/voxelmorph/blob/master/voxelmorph/tf/losses.py#L37 at L37 and L51, such that (changes marked with class NCC:
"""
Local (over window) normalized cross correlation loss.
"""
def __init__(self, win=None, eps=1e-5):
self.win = win
self.eps = eps
def ncc(self, I, J):
# get dimension of volume
# assumes I, J are sized [batch_size, *vol_shape, nb_feats]
ndims = len(I.get_shape().as_list()) - 2
assert ndims in [1, 2, 3], "volumes should be 1 to 3 dimensions. found: %d" % ndims
# set window size
if self.win is None:
self.win = [9] * ndims
# get convolution function
conv_fn = getattr(tf.nn, 'conv%dd' % ndims)
# compute CC squares
I2 = I * I
J2 = J * J
IJ = I * J
# compute filters
in_ch = J.get_shape().as_list()[-1]
sum_filt = tf.ones([*self.win, 1, in_ch]) # CHANGED HERE
strides = 1
if ndims > 1:
strides = [1] * (ndims + 2)
# compute local sums via convolution
padding = 'SAME'
I_sum = conv_fn(I, sum_filt, strides, padding)
J_sum = conv_fn(J, sum_filt, strides, padding)
I2_sum = conv_fn(I2, sum_filt, strides, padding)
J2_sum = conv_fn(J2, sum_filt, strides, padding)
IJ_sum = conv_fn(IJ, sum_filt, strides, padding)
# compute cross correlation
win_size = np.prod(self.win) # CHANGED HERE
u_I = I_sum / win_size
u_J = J_sum / win_size
cross = IJ_sum - u_J * I_sum - u_I * J_sum + u_I * u_J * win_size # TODO: simplify this
I_var = I2_sum - 2 * u_I * I_sum + u_I * u_I * win_size
J_var = J2_sum - 2 * u_J * J_sum + u_J * u_J * win_size
cc = cross * cross / (I_var * J_var + self.eps)
# return mean cc for each entry in batch
return tf.reduce_mean(K.batch_flatten(cc), axis=-1)
def loss(self, y_true, y_pred):
return - self.ncc(y_true, y_pred) I haven't assessed this super carefully or critically, but numerical tests worked fine. If it looks correct, I can send a PR. |
Great, opened PR #314. |
Sure. I don’t think I have used multi-channel NCC yet, but certainly we will want it!
|
Hi Adrian & co.,
For multi-channel registration (eg, RGB image registration or 4D registration of subject A with T1 and T2 <---> subject B with T1 and T2), vxm implements 4D windows for local NCC (e.g., with window size [9, 9, 9, 2] for T1+T2).
I wonder if this may be a problem when dealing with domain shifts (eg, scanner differences) in a heterogeneous dataset. Typically, 3D NCC handles this by standardizing local statistics and is mostly insensitive to domain shift. However, T1 and T2 intensities may not change with the same transformation and this impacts the statistics of the 4D window.
In practice, when training for multi-channel templates on a dataset with multiple centers, the NCC loss values had high variance and depended strongly on the center (which eventually lead to divergence). This effect goes away once I just used two separate 3D NCC terms for each modality (ANTs uses separate NCC terms as well). I imagine that if the batch size is high enough, this would not be an issue, but we're stuck with a low number for 3D MRI. :)
Here's a minimal example demonstrating that 4D NCC is sensitive to domain shifts, whereas 3D NCC on each channel is relatively insensitive. The example uses ICBM 2009a Nonlinear Asymmetric T1+T2 as image 1 and NIH's pediatric template as image 2.
This yields output:
Do you have any thoughts on this phenomenon and if 4D NCC would be better than split 3D NCC in other applications?
Thanks!
The text was updated successfully, but these errors were encountered: