# Example workflow using fit

## Introduction

In this notebook, we are going to use the `fit`

function to train a UniRep model.

## Imports

Here are the imports that we are going to need for the notebook.

```
from jax.random import PRNGKey
from jax.experimental.stax import Dense, Softmax, serial
from jax_unirep import fit
from jax_unirep.evotuning_models import mlstm64
from jax_unirep.utils import load_params
from jax_unirep.layers import AAEmbedding, mLSTM, mLSTMHiddenStates
```

## Sequences

We'll prepare a bunch of dummy sequences.

In your *actual* use case, you'll probably need to find a way to load your sequences into memory as a **list of strings**. (We try our best to stick with Python idioms.)

```
sequences = ["HASTA", "VISTA", "ALAVA", "LIMED", "HAST", "HAS", "HASVASTA"] * 5
holdout_sequences = [
"HASTA",
"VISTA",
"ALAVA",
"LIMED",
"HAST",
"HASVALTA",
] * 5
```

## Example 1: Default mLSTM model

In this first example, we'll use a default mLSTM1900 model with the shipped weights that are provided.

Nothing needs to be passed in except for:

- The sequences to evotune against, and
- The number of epochs.

It's the easiest/fastest way to get up and running.

```
# First way: Use the default mLSTM1900 weights with mLSTM1900 model.
tuned_params = fit(sequences, n_epochs=2)
```

## Example 2: Pre-build model architectures

The second way is to use one of the pre-built evotuning models. The pre-trained weights for the three model architectures from the paper are shipped with the repo (1900, 256, 64). You can also leverage JAX to reproducibly initialize random parameters.

In this example, we'll use the `mlstm64`

model.
The `mlstm256`

model is also available,
and it might give you better performance
though at the price of longer training time.

```
init_fun, apply_fun = mlstm64()
# The init_func always requires a PRNGKey,
# and input_shape should be set to (-1, 26)
# This creates randomly initialized parameters
_, params = init_fun(PRNGKey(42), input_shape=(-1, 26))
# Alternatively, you can load the paper weights
params = load_params(paper_weights=64)
# Now we tune the params.
tuned_params = fit(sequences, n_epochs=2, model_func=apply_fun, params=params)
```

## Example 3: Build your own model

Finally, the modular style of `jax-unirep`

allows you to easily try out your own model architectures. You could for example change the amount of inital embedding dimensions, or the mLSTM architecture. Let's try a model with 20 inital embedding dimensions instead of 10, and two stacked mLSTM's with 512 hidden states each:

```
model_layers = (
AAEmbedding(20),
mLSTM(512),
mLSTMHiddenStates(),
mLSTM(512),
mLSTMHiddenStates(),
Dense(25),
Softmax,
)
init_fun, apply_fun = serial(*model_layers)
_, params = init_fun(PRNGKey(42), input_shape=(-1, 26))
tuned_params = fit(sequences, n_epochs=2, model_func=apply_fun, params=params)
```

## Obviously...

...you would probably swap in/out a different set of sequences and train for longer :).