## multidimensional jax.isin()

Problem Description:

i am trying to filter an array of triples.
The criterion by which I want to filter is whether another array of triples contains at
least one element with the same first and third element. E.g

``````import jax.numpy as jnp
array1 = jnp.array(
[
[0,1,2],
[1,0,2],
[0,3,3],
[3,0,1],
[0,1,1],
[1,0,3],
]
)
array2 = jnp.array([[0,1,3],[0,3,2]])
# the mask to filter the first array1 should look like this:
jnp.array([True,False,True,False,False,False])
``````

What would be a computationally efficient way to achieve this mask using jax?
I am looking forward to your input.

## Solution – 1

You can do this by reducing over a broadcasted equality check:

``````import jax.numpy as jnp
array2 = jnp.array([[0,1,2],[0,3,2]])  # note adjustment to match first entry of array1

mask = (array1[:, None] == array2[None, :]).all(-1).any(-1)