How to improve Julia's performance using just in time compilation (JIT)

How to improve Julia's performance using just in time compilation (JIT)

Problem Description:

I have been playing with JAX (automatic differentiation library in Python) and Zygote (the automatic differentiation library in Julia) to implement Gauss-Newton minimisation method.
I came upon the @jit macro in Jax that runs my Python code in around 0.6 seconds compared to ~60 seconds for the version that does not use @jit.
Julia ran the code in around 40 seconds. Is there an equivalent of @jit in Julia or Zygote that results is a better performance?

Here are the codes I used:

Python

from jax import grad, jit, jacfwd
import jax.numpy as jnp
import numpy as np
import time

def gaussian(x, params):
    amp = params[0]
    mu  = params[1]
    sigma = params[2]
    amplitude = amp/(jnp.abs(sigma)*jnp.sqrt(2*np.pi))
    arg = ((x-mu)/sigma)
    return amplitude*jnp.exp(-0.5*(arg**2))

def myjacobian(x, params):
    return jacfwd(gaussian, argnums = 1)(x, params)

def op(jac):
    return jnp.matmul(
        jnp.linalg.inv(jnp.matmul(jnp.transpose(jac),jac)),
        jnp.transpose(jac))
                         
def res(x, data, params):
    return data - gaussian(x, params)
@jit
def step(x, data, params):
    residuals = res(x, data, params)
    jacobian_operation = op(myjacobian(x, params))
    temp = jnp.matmul(jacobian_operation, residuals)
    return params + temp

N = 2000
x = np.linspace(start = -100, stop = 100, num= N)
data = gaussian(x, [5.65, 25.5, 37.23])

ini = jnp.array([0.9, 5., 5.0])
t1 = time.time()
for i in range(5000):
    ini = step(x, data, ini)
t2 = time.time()
print('t2-t1: ', t2-t1)
ini

Julia

using Zygote

function gaussian(x::Union{Vector{Float64}, Float64}, params::Vector{Float64})
    amp = params[1]
    mu  = params[2]
    sigma = params[3]
    
    amplitude = amp/(abs(sigma)*sqrt(2*pi))
    arg = ((x.-mu)./sigma)
    return amplitude.*exp.(-0.5.*(arg.^2))
    
end

function myjacobian(x::Vector{Float64}, params::Vector{Float64})
    output = zeros(length(x), length(params))
    for (index, ele) in enumerate(x)
        output[index,:] = collect(gradient((params)->gaussian(ele, params), params))[1]
    end
    return output
end

function op(jac::Matrix{Float64})
    return inv(jac'*jac)*jac'
end

function res(x::Vector{Float64}, data::Vector{Float64}, params::Vector{Float64})
    return data - gaussian(x, params)
end

function step(x::Vector{Float64}, data::Vector{Float64}, params::Vector{Float64})
    residuals = res(x, data, params)
    jacobian_operation = op(myjacobian(x, params))
    
    temp = jacobian_operation*residuals
    return params + temp
end

N = 2000
x = collect(range(start = -100, stop = 100, length= N))
params = vec([5.65, 25.5, 37.23])
data = gaussian(x, params)

ini = vec([0.9, 5., 5.0])
@time for i in range(start = 1, step = 1, length = 5000)
    ini = step(x, data, ini)
end
ini

Solution – 1

Your Julia code doing a number of things that aren’t idiomatic and are worsening your performance. This won’t be a full overview, but it should give you a good idea to start.

The first thing is passing params as a Vector is a bad idea. This means it will have to be heap allocated, and the compiler doesn’t know how long it is. Instead, use a Tuple which will allow for a lot more optimization. Secondly, don’t make gaussian act on a Vector of xs. Instead, write the scalar version and broadcast it. Specifically, with these changes, you will have

function gaussian(x::Number, params::NTuple{3, Float64})
    amp, mu, sigma = params
    
    # The next 2 lines should probably be done outside this function, but I'll leave them here for now.
    amplitude = amp/(abs(sigma)*sqrt(2*pi))
    arg = ((x-mu)/sigma)
    return amplitude*exp(-0.5*(arg^2))
end

Solution – 2

One straightforward way to speed this up is to use ForwardDiff not Zygote, since you are taking a gradient of a vector of length 3, many times. Here this gets me from 16 to 3.5 seconds, with the last factor of 2 involving Chunk(3) to improve type-stability. Perhaps this can be improved further.

function myjacobian(x::Vector, params)
    # return rand(eltype(x), length(x), length(params))  # with no gradient, takes 0.5s
    output = zeros(eltype(x), length(x), length(params))
    config = ForwardDiff.GradientConfig(nothing, params, ForwardDiff.Chunk(3))
    for (i, xi) in enumerate(x)
        # grad = gradient(p->gaussian(xi, p), params)[1]       # original, takes 16s
        # grad = ForwardDiff.gradient(p-> gaussian(xi, p))     # ForwardDiff, takes 7s
        grad = ForwardDiff.gradient(p-> gaussian(xi, p), params, config)  # takes 3.5s
        copyto!(view(output,i,:), grad)  # this allows params::Tuple, OK for Zygote, no help
    end
    return output
end
# This needs gaussian.(x, Ref(params)) elsewhere to use on many x, same params
function gaussian(x::Real, params)
    # amp, mu, sigma = params  # with params::Vector this is slower, 19 sec
    amp = params[1]
    mu  = params[2]
    sigma = params[3]  # like this, 16 sec
    T = typeof(x)  # avoids having (2*pi)::Float64 promote everything
    amplitude = amp/(abs(sigma)*sqrt(2*T(pi)))
    arg = (x-mu)/sigma
    return amplitude * exp(-(arg^2)/2)
end

However, this is still computing many small gradient arrays in a loop. It could easily compute one big gradient array instead.

While in general Julia is happy to compile loops to something fast, loops that make individual arrays tend to be a bad idea. And this is especially true for Zygote, which is fastest on matlab-ish whole-array code.

Here’s how this looks, it gets me under 1s for the whole program:

function gaussian(x::Real, amp::Real, mu::Real, sigma::Real)
    T = typeof(x)
    amplitude = amp/(abs(sigma)*sqrt(2*T(pi)))
    arg = (x-mu)/sigma
    return amplitude * exp(-(arg^2)/2)
end
function myjacobian2(x::Vector, params)  # with this, 0.9s
    amp = fill(params[1], length(x))
    mu  = fill(params[2], length(x))
    sigma = fill(params[3], length(x))  # use same sigma & different x value at each row:
    grads = gradient((amp, mu, sigma) -> sum(gaussian.(x, amp, mu, sigma)), amp, mu, sigma)
    hcat(grads...)
end
# Check that it agrees:
myjacobian2(x, params) ≈ myjacobian(x, params)

While this has little effect on the speed, I think you probably also want op(jac::Matrix) = Hermitian(jac'*jac) jac' rather than inv.

Rate this post
We use cookies in order to give you the best possible experience on our website. By continuing to use this site, you agree to our use of cookies.
Accept
Reject