Closed
Description
I've found cases of functions in scipy, e.g.
that crash on the very first few lines with JAX when they're inside @jax.jit
:
>>> import jax
>>> import jax.numpy as xp
>>> from scipy.cluster.vq import kmeans
>>> a = xp.asarray([[1.,2.],[3.,4.]])
>>> jax.jit(kmeans)(a, 2)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function kmeans at /home/crusaderky/github/scipy/build-install/lib/python3.12/site-packages/scipy/cluster/vq.py:332 for jit. This concrete value was not available in Python because it depends on the value of the argument obs.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
The issue is that there is a health check enabled by default, check_finite=True
, which triggers this code:
A JAX jitted array crashes on bool()
; a dask array is quietly computed when you do so - which is possibly even worse.
There are two issues here:
- the default behaviour of the function is to inspect the contents of the array, and
- the error message is uncomprehensible to an end user, as it is triggered by code deep inside the scipy implementation.
My proposal:
- in array-api-compat, add two functions:
def is_jax_jitted_array(x):
return isinstance(x, DynamicJaxprTracer)
def is_material_array(x):
"""Return True if x has contents in memory at the moment of calling this function,
which are cheap to retrieve as long as they're small in size.
Return False if x is a future or it would be otherwise impossible or expensive to
read its contents, regardless of their size.
"""
return not is_dask_array(x) and not is_jax_jitted_array(x)
- in scipy, change
kmeans(..., check_finite=True)
tokmeans(..., check_finite=None)
, which will mean "check if possible", and replace
if check_finite:
_check_finite(x, xp)
with
def _check_material_array(x: Array, check: bool | None, check_name: str) -> bool:
if check is None:
return is_material_array(x)
if check and not is_material_array(x):
raise TypeError(f"Can't check non-material array {type(x)}. Please set {check_name} to None or False.")
return check
...
if _check_material_array(x, check_finite, "check_finite"):
_check_finite(x, xp)
However, @jakevdp mentioned elsewhere that DynamicJaxprTracer is not part of the public API of JAX and there is no public method to test for jitting. Not sure I can see a way forward without this information.
Also CC @rgommers @lucascolley
Metadata
Metadata
Assignees
Labels
No labels