Radial Basis Function (RBF) Interpolation with JAX
The RBF interpolation from scipy is widely use in the Python community. However, the library is limited by a set of pre-defined kernel functions and lack of parallelization.
Through JAX, I was able to make a program that is:
- Flexible by letting you choose the kernel and how to compute the norm
- Differentiable by making easy to compute the gradient
- Parallelizable by taking advantage of parallel evaluation of JAX
Feel free to make a comment if you see any mistake