The RBF interpolation from scipy is widely use in the Python community. However, the library is limited by a set of pre-defined kernel functions and lack of parallelization.

Through JAX, I was able to make a program that is:

  • Flexible by letting you choose the kernel and how to compute the norm
  • Differentiable by making easy to compute the gradient
  • Parallelizable by taking advantage of parallel evaluation of JAX

Feel free to make a comment if you see any mistake