import pickle
from dataclasses import dataclass
from typing import IO, Any, Dict, List, Optional, Tuple
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
from optax import OptState, adam
from scipy import sparse as sps
from ..definitions import DenseScoreArray, InteractionMatrix, UserIndexArray
from ..optimization.parameter_range import CategoricalRange
from .base_earlystop import (
BaseEarlyStoppingRecommenderConfig,
BaseRecommenderWithEarlyStopping,
TrainerBase,
)
class BaseMLP:
def __init__(
self,
output_dim: int,
hidden_dims: List[int],
):
self.output_dim = output_dim
self.hidden_dims = hidden_dims
def __call__(
self,
X: jnp.ndarray,
) -> jnp.ndarray:
layers = []
for d_o in self.hidden_dims:
layers.append(hk.Linear(d_o))
layers.append(jnp.tanh)
layers.append(hk.Linear(self.output_dim))
return hk.Sequential(layers)(X)
class DecoderNN:
def __init__(self, output_dim: int, hidden_dims: List[int]):
mlp = BaseMLP(output_dim=output_dim, hidden_dims=hidden_dims)
self.mlp = mlp
def __call__(self, X: jnp.ndarray) -> jnp.ndarray:
return jax.nn.log_softmax(self.mlp(X), axis=1)
def l2_normalize(X: jnp.ndarray) -> jnp.ndarray:
return X / jnp.sqrt((X**2).sum(axis=1) + 1e-8)[:, None]
class EncoderNN:
def __init__(
self,
latent_dim: int,
hidden_dims: List[int],
):
self.mlp = BaseMLP(2 * latent_dim, hidden_dims)
self.latent_dim = latent_dim
def __call__(
self, X: jnp.ndarray, dropout: float, train: bool
) -> Tuple[jnp.ndarray, jnp.ndarray]:
X = l2_normalize(X)
if train:
X = hk.dropout(hk.next_rng_key(), dropout, X)
h = self.mlp(X)
mu = h[:, : self.latent_dim]
log_var = h[:, self.latent_dim :]
return mu, log_var
@dataclass
class MultVAEOutput:
log_softmax: jnp.ndarray
mean: jnp.ndarray
log_stddev: jnp.ndarray
KL: jnp.ndarray
class MultVAE:
def __init__(
self,
n_obs: int,
latent_dim: int,
enc_hidden_dims: List[int],
dec_hidden_dims: List[int],
dropout_p: float = 0.5,
l2_reg: float = 0.01,
):
self.encoder_network = EncoderNN(
latent_dim,
enc_hidden_dims,
)
self.decoder_network = DecoderNN(n_obs, dec_hidden_dims)
self._kl_coeff: float = 0.0
self.dropout_p = dropout_p
self.l2_reg = l2_reg
def set_kl_coeff(self, nv: float) -> None:
self._kl_coeff = nv
def kl_coeff(self) -> float:
return self._kl_coeff
def __call__(self, X: jnp.ndarray, p: jnp.ndarray, train: bool) -> MultVAEOutput:
mu, log_var = self.encoder_network(X, p, train)
std = jnp.exp(log_var * 0.5)
# log_var = 2 * log (std)
KL: jnp.ndarray = 0.5 * (-log_var + std**2 + mu**2 - 1)
KL = KL.sum(axis=1).mean(axis=0)
if train:
eps = jax.random.normal(hk.next_rng_key(), mu.shape) # self.rng.rand
z: jnp.ndarray = mu + eps * std
else:
z = mu
log_softmax: jnp.ndarray = self.decoder_network(z)
return MultVAEOutput(log_softmax, mu, log_var * 0.5, KL)
class MultVAETrainer(TrainerBase):
def __init__(
self,
X: InteractionMatrix,
dim_z: int,
enc_hidden_dim: int,
dec_hidden_dim: Optional[int],
dropout_p: float,
l2_regularizer: float,
kl_anneal_goal: float,
anneal_end_epoch: int,
minibatch_size: int,
learning_rate: float,
):
self.X = X
self.n_users = X.shape[0]
self.n_items = X.shape[1]
self.minibatch_size = minibatch_size
self.kl_anneal_goal = kl_anneal_goal
self.anneal_end_epoch = anneal_end_epoch
self.dropout_p = dropout_p
self.l2_regularizer = l2_regularizer
self.learning_rate = learning_rate
self.rng_seq = hk.PRNGSequence(42)
if dec_hidden_dim is None:
dec_hidden_dim = enc_hidden_dim
self.enc_hidden_dim = enc_hidden_dim
self.dec_hidden_dim = dec_hidden_dim
self.dim_z = dim_z
self.total_anneal_step = (anneal_end_epoch * self.n_users) / minibatch_size
self._setup_jax_funcs()
self._update_count = 0
params = self.vae_f.init(
next(self.rng_seq),
np.zeros((1, self.n_items), dtype=np.float32),
self.dropout_p,
True,
)
self.params = params
self.opt_state = self.optimizer.init(self.params)
def _setup_jax_funcs(self) -> None:
vae_f = hk.transform(
lambda X, p, train: MultVAE(
self.n_items,
self.dim_z,
[self.enc_hidden_dim],
[self.dec_hidden_dim],
dropout_p=self.dropout_p,
l2_reg=self.l2_regularizer,
)(X, p, train)
)
self.optimizer = adam(learning_rate=self.learning_rate)
def loss_fn(
params: hk.Params,
rng: jnp.ndarray,
X: jnp.ndarray,
kl_coeff: jnp.ndarray,
dropout: float,
train: bool,
) -> jnp.ndarray:
mvresult: MultVAEOutput = vae_f.apply(params, rng, X, dropout, train)
neg_ll = -(mvresult.log_softmax * X).sum(axis=1).mean()
neg_elbo = neg_ll + kl_coeff * mvresult.KL
return neg_elbo
loss_fn = jax.jit(loss_fn, static_argnums=(4, 5))
def update(
params: hk.Params,
rng: jnp.ndarray,
opt_state: OptState,
X: jnp.ndarray,
kl_coeff: jnp.ndarray,
dropout: float,
) -> Tuple[hk.Params, OptState]:
grads = jax.grad(loss_fn)(params, rng, X, kl_coeff, dropout, True)
updates, new_optstate = self.optimizer.update(grads, opt_state)
new_params = optax.apply_updates(
params,
updates,
)
return new_params, new_optstate
update = jax.jit(update, static_argnums=(5,))
self.vae_f = vae_f
self.update_function = update
def get_score_cold_user(self, X: InteractionMatrix) -> DenseScoreArray:
mb_arrays: List[DenseScoreArray] = []
n_users: int = X.shape[0]
X_csr: sps.csr_matrix = X.tocsr()
for mb_start in range(0, n_users, self.minibatch_size):
mb_end = min(n_users, mb_start + self.minibatch_size)
X_mb = X_csr[mb_start:mb_end].astype(np.float32).toarray()
X_jax = jnp.asarray(X_mb)
mvresult: MultVAEOutput = self.vae_f.apply(
self.params, next(self.rng_seq), X_jax, self.dropout_p, False
)
mb_arrays.append(np.asarray(mvresult.log_softmax, dtype=np.float64))
score_concat: DenseScoreArray = np.concatenate(mb_arrays, axis=0)
return score_concat
def run_epoch(self) -> None:
user_indices = np.arange(self.n_users)
np.random.shuffle(user_indices)
for mb_start in range(0, self.n_users, self.minibatch_size):
current_kl_coeff = jnp.asarray(
self.kl_anneal_goal
* min(1, self._update_count / self.total_anneal_step)
)
mb_end = min(self.n_users, mb_start + self.minibatch_size)
X = self.X[user_indices[mb_start:mb_end]].astype(np.float32)
if sps.issparse(X):
X = X.toarray()
X_jax = jnp.asarray(X, dtype=jnp.float32)
self.params, self.opt_state = self.update_function(
self.params,
next(self.rng_seq),
self.opt_state,
X_jax,
current_kl_coeff,
self.dropout_p,
)
self._update_count += 1
def save_state(self, ofs: IO) -> None:
pickle.dump(self.params, ofs)
def load_state(self, ifs: IO) -> None:
self.params = pickle.load(ifs)
def __getstate__(self) -> Any:
serealized = dict(**self.__dict__)
serealized.pop("update_function", None)
serealized.pop("vae_f", None)
serealized.pop("optimizer")
return serealized
def __setstate__(self, state: Dict[str, Any]) -> Any:
self.__dict__.update(state)
self._setup_jax_funcs()
class MultVAEConfig(BaseEarlyStoppingRecommenderConfig):
dim_z: int = 16
enc_hidden_dims: int = 256
dec_hidden_dims: Optional[int] = None
dropout_p: float = 0.5
l2_regularizer: float = 0
kl_anneal_goal: float = 0.2
anneal_end_epoch: int = 50
minibatch_size: int = 512
learning_rate: float = 1e-3
[docs]class MultVAERecommender(BaseRecommenderWithEarlyStopping):
r"""JAX implementation of Mult-VAE, presented in `"Variational Autoencoders for Collaborative Filtering" <https://arxiv.org/abs/1802.05814>`_.
Args:
X_train_all:
The source data.
dim_z:
The latend dimension.
enc_hidden_dims:
The encoder's intermediate layer dimension.
dec_hidden_dims:
The dimensions of the decoder's intermediate layers.
dropout_p:
Dropout ratio. Defaults to 0.5.
l2_regularizer:
L2 regularization coefficient. Defaults to 0.
kl_anneal_goal:
beta of beta-VAE. Defaults to 0.2.
anneal_end_epoch:
The epoch to complete the annealing.. Defaults to 50.
minibatch_size (int, optional):
Minibatch size. Defaults to 512.
train_epochs:
The number of epochs to run. Defaults to 300.
learning_rate:
Adam optimizer's learning rate. Defaults to 1e-3.
"""
config_class = MultVAEConfig
default_tune_range = [
CategoricalRange("dim_z", [32, 64, 128, 256]),
CategoricalRange("enc_hidden_dims", [128, 256, 512]),
CategoricalRange("kl_anneal_goal", [0.1, 0.2, 0.4]),
]
[docs] def __init__(
self,
X_train_all: InteractionMatrix,
dim_z: int = 16,
enc_hidden_dims: int = 256,
dec_hidden_dims: Optional[int] = None,
dropout_p: float = 0.5,
l2_regularizer: float = 0,
kl_anneal_goal: float = 0.2,
anneal_end_epoch: int = 50,
minibatch_size: int = 512,
train_epochs: int = 300,
learning_rate: float = 1e-3,
) -> None:
super().__init__(
X_train_all,
train_epochs=train_epochs,
)
self.dim_z = dim_z
self.enc_hidden_dims = enc_hidden_dims
self.dec_hidden_dims = dec_hidden_dims
self.kl_anneal_goal = kl_anneal_goal
self.anneal_end_epoch = anneal_end_epoch
self.minibatch_size = minibatch_size
self.dropout_p = dropout_p
self.l2_regularizer = l2_regularizer
self.learning_rate = learning_rate
self.trainer: Optional[MultVAETrainer] = None
def _create_trainer(self) -> MultVAETrainer:
return MultVAETrainer(
self.X_train_all,
self.dim_z,
self.enc_hidden_dims,
self.dec_hidden_dims,
self.dropout_p,
self.l2_regularizer,
self.kl_anneal_goal,
self.anneal_end_epoch,
self.minibatch_size,
self.learning_rate,
)
[docs] def get_score(self, user_indices: UserIndexArray) -> DenseScoreArray:
return self.get_score_cold_user(self.X_train_all[user_indices])
[docs] def get_score_cold_user(self, X: InteractionMatrix) -> DenseScoreArray:
if self.trainer is None:
raise RuntimeError("encoder called before training.")
return self.trainer.get_score_cold_user(X)