multidimensional jax.isin()

multidimensional jax.isin()

Problem Description:

i am trying to filter an array of triples.
The criterion by which I want to filter is whether another array of triples contains at
least one element with the same first and third element. E.g

import jax.numpy as jnp
array1 = jnp.array(
  [
    [0,1,2],
    [1,0,2],
    [0,3,3],
    [3,0,1],
    [0,1,1],
    [1,0,3],
  ]
)
array2 = jnp.array([[0,1,3],[0,3,2]])
# the mask to filter the first array1 should look like this:
jnp.array([True,False,True,False,False,False])

What would be a computationally efficient way to achieve this mask using jax?
I am looking forward to your input.

Solution – 1

You can do this by reducing over a broadcasted equality check:

import jax.numpy as jnp
array1 = jnp.array(
  [
    [0,1,2],
    [1,0,2],
    [0,3,3],
    [3,0,1],
    [0,1,1],
    [1,0,3],
  ]
)
array2 = jnp.array([[0,1,2],[0,3,2]])  # note adjustment to match first entry of array1

mask = (array1[:, None] == array2[None, :]).all(-1).any(-1)
print(mask)
# [ True False False False False False]

XLA doesn’t have any binary search-like primitive, so the best approach in general is to generate the full equality matrix and reduce. If you’re running the code on an accelerator like a GPU/TPU, this sort of vectorized operation is efficiently parallelized and so it will be computed quite efficiently in practice.

Rate this post
We use cookies in order to give you the best possible experience on our website. By continuing to use this site, you agree to our use of cookies.
Accept
Reject