JAX offers a range of practical functions, also for data preparation. One of these is jax.nn.one_hot, which performs a classic one-hot encoding. Unfortunately, I was not able find a suitable multi-hot equivalent for multi-label applications. However, it is also quite easy to implement the functionality directly in JAX:

import jax.numpy as jnp
from functools import partial

@partial(jax.jit, static_argnames=("num_classes"))
def multi_hot(labels, num_classes: int):
    return jnp.take(jnp.eye(num_classes), jnp.array(labels), axis=0).sum(axis=0)