itk中互信息配准二维图像(仅平移)

原文地址:
https://examples.itk.org/src/registration/common/mutualinformation/mutualinformation#Resample-the-moving-image

https://examples.itk.org/src/registration/common/mutualinformation/documentation

效果图:

原始图像对比:

itk中互信息配准二维图像(仅平移)_第1张图片

沿-20到20的互信息结果二维:

itk中互信息配准二维图像(仅平移)_第2张图片

沿-20到20的互信息结果三维:

itk中互信息配准二维图像(仅平移)_第3张图片

沿梯度下降方向配准结果:

itk中互信息配准二维图像(仅平移)_第4张图片

最终配准结果:

itk中互信息配准二维图像(仅平移)_第5张图片

源码

import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

import itk


dim = 2
ImageType = itk.Image[itk.F, dim]
FixedImageType = ImageType
MovingImageType = ImageType

baseDir = r'D:\learn\itk\itk5.3.0\ITK-5.3.0\Examples\Data'

fixed_img_path = baseDir + r"\BrainT1SliceBorder20.png"
moving_img_path = baseDir + r"\BrainProtonDensitySliceShifted13x17y.png"

fixed_img = itk.imread(fixed_img_path, itk.F)
moving_img = itk.imread(moving_img_path, itk.F)

fixed_normalized_image = itk.normalize_image_filter(fixed_img)
fixed_smoothed_image = itk.discrete_gaussian_image_filter(fixed_normalized_image, variance=2.0)

moving_normalized_image = itk.normalize_image_filter(moving_img)
moving_smoothed_image = itk.discrete_gaussian_image_filter(moving_normalized_image, variance=2.0)

plt.subplot(221), plt.imshow(itk.GetArrayFromImage(fixed_img), 'gray'), plt.title('fixed_img')
plt.subplot(222), plt.imshow(itk.GetArrayFromImage(moving_img), 'gray'), plt.title('moving_img')
plt.subplot(223), plt.imshow(itk.GetArrayFromImage(fixed_smoothed_image), 'gray'), plt.title('fixed_smoothed_image')
plt.subplot(224), plt.imshow(itk.GetArrayFromImage(moving_smoothed_image), 'gray'), plt.title('moving_smoothed_image')
plt.show()

TransformType = itk.TranslationTransform[itk.D, dim]
OptimizerType = itk.GradientDescentOptimizer
ExhaustiveOptimizerType = itk.ExhaustiveOptimizer
MetricType = itk.MutualInformationImageToImageMetric[ImageType, ImageType]
RegistrationType = itk.ImageRegistrationMethod[ImageType, ImageType]
InterpolatorType = itk.LinearInterpolateImageFunction[ImageType, itk.D]

# 显示-20到20的x方向和y方向step为0.2的每个可能的值, step 为 window_size/n_steps
# Plot_the_MutualInformationImageToImageMetric_surface
# Move at most 20 pixels away from the initial position
window_size = [20, 20]
# Collect 100 steps of data along each axis
n_steps = [100, 100]

transform = TransformType.New()
metric = MetricType.New()
optimizer = ExhaustiveOptimizerType.New()
registrar = RegistrationType.New()
interpolator = InterpolatorType.New()

metric.SetNumberOfSpatialSamples(100)
metric.SetFixedImageStandardDeviation(0.4)
metric.SetMovingImageStandardDeviation(0.4)

optimizer.SetNumberOfSteps(n_steps)

# Initialize scales and set back to optimizer
scales = optimizer.GetScales()
scales.SetSize(2)
scales.SetElement(0, window_size[0] / n_steps[0])
scales.SetElement(1, window_size[1] / n_steps[1])
optimizer.SetScales(scales)

registrar.SetFixedImage(fixed_smoothed_image)
registrar.SetMovingImage(moving_smoothed_image)
registrar.SetOptimizer(optimizer)
registrar.SetTransform(transform)
registrar.SetInterpolator(interpolator)
registrar.SetMetric(metric)
registrar.SetFixedImageRegion(fixed_img.GetBufferedRegion())
registrar.SetInitialTransformParameters(transform.GetParameters())

# Collect data describing the parametric surface with an observer
surface = dict()


def print_iteration():
    surface[tuple(optimizer.GetCurrentPosition())] = optimizer.GetCurrentValue()


optimizer.AddObserver(itk.IterationEvent(), print_iteration)
registrar.Update()

max_position = list(optimizer.GetMaximumMetricValuePosition())
min_position = list(optimizer.GetMinimumMetricValuePosition())
max_val = optimizer.GetMaximumMetricValue()
min_val = optimizer.GetMinimumMetricValue()

print(max_position)
print(min_position)
print(max_val)
print(min_val)

# Set up values for the plot
x_vals = [list(set([x[i] for x in surface.keys()])) for i in range(0, 2)]

for i in range(0, 2):
    x_vals[i].sort()

X, Y = np.meshgrid(x_vals[0], x_vals[1])
Z = np.array([[surface[(x0, x1)] for x1 in x_vals[0]] for x0 in x_vals[1]])
# Plot the surface as a 2D heat map
fig = plt.figure()
plt.gca().invert_yaxis()
ax = plt.gca()
surf = ax.scatter(X, Y, c=Z, cmap=cm.coolwarm)
ax.plot(max_position[0], max_position[1], "k^")
ax.plot(min_position[0], min_position[1], "kv")
plt.show()

# Plot the surface as a 3D scatter plot
fig = plt.figure()
ax = fig.add_subplot(projection="3d")
surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm)
plt.show()


n_iterations = 200
transform = TransformType.New()
metric = MetricType.New()
optimizer = OptimizerType.New()
registrar = RegistrationType.New()
interpolator = InterpolatorType.New()

registrar.SetFixedImage(fixed_smoothed_image)
registrar.SetMovingImage(moving_smoothed_image)
registrar.SetOptimizer(optimizer)
registrar.SetTransform(transform)
registrar.SetInterpolator(interpolator)
registrar.SetMetric(metric)
registrar.SetFixedImageRegion(fixed_img.GetBufferedRegion())
registrar.SetInitialTransformParameters(transform.GetParameters())

metric.SetNumberOfSpatialSamples(100)
metric.SetFixedImageStandardDeviation(0.4)
metric.SetMovingImageStandardDeviation(0.4)

optimizer.SetLearningRate(15)
optimizer.SetNumberOfIterations(n_iterations)
optimizer.MaximizeOn()

descent_data = dict()
descent_data[0] = (0, 0)


def log_iteration():
    descent_data[optimizer.GetCurrentIteration() + 1] = tuple(optimizer.GetCurrentPosition())


optimizer.AddObserver(itk.IterationEvent(), log_iteration)
registrar.Update()
print(f"Its: {optimizer.GetCurrentIteration()}")
print(f"Final Value: {optimizer.GetValue()}")
print(f"Final Position: {list(registrar.GetLastTransformParameters())}")
x_vals = [descent_data[i][0] for i in range(0, n_iterations)]
y_vals = [descent_data[i][1] for i in range(0, n_iterations)]
fig = plt.figure()
# Note: We invert the y-axis to represent the image coordinate system
plt.gca().invert_yaxis()
ax = plt.gca()

surf = ax.scatter(X, Y, c=Z, cmap=cm.coolwarm)

for i in range(0, n_iterations - 1):
    plt.plot(x_vals[i : i + 2], y_vals[i : i + 2], "wx-")
plt.plot(descent_data[0][0], descent_data[0][1], "bo")
plt.plot(descent_data[n_iterations - 1][0], descent_data[n_iterations - 1][1], "ro")

plt.plot(max_position[0], max_position[1], "k^")
plt.plot(min_position[0], min_position[1], "kv")
plt.show()
print(max_position)
print(min_position)

ResampleFilterType = itk.ResampleImageFilter[MovingImageType, FixedImageType]
resample = ResampleFilterType.New(
    Transform=transform,
    Input=moving_img,
    Size=fixed_img.GetLargestPossibleRegion().GetSize(),
    OutputOrigin=fixed_img.GetOrigin(),
    OutputSpacing=fixed_img.GetSpacing(),
    OutputDirection=fixed_img.GetDirection(),
    DefaultPixelValue=100,
)
resample.Update()

plt.subplot(121), plt.imshow(itk.GetArrayFromImage(fixed_img), 'gray'), plt.title('fixed_img')
plt.subplot(122), plt.imshow(itk.GetArrayFromImage(resample.GetOutput()), 'gray'), plt.title('resample image')
plt.show()

你可能感兴趣的:(ITK,itk,互信息)