## Multinomial Logistic Regression in JAX

Classifications are a classic machine learning problem we can tackle using logistic regression. If we distinguish between more than two classes, we call it a multinomial logistic regression. In this post, I will show how this can be done using JAX based on the well-known Fischer’s Iris dataset (every R user should be familiar with this one). First, we have to load the required libraries and load the data. Since this is a classification, we have a set of predictors (aka....