By: Steven Whitaker
Re-posted from: https://glcs.hashnode.dev/optimization-mri
Welcome to our firstJulia for Devsblog post!This will be a continuously running series of postswhere our team will discusshow we have used Juliato solve real-life problems.So, let’s get started!
In this Julia for Devs post,we will discuss using Juliato optimize scan parametersin magnetic resonance imaging (MRI).
If you are interestedjust in seeing example codeshowing how to use Juliato optimize your own cost function,feel free to skip ahead.
While we will discuss MRI specifically,the concepts(and even some of the code)we will seewill be applicablein many situations where optimization is useful,particularly when there is a modelfor an observable signal.The model could be, e.g.,a mathematical equationor a trained neural network,and the observable signal(aka the output of the model)could be, e.g.,the image intensity of a medical imaging scanor the energy needed to heat a building.
Problem Setup
In this post,the signal model we will useis a mathematical equationthat gives the image intensityof a balanced steady-state free precession (bSSFP) MRI scan.This model has two sets of inputs:those that are user-definedand those that are not.In MRI,user-defined inputs are scan parameters,such as how long to run the scan foror how much power to use to generate MRI signal.The other input parameters are called tissue parametersbecause they are properties that are intrinsicto the imaged tissue (e.g., muscle, brain, etc.).
Tissue parameters often are assumedto take on a pre-defined valuefor each specific tissue type.However,because they are tissue-specificand can vary with health and age,sometimes tissue parameters are considered to be unknown valuesthat need to be estimated from data.
The problem we will discuss in this postis how to estimate tissue parametersfrom a set of bSSFP scans.Then we will discusshow to optimize the scan parametersof the set of bSSFP scansto improve the tissue parameter estimates.
To estimate tissue parameters,we need the following:
- a signal model, and
- an estimation algorithm.
Signal Model
Here’s the signal model:
function bssfp(TR, flip, T1, T2) E1 = exp(-TR / T1) E2 = exp(-TR / T2) (sin, cos) = sincosd(flip) num = sin * (1 - E1) den = 1 - cos * (E1 - E2) - E1 * E2 signal = num / den return signalend
This function returns the bSSFP signalas a function of two scan parameters(TR
, a value in units of seconds,and flip
, a value in units of degrees)and two tissue parameters(T1
and T2
, each a value in units of seconds).
Estimation Algorithm
There are various estimation algorithms out there,but, for simplicity in this post,we will stick with a grid search.We will compute the bSSFP signalover many pairs of T1
and T2
values.We will compare these computed signalsto the actual observed signal,and the pair of T1
and T2
valuesthat results in the closest matchwill be chosen as the estimates of the tissue parameters.Here’s the algorithm in code:
function gridsearch(signal, TR, flip) T1_grid = exp10.(range(log10(0.5), log10(3), 40)) T2_grid = exp10.(range(log10(0.005), log10(0.7), 40)) T1_est = T1_grid[1] T2_est = T2_grid[1] best = Inf residual = similar(signal) for (T1, T2) in Iterators.product(T1_grid, T2_grid) T1 > T2 > 0 || continue residual .= signal .- bssfp.(TR, flip, T1, T2) err = residual' * residual if err < best best = err T1_est = T1 T2_est = T2 end end isinf(best) && error("no valid grid points; consider changing `T1_grid` and/or `T2_grid`") return (T1_est, T2_est)end
Cost Function
Now,to optimize scan parameterswe need a function to optimize,also called a cost function.In this case,we want to minimize the errorbetween what our estimator tells usand the true tissue parameter values.But because we need to optimize scan parametersbefore running any scans,we will simulate bSSFP MRI scansusing the given sets of scan parametersand average the estimator errorover several sets of tissue parameters.Here’s the code for the cost function:
using Statisticsfunction estimator_error(TR, flip) T1_true = [0.8, 1.0, 1.5] T2_true = [0.05, 0.08, 0.1] noise_level = 0.01 T1_avg = mean(T1_true) T2_avg = mean(T2_true) signal = float.(TR) return mean(Iterators.product(T1_true, T2_true)) do (T1, T2) T2 > T1 && return 0 signal .= bssfp.(TR, flip, T1, T2) .+ noise_level .* randn.() (T1_est, T2_est) = gridsearch(signal, TR, flip) T1_err = (T1_est - T1)^2 T2_err = (T2_est - T2)^2 err = (T1_err / T1_avg + T2_err / T2_avg) / 2 endend
Optimization
Now we are finally ready to optimize!We will use Optimization.jl.Many optimizers are available,but because we have a non-convex optimization problem,we will use the adaptive particle swarm global optimization algorithm.Here’s a function that does the optimization:
using Optimization, OptimizationOptimJLusing Randomfunction scan_optimization(TR_init, flip_init) Random.seed!(0) N_scans = length(TR_init) length(flip_init) == N_scans || error("`TR_init` and `flip_init` have different lengths") cost_fun = (x, p) -> estimator_error(x[1:N_scans], x[N_scans+1:end]) (min_TR, max_TR) = (0.001, 0.02) (min_flip, max_flip) = (1, 90) constraints = (; lb = [fill(min_TR, N_scans); fill(min_flip, N_scans)], ub = [fill(max_TR, N_scans); fill(max_flip, N_scans)], ) f = OptimizationFunction(cost_fun) prob = OptimizationProblem(f, [TR_init; flip_init]; constraints...) sol = solve(prob, ParticleSwarm(lower = prob.lb, upper = prob.ub, n_particles = 3)) TR_opt = sol.u[1:N_scans] flip_opt = sol.u[N_scans+1:end] return (TR_opt, flip_opt)end
Results
Let’s see how scan parameter optimizationcan improve tissue parameter estimates.We will use the following functionto take scan parameters,simulate bSSFP scans,and then estimate the tissue parametersfrom those scans.We will compare the results of this functionwith manually chosen scan parametersto those with optimized scan parameters.
using LinearAlgebrausing Plotsfunction plot_true_vs_est(T_true, T_est, title, rmse, clim) return heatmap([T_true T_est]; title, clim, xlabel = "RMSE = $(round(rmse; digits = 4)) s", xticks = [], yticks = [], showaxis = false, aspect_ratio = 1, )endfunction run(TR, flip) nx = ny = 128 object = map(Iterators.product(range(-1, 1, nx), range(-1, 1, ny))) do (x, y) r = hypot(x, y) if r < 0.5 return (0.8, 0.08) elseif r < 0.8 return (1.3, 0.09) else return (0.0, 0.0) end end T1_true = first.(object) T2_true = last.(object) noise_level = 0.001 signal = map(T1_true, T2_true) do T1, T2 (T1, T2) == (0.0, 0.0) && return 0.0 bssfp.(TR, flip, T1, T2) .+ noise_level .* randn.() end T1_est = zeros(nx, ny) T2_est = zeros(nx, ny) for i in eachindex(signal, T1_est, T2_est) signal[i] == 0.0 && continue (T1_est[i], T2_est[i]) = gridsearch(signal[i], TR, flip) end m = T1_true .!= 0.0 T1_rmse = sqrt(norm(T1_true[m] - T1_est[m]) / count(m)) T2_rmse = sqrt(norm(T2_true[m] - T2_est[m]) / count(m)) p_T1 = plot_true_vs_est(T1_true, T1_est, "True vs Estimated T1", T1_rmse, (0, 2.5)) p_T2 = plot_true_vs_est(T2_true, T2_est, "True vs Estimated T2", T2_rmse, (0, 0.25)) return (p_T1, p_T2)end
First, let’s see the results with no optimization:
TR_init = [0.005, 0.005, 0.005]flip_init = [30, 60, 80](p_T1_init, p_T2_init) = run(TR_init, flip_init)
display(p_T1_init)
data:image/s3,"s3://crabby-images/02b6f/02b6f92c2f5d2b5547335b3c2c4729d5a7216f58" alt="T1 estimates using initial scan parameters"
display(p_T2_init)
data:image/s3,"s3://crabby-images/c9529/c9529cc248b5ab350fd868565c907298aff81854" alt="T2 estimates using initial scan parameters"
Now let’s optimize the scan parametersand then see how the tissue parameter estimates look:
(TR_opt, flip_opt) = scan_optimization(TR_init, flip_init)(p_T1_opt, p_T2_opt) = run(TR_opt, flip_opt)
display(p_T1_opt)
data:image/s3,"s3://crabby-images/da44f/da44fb6fefbf1bf101a31ecc6d4983e9b61a9442" alt="T1 estimates using optimized scan parameters"
display(p_T2_opt)
data:image/s3,"s3://crabby-images/bf05b/bf05b0783f102730bef2a1d8d8d23ff782546be6" alt="T2 estimates using optimized scan parameters"
We can see that the optimized scansresult in much better tissue parameter estimates!
Summary
In this post,we saw how Julia can be usedto optimize MRI scan parametersto improve tissue parameter estimates.Even though we discussed MRI specifically,the concepts presented hereeasily extend to other applicationswhere signal models are knownand optimization is required.
Note that most of the code for this postwas taken from a Dash appI helped create for JuliaHub.Feel free to check it outif you want to see this code in action!
Additional Links
- Optimizing MRI Scans Dash App
- Web app showcasing the capabilities discussed in this post.
- Optimizing MRI Scans Pluto Notebook
- Pluto notebook accompanying the Dash app mentioned above.(Unfortunately, the data used in the notebook didn’t get set up properly,so most of the code can’t run correctly.)
- Links to other blog posts discussing how to do MRI in Julia:
- Optimization.jl Docs
- Official documentation for Optimization.jl.