i am trying to filter an array of triples.
The criterion by which I want to filter is whether another array of triples contains at
least one element with the same first and third element. E.g
import jax.numpy as jnp array1 = jnp.array( [ [0,1,2], [1,0,2], [0,3,3], [3,0,1], [0,1,1], [1,0,3], ] ) array2 = jnp.array([[0,1,3],[0,3,2]]) # the mask to filter the first array1 should look like this: jnp.array([True,False,True,False,False,False])
What would be a computationally efficient way to achieve this mask using jax?
I am looking forward to your input.
Solution – 1
You can do this by reducing over a broadcasted equality check:
import jax.numpy as jnp array1 = jnp.array( [ [0,1,2], [1,0,2], [0,3,3], [3,0,1], [0,1,1], [1,0,3], ] ) array2 = jnp.array([[0,1,2],[0,3,2]]) # note adjustment to match first entry of array1 mask = (array1[:, None] == array2[None, :]).all(-1).any(-1) print(mask) # [ True False False False False False]
XLA doesn’t have any binary search-like primitive, so the best approach in general is to generate the full equality matrix and reduce. If you’re running the code on an accelerator like a GPU/TPU, this sort of vectorized operation is efficiently parallelized and so it will be computed quite efficiently in practice.