## ----setup, include=FALSE-----------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.width = 7,
  fig.height = 5
)

## ----packages-----------------------------------------------------------------
library(mpmaggregate)
library(knitr)
library(kableExtra)

## ----fetch-comadre------------------------------------------------------------
# Example matrices used in the general-to-general MPM aggregation examples below
# Population projection matrix for Rourea induta

# Not run: requires Rcompadre and internet access

# MatrixID = 246781
# Matrices can be retrieved with:
# library(Rcompadre)
# compadre <- cdb_fetch("compadre")
# mpm <- compadre[compadre$MatrixID == 246781, ]
# matA <- matA(mpm)[[1]]
# matU <- matU(mpm)[[1]]
# matF <- matF(mpm)[[1]]
# matC <- matC(mpm)[[1]]


# The matrices are defined locally so that Rcompadre and internet access
# are not required

# Population projection matrix
matA <- matrix(c(
  0.000,0.000,0.000,0.000,0.002,0.007,0.021,0.055,0.102,0.139,0.176,0.191,0.203,0.214,
  0.000,0.000,0.000,0.001,0.002,0.004,0.007,0.011,0.016,0.021,0.036,0.052,0.072,0.096,
  0.833,0.043,0.883,0.151,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0.000,0.391,0.043,0.791,0.079,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0.000,0.348,0.000,0.047,0.810,0.139,0.024,0.024,0.056,0.000,0.000,0.000,0.000,0.000,
  0.000,0.000,0.000,0.000,0.087,0.733,0.082,0.071,0.000,0.000,0.000,0.000,0.000,0.000,
  0.000,0.087,0.000,0.000,0.000,0.129,0.765,0.024,0.000,0.013,0.000,0.000,0.000,0.000,
  0.000,0.000,0.000,0.000,0.000,0.000,0.106,0.762,0.028,0.000,0.000,0.000,0.000,0.000,
  0.000,0.000,0.000,0.000,0.000,0.000,0.012,0.119,0.778,0.000,0.000,0.000,0.000,0.000,
  0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.083,0.861,0.000,0.000,0.000,0.000,
  0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.127,0.891,0.000,0.000,0.000,
  0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.091,0.909,0.000,0.000,
  0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.091,0.918,0.000,
  0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.082,0.967
), nrow = 14, byrow = TRUE)

# Survival/transition matrix
matU <- matrix(c(
  0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0.833,0.043,0.883,0.151,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0.000,0.391,0.043,0.791,0.079,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0.000,0.348,0.000,0.047,0.810,0.139,0.024,0.024,0.056,0.000,0.000,0.000,0.000,0.000,
  0.000,0.000,0.000,0.000,0.087,0.733,0.082,0.071,0.000,0.000,0.000,0.000,0.000,0.000,
  0.000,0.087,0.000,0.000,0.000,0.129,0.765,0.024,0.000,0.013,0.000,0.000,0.000,0.000,
  0.000,0.000,0.000,0.000,0.000,0.000,0.106,0.762,0.028,0.000,0.000,0.000,0.000,0.000,
  0.000,0.000,0.000,0.000,0.000,0.000,0.012,0.119,0.778,0.000,0.000,0.000,0.000,0.000,
  0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.083,0.861,0.000,0.000,0.000,0.000,
  0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.127,0.891,0.000,0.000,0.000,
  0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.091,0.909,0.000,0.000,
  0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.091,0.918,0.000,
  0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.082,0.967
), nrow = 14, byrow = TRUE)

# Sexual reproduction matrix
matF <- matrix(c(
  0,0,0,0,0.002,0.007,0.021,0.055,0.102,0.139,0.176,0.191,0.203,0.214,
  0,0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
),nrow=14,byrow=TRUE)

#Clonal reproduction matrix
matC <- matrix(c(
  0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0.001,0.002,0.004,0.007,0.011,0.016,0.021,0.036,0.052,0.072,0.096,
  0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
  0,0,0,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000
), nrow = 14, byrow = TRUE)

#redefined matF so it includes sexual + clonal reproduction
matF <- matF + matC

#sanity check
stopifnot(all.equal(unname(matA), unname(matU + matF)))

#1                active          Seedling
#2                active            Sucker
#3                active         1-1 .9 mm
#4                active         2-3 .9 mm
#5                active         4-5 .9 mm
#6                active         6-7 .9 mm
#7                active        8-.9 .9 mm
#8                active       10-14 .9 mm
#9                active        15-19.9 mm
#10               active          20-29 mm
#11               active          30-39 mm
#12               active          40-49 mm
#13               active          50-69 mm
#14               active            70+ mm

# Stage aggregation used in later examples:
# Form the aggregated groups, leaving Seedling and Sucker stages alone
groups <- list(
  c(1),                #Seedling
  c(2),                #Sucker
  c(3, 4, 5, 6),       #1-7.9 mm
  c(7, 8, 9, 10),      #8-29 mm
  c( 11, 12, 13, 14)   #30+ mm
)

## ----helpers------------------------------------------------------------------
`%||%` <- function(x, y) if (!is.null(x)) x else y

get_R0 <- function(U, F) {
  I <- diag(nrow(U))
  N <- solve(I - U)
  K <- F %*% N
  spectral_radius(K)
}

get_Ta <- function(U, F) {
  generation_time(F, U, framework = "lambda")$generation_time
}

get_Tc <- function(U, F) {
  generation_time(F, U, framework = "R0")$generation_time
}

## ----aggregate----------------------------------------------------------------
agg1 <- mpm_aggregate(
  matU = matU,
  matF = matF,
  groups = groups,
  framework = "lambda",
  criterion = "standard"
)

agg2 <- mpm_aggregate(
  matU = matU,
  matF = matF,
  groups = groups,
  framework = "lambda",
  criterion = "elasticity"
)

agg3 <- mpm_aggregate(
  matU = matU,
  matF = matF,
  groups = groups,
  framework = "R0",
  criterion = "standard"
)

agg4 <- mpm_aggregate(
  matU = matU,
  matF = matF,
  groups = groups,
  framework = "R0",
  criterion = "elasticity"
)

# Extract aggregated U and F robustly (in case object field names differ by version)
extract_U <- function(x) x$matUk_agg %||% x$matU_agg %||% x$matUagg %||% x$U
extract_F <- function(x) x$matFk_agg %||% x$matF_agg %||% x$matFagg %||% x$F
extract_A <- function(x) x$matAk_agg %||% x$matA_agg %||% x$matAagg %||% x$A

as_ufA <- function(x) {
  U <- extract_U(x)
  F <- extract_F(x)
  A <- extract_A(x)
  if (is.null(A) && !is.null(U) && !is.null(F)) A <- U + F
  list(U = U, F = F, A = A)
}

m0 <- list(U = matU, F = matF, A = matA)
m1 <- as_ufA(agg1)
m2 <- as_ufA(agg2)
m3 <- as_ufA(agg3)
m4 <- as_ufA(agg4)


eff <- c(NA,agg1$effectiveness,agg2$effectiveness,agg3$effectiveness,agg4$effectiveness)

models <- list(
  "Original"                  = m0,
  "Agg lambda + standard"     = m1,
  "Agg lambda + elasticity"   = m2,
  "Agg R0 + standard"         = m3,
  "Agg R0 + elasticity"       = m4
)

model_labels <- c(
  Original         = "Original",
  Agg_lambda_std   = "Agg &lambda; + standard",
  Agg_lambda_elast = "Agg &lambda; + elasticity",
  Agg_R0_std       = "Agg R<sub>0</sub> + standard",
  Agg_R0_elast     = "Agg R<sub>0</sub> + elasticity"
)

names(models) <- c(
  "Original",
  "Agg_lambda_std",
  "Agg_lambda_elast",
  "Agg_R0_std",
  "Agg_R0_elast"
)

# Basic checks
stopifnot(all(sapply(models, function(m) !is.null(m$A))))
#stopifnot(all(sapply(models, function(m) nrow(m$A) == length(groups))))

#show aggregated matrices
print("Agg λ + standard")
m1$A
print("Agg λ + elasticity")
m2$A
print("Agg R0 + standard")
m3$A
print("Agg R0 + elasticity")
m4$A


## ----results-table------------------------------------------------------------

results <- data.frame(
  Model  = unname(model_labels[names(models)]),
  lambda = sapply(models, function(m) spectral_radius(m$A)),
  R0     = sapply(models, function(m) get_R0(m$U, m$F)),
  Ta     = sapply(models, function(m) get_Ta(m$U, m$F)),
  Tc     = sapply(models, function(m) get_Tc(m$U, m$F)),
  Effectiveness = eff,
  row.names = NULL,
  stringsAsFactors = FALSE
)

n_rows <- nrow(results)

knitr::kable(
  results,
  col.names = c(
    "Model",
    "<em>&lambda;</em>",
    "<em>R</em><sub>0</sub>",
    "<em>T</em><sub>a</sub> (years)",
    "<em>T</em><sub>c</sub> (years)",
    "<em>&rho;</em><sup>2</sup>"
  ),
  digits = 3,
  caption = paste0(
  "<strong>Table 1.</strong> Comparison of demographic properties for the original ",
  "stage-structured matrix population model of <em>Rourea induta</em> ",
  "(COMPADRE MatrixID 246781) and four aggregated versions. The original 14 stages ",
  "are collapsed to 5 while leaving the Seedling and Sucker stages unchanged. ",
  "&ldquo;Standard&rdquo; denotes use of a standard aggregator and ",
  "&ldquo;elasticity&rdquo; denotes elasticity-consistent aggregation. ",
  "<em>T</em><sub>a</sub> is generation time (years) in the <em>&lambda;</em> framework and ",
  "<em>T</em><sub>c</sub> is cohort generation time (years) in the ",
  "<em>R</em><sub>0</sub> framework. ",
  "<em>&rho;</em><sup>2</sup> quantifies aggregation effectiveness, with values closer to 1 ",
  "indicating closer agreement between the aggregated and reference models."
  ),
  format = "html",
  escape = FALSE
) |>
  kable_styling(full_width = TRUE) |>
  row_spec(0, extra_css = "border-bottom: 2px solid black;") |>
  row_spec(nrow(results), extra_css = "border-bottom: 2px solid black;") |>
  footnote(
    general_title = "",
    general = paste0(
      "<span style='font-size: 90%;'>",
      "<strong>Note.</strong> The projection interval associated with the population ",
      "growth rate (<em>&lambda;</em>) is 1 year.</span>"
    ),
    footnote_as_chunk = TRUE,
    escape = FALSE
  )

## ----elasticities-------------------------------------------------------------

# Partitioning matrix induced by the stage groups
# This matrix is constructed internally by mpm_aggregate(),
# but we reproduce it here because it can be used to map quantities
# from the original stage space to the aggregated space.
P <- mpm_partition(groups=groups, n=nrow(matA))

# Elasticity matrices in aggregated space (k x k) for each model
E_A <- mpm_elasticity(matA=matA,framework="lambda")$elasticity
E_list <- list(
  "Original" = P %*% E_A %*% t(P),
  "Agg lambda + standard" = mpm_elasticity(matA=m1$A,framework="lambda")$elasticity,
  "Agg lambda + elasticity" = mpm_elasticity(matA=m2$A,framework="lambda")$elasticity,
  "Agg R0 + standard" = mpm_elasticity(matA=m3$A,framework="lambda")$elasticity,
  "Agg R0 + elasticity" = mpm_elasticity(matA=m4$A,framework="lambda")$elasticity
)

# Build a long data.frame for all nonzero entries (row-major order)
to_long <- function(M, model_name) {
  idx <- which(M != 0, arr.ind = TRUE)
  data.frame(
    model = model_name,
    row = idx[, 1],
    col = idx[, 2],
    entry = paste(idx[, 1], idx[, 2], sep = ","),
    elasticity = M[idx],
    stringsAsFactors = FALSE
  )
}

elast_df <- do.call(rbind, Map(to_long, E_list, names(E_list)))

# Keep positive values for log scale
elast_df <- elast_df[elast_df$elasticity > 0, ]

# Row-major ordering: row 1 entries first, row 2 second, rows 3+ third, etc.
elast_df <- elast_df[order(elast_df$row, elast_df$col), ]

models_order <- names(E_list)
entries <- unique(elast_df$entry)

# Create matrix for barplot: rows = models (5), cols = entries
elast_mat <- sapply(models_order, function(m) {
  elast_df$elasticity[elast_df$model == m]
})
elast_mat <- t(elast_mat)
rownames(elast_mat) <- models_order
colnames(elast_mat) <- entries

## ----lambda-elasticity-plot, fig.width=12-------------------------------------
# Build biologically meaningful entry labels
# Build (row, col) vectors in the exact column order of elast_mat
rc <- do.call(rbind, strsplit(colnames(elast_mat), ","))
r_vec <- as.integer(rc[, 1])
c_vec <- as.integer(rc[, 2])

make_entry_label <- function(r, c) {
  if (r == 1) {
    bquote(F[1, .(c)])
  } else if (r == 2) {
    bquote(C[2, .(c)])
  } else {
    bquote(U[.(r), .(c)])
  }
}

entry_labels <- mapply(make_entry_label, r_vec, c_vec, SIMPLIFY = FALSE)


# Colors by model
# Same colors as the vital-rate plot for consistency
cols <- c("#1b9e77", "#d95f02", "#7570b3", "#e7298a", "#66a61e")

# Leave room for left legend and long x labels
op <- par(mar = c(7, 6, 4, 2))

bp <- barplot(
  elast_mat,
  beside = TRUE,
  log = "y",
  col = cols,
  border = NA,
  ylab = expression("Elasticity of " * lambda * " (log scale)"),
  xaxt = "n",
  space = c(0.2, 1)   # ← this is the key line
)

axis(
  side = 1,
  at = colMeans(bp),
  labels = entry_labels,
  las = 2,
  cex.axis = 0.9
)

legend(
  "topleft",
  legend = c(
    "Original",
    expression(paste("Agg ", lambda, " + standard")),
    expression(paste("Agg ", lambda, " + elasticity")),
    expression(paste("Agg ", R[0], " + standard")),
    expression(paste("Agg ", R[0], " + elasticity"))
  ),
  fill = cols,
  bty = "n",
  inset = 0.02
)

par(op)

## ----elasticities of R0-------------------------------------------------------
# Partitioning matrix induced by the stage group
# Reintroduced here so this section can be read independently
P <- mpm_partition(groups=groups, n=nrow(matA))

# Elasticity matrices in aggregated space (k x k) for each model
# The elasticity matrix of the original model is presented in its aggregated
# form using the partitioning matrix P. This is the true aggregated form of the elasticity matrix
# to which elasticities derived from aggregated models are compared.
E_A <- mpm_elasticity(matF=matF,matU=matU,framework="R0", normalize=TRUE)$elasticity
E_list <- list(
  "Original" = P %*% E_A %*% t(P),
  "Agg lambda + standard" = mpm_elasticity(matF=m1$F, matU=m1$U,framework="R0", normalize=TRUE)$elasticity,
  "Agg lambda + elasticity" = mpm_elasticity(matF=m2$F, matU=m2$U,framework="R0", normalize=TRUE)$elasticity,
  "Agg R0 + standard" = mpm_elasticity(matF=m3$F, matU=m3$U,framework="R0", normalize=TRUE)$elasticity,
  "Agg R0 + elasticity" = mpm_elasticity(matF=m4$F, matU=m4$U,framework="R0", normalize=TRUE)$elasticity
)

# Build a long data.frame for all nonzero entries (row-major order)
to_long <- function(M, model_name) {
  idx <- which(M != 0, arr.ind = TRUE)
  data.frame(
    model = model_name,
    row = idx[, 1],
    col = idx[, 2],
    entry = paste(idx[, 1], idx[, 2], sep = ","),
    elasticity = M[idx],
    stringsAsFactors = FALSE
  )
}

elast_df <- do.call(rbind, Map(to_long, E_list, names(E_list)))

# Keep positive values for log scale
elast_df <- elast_df[elast_df$elasticity > 0, ]

# Row-major ordering: row 1 entries first, row 2 second, rows 3+ third, etc.
elast_df <- elast_df[order(elast_df$row, elast_df$col), ]

models_order <- names(E_list)
entries <- unique(elast_df$entry)

# Create matrix for barplot: rows = models (5), cols = entries
elast_mat2 <- sapply(models_order, function(m) {
  elast_df$elasticity[elast_df$model == m]
})
elast_mat2 <- t(elast_mat2)
rownames(elast_mat2) <- models_order
colnames(elast_mat2) <- entries

## ----elasticity-R0-plot, fig.width=12, fig.height=6---------------------------
# Build biologically meaningful entry labels
# Build (row, col) vectors in the exact column order of elast_mat
rc <- do.call(rbind, strsplit(colnames(elast_mat2), ","))
r_vec <- as.integer(rc[, 1])
c_vec <- as.integer(rc[, 2])

make_entry_label <- function(r, c) {
  if (r == 1) {
    bquote(F[1, .(c)])
  } else if (r == 2) {
    bquote(C[2, .(c)])
  } else {
    bquote(U[.(r), .(c)])
  }
}

entry_labels <- mapply(make_entry_label, r_vec, c_vec, SIMPLIFY = FALSE)


# Colors by model
# Same colors as the vital-rate plot for consistency
cols <- c("#1b9e77", "#d95f02", "#7570b3", "#e7298a", "#66a61e")

# Leave room for left legend and long x labels
op <- par(mar = c(7, 6, 4, 2))

bp <- barplot(
  elast_mat2,
  beside = TRUE,
  log = "y",
  col = cols,
  border = NA,
  ylab = expression("Normalized elasticity of " * R[0] * " (log scale)"),
  xaxt = "n",
  space = c(0.2, 1)   # ← this is the key line
)

axis(
  side = 1,
  at = colMeans(bp),
  labels = entry_labels,
  las = 2,
  cex.axis = 0.9
)

legend(
  "topleft",
  legend = c(
    "Original",
    expression(paste("Agg ", lambda, " + standard")),
    expression(paste("Agg ", lambda, " + elasticity")),
    expression(paste("Agg ", R[0], " + standard")),
    expression(paste("Agg ", R[0], " + elasticity"))
  ),
  fill = cols,
  bty = "n",
  inset = 0.02
)

par(op)

## ----lambda-vs.-R0-elasticities-----------------------------------------------
#compare elasticities of top 3 vital rates
#top three elasticities in lambda framework
sum(elast_mat[,c(9,13,15)])/5
#top three elasticities in R0 framework
sum(elast_mat2[,c(9,13,15)])/5

