## ----include = FALSE----------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)
set.seed(123)

## -----------------------------------------------------------------------------
library(mlstm)

D <- 50
V <- 200
K <- 5

NZ_per_doc <- 20
NZ <- D * NZ_per_doc

count <- cbind(
  d = as.integer(rep(0:(D - 1), each = NZ_per_doc)),
  v = as.integer(sample.int(V, NZ, replace = TRUE) - 1L),
  c = as.integer(rpois(NZ, 3) + 1L)
)

Y <- cbind(
  y1 = rnorm(D),
  y2 = rnorm(D)
)

dim(count)
head(count)
dim(Y)

## -----------------------------------------------------------------------------
mod_lda <- run_lda_gibbs(
  count = count,
  K = K,
  alpha = 0.1,
  beta = 0.01,
  n_iter = 20,
  verbose = FALSE
)

str(mod_lda$theta)
str(mod_lda$phi)

## -----------------------------------------------------------------------------
y <- Y[, 1]

set_threads(2)

mod_stm <- run_stm_vi(
  count = count,
  y = y,
  K = K,
  alpha = 0.1,
  beta = 0.01,
  max_iter = 50,
  min_iter = 10,
  verbose = FALSE
)

y_hat <- ((mod_stm$nd / mod_stm$ndsum) %*% mod_stm$eta)[, 1]
cor(y, y_hat)

## ----eval = FALSE-------------------------------------------------------------
#  plot(mod_stm$elbo_trace, type = "l")
#  plot(mod_stm$label_loglik_trace, type = "l")

## -----------------------------------------------------------------------------
mu <- rep(0, K)
upsilon <- K + 2
Omega <- diag(K)

mod_mlstm <- run_mlstm_vi(
  count = count,
  Y = Y,
  K = K,
  alpha = 0.1,
  beta = 0.01,
  mu = mu,
  upsilon = upsilon,
  Omega = Omega,
  max_iter = 50,
  min_iter = 10,
  verbose = FALSE
)

Y_hat <- ((mod_mlstm$nd / mod_mlstm$ndsum) %*% mod_mlstm$eta)
cor(Y, Y_hat)

## ----eval = FALSE-------------------------------------------------------------
#  plot(mod_mlstm$elbo_trace, type = "l")
#  plot(mod_mlstm$label_loglik_trace, type = "l")

## -----------------------------------------------------------------------------
sessionInfo()

