Classifications are a classic machine learning problem we can tackle using logistic regression. If we distinguish between more than two classes, we call it a multinomial logistic regression. In this post, I will show how this can be done using JAX based on the well-known Fischer’s Iris dataset (every R user should be familiar with this one).

First, we have to load the required libraries and load the data. Since this is a classification, we have a set of predictors (aka. features) and a label for each sample.

from sklearn.model_selection import train_test_split
from sklearn import datasets

iris = datasets.load_iris()

xs =
n_classes = max( + 1
ys = jax.nn.one_hot(, n_classes)

Standardised features often facilitate the learning process. In this case, however, I would not have noticed any major differences. Nevertheless, we can standardise the features as follows.

xs = (xs - xs.mean(axis=0))/xs.std(axis=0)

In any case, we should divide the data into a training set and a test set. This split allows us to assess the model’s performance on data that the model has not seen before.

X_train, X_test, y_train, y_test = train_test_split(
	test_size=0.1, random_state=42

In any case, we should divide the data into training and test sets. This split allows us to assess the model’s performance on data that the model has not seen before.

We model our multinomial logistic regression as a linear transformation that we then chase through a (non-linear) activation function. The matrix multiplication transforms the predictors into a smaller vector space whose dimensions are defined by the number of classes. The bias is similar to an intercept of a linear model and represents the baseline for each class. The activation function used here is the normalized exponential function (softmax). This function normalizes the input to scores that look like probabilities (values between 0 and 1) and add up to 1.

Consequently, a minimalist implementation of multinomial logistic regression looks like this:

import jax
import jax.numpy as jnp

class MultinomialLogisticRegressor():
    def __init__(self, w, b):
        self.w = w
        self.b = b
    def predict(self, x):
        return jax.nn.softmax(jax.lax.batch_matmul(x, self.w) + self.b)
    def cross_entropy(logprobs, targets):
        target_class = jnp.argmax(targets, axis=1)
        nll = jnp.take_along_axis(
            jnp.expand_dims(target_class, axis=1), 
        ce = -jnp.mean(nll)
        return ce

The parameters are stored in the class and set via the constructor. This way, we can set the weights manually and try out the inference. Of course, these are only random weights, but you can already see that the inference yields a reasonable numeric output.

    jnp.ones([4, n_classes]),
).predict(jnp.array([[42., 0., 42., 42.]]))
DeviceArray([[0.33333334, 0.33333334, 0.33333334]], dtype=float32)

But to get more meaningful weights now, we can train the network.

key = jax.random.PRNGKey(123)

params = {
    'w': jax.random.normal(key, shape=(4, n_classes)),
    'w': jnp.zeros([4, n_classes]), 
    'b': jnp.zeros([n_classes])

I choose a random starting point, and using the loss function; we update the weights (i.e., parameters) step by step.

def loss_fn(params, xs, ys):
    my_regressor = MultinomialLogisticRegressor(params['w'], params['b'])
    return MultinomialLogisticRegressor.cross_entropy(

grad_fn = jax.grad(loss_fn)

for i in range(1000):
    if i % 100 == 0:
        print(loss_fn(params, X_train, y_train))
    grads = grad_fn(params, X_train, y_train)
    for name in params.keys():
        params[name] -= 0.1 * grads[name]

Since we only have a few data points here, I use all the training data at once for each run. Usually, the training data for neural networks are much bigger, so you can only use a subset (batch) at a time for training. A complete run is called an epoch - so we are talking about 1000 epochs in which we are learning here. Therefore, the step size of 0.1 per epoch is comparatively large and should be adjusted.

After the training, we can now test the inference to the learned weights. For example, I use the sample at 120 in our data set.

y_hat = MultinomialLogisticRegressor(params['w'], params['b']).predict(
    xs[120].reshape(1, -1)
DeviceArray([[6.8913272e-05, 7.8267194e-02, 9.2166394e-01]], dtype=float32)

As seen in the output, the scores for the last class are the highest, so a classification in class 3 is obvious.

y_hat = MultinomialLogisticRegressor(params['w'], params['b']).predict(X_test)

# accuracy
(y_test.argmax(axis=1) == y_hat.argmax(axis=1)).mean()
DeviceArray(1., dtype=float32)

We can also calculate the whole holdout (test set) inference. The inference on the holdout also allows us to determine the accuracy. This simple network, for example, can deliver a remarkable 100 % accuracy for the test data. However, the training set is relatively small, and one typically has to pay attention to other things in classifications (e.g., are all classes represented equally often?).

In addition, we can also take a closer look at the weights we have learned. In detail, we see here the learned parameters of the individual features for each class and the biases (one per class).

{'w': DeviceArray([[-0.8394912 ,  0.10752536,  0.73196554],
              [ 0.99018306, -1.0996652 ,  0.10948081],
              [-1.3265549 ,  0.13573228,  1.1908214 ],
              [-1.2383505 , -0.38045356,  1.6188046 ]], dtype=float32),
 'b': DeviceArray([-0.27729714,  0.7294254 , -0.45212856], dtype=float32)}

Finally, we calculate the perplexity. This measure is the exponentiation of the entropy, is often used for text models, and can be calculated very quickly based on the cross-entropy loss. However, the perplexity is usually calculated on the holdout set and not the training loss.

jnp.power(2, loss_fn(params, X_test, y_test))
DeviceArray(0.5497128, dtype=float32)