## ----include = FALSE----------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)

## ----setup--------------------------------------------------------------------
library(tempodisco)

## -----------------------------------------------------------------------------
# Load data and fit model
data("td_bc_single_ptpt")
mod <- td_bcnm(td_bc_single_ptpt, discount_function = "exponential")
# Plot the discount curve
plot(mod, type = "summary", verbose = F)

## -----------------------------------------------------------------------------
# Plot the choice rule
plot(mod, type = "endpoints", del = 50, val_del = 200, verbose = F)

## -----------------------------------------------------------------------------
vis_del <- sort(unique(td_bc_single_ptpt$del))[2]
newdata <- data.frame(del = vis_del, val_del = 200, val_imm = seq(0, 200, length.out = 1000))
plot(imm_chosen ~ val_imm, data = subset(td_bc_single_ptpt, del == vis_del),
     xlim = c(0, 200), ylim = c(0, 1))
plot_legend <- c("red" = "logistic",
                 "blue" = "probit",
                 "forestgreen" = "power")
logLiks <- c()
for (entry in names(plot_legend)) {
  choice_rule <- plot_legend[entry]
  mod <- td_bcnm(td_bc_single_ptpt,
                 discount_function = "exponential",
                 choice_rule = choice_rule)
  logLiks[entry] <- logLik(mod)
  preds <- predict(mod, type = 'response', newdata = newdata)
  lines(preds ~ newdata$val_imm, col = entry)
}
legend(0, 1,
       fill = names(plot_legend),
       legend = paste(plot_legend, '; log lik. = ', round(logLiks, 1), sep = ''))

## -----------------------------------------------------------------------------
vis_del <- sort(unique(td_bc_single_ptpt$del))[2]
vis_val_del <- 200
newdata <- data.frame(del = vis_del,
                      val_del = vis_val_del,
                      val_imm = seq(0, vis_val_del, length.out = 1000))
plot(imm_chosen ~ val_imm, data = subset(td_bc_single_ptpt, del == vis_del),
     xlim = c(0, vis_val_del), ylim = c(0, 1))
plot_legend <- c("red" = "logistic",
                 "blue" = "probit",
                 "forestgreen" = "power")
logLiks <- c()
for (entry in names(plot_legend)) {
  choice_rule <- plot_legend[entry]
  mod <- td_bcnm(td_bc_single_ptpt,
                 discount_function = "exponential",
                 fixed_ends = TRUE,                 # Fixed endpoints
                 choice_rule = choice_rule)
  logLiks[entry] <- logLik(mod)
  preds <- predict(mod, type = 'response', newdata = newdata)
  lines(preds ~ newdata$val_imm, col = entry)
}
legend(0, 1,
       fill = names(plot_legend),
       legend = paste(plot_legend, '; log lik. = ', round(logLiks, 1), sep = ''))

## -----------------------------------------------------------------------------
vis_del <- sort(unique(td_bc_single_ptpt$del))[2]
vis_val_del <- 1000
newdata <- data.frame(del = vis_del,
                      val_del = vis_val_del,
                      val_imm = seq(0, vis_val_del, length.out = 1000))
plot(c(0, 1) ~ c(0, vis_val_del), type = "n",
     xlab = "val_imm", ylab = "prob_imm")
plot_legend <- c("red" = "logistic",
                 "blue" = "probit",
                 "forestgreen" = "power")
for (entry in names(plot_legend)) {
  choice_rule <- plot_legend[entry]
  mod <- td_bcnm(td_bc_single_ptpt,
                 discount_function = "exponential",
                 fixed_ends = TRUE,
                 choice_rule = choice_rule)
  preds <- predict(mod, type = 'response', newdata = newdata)
  lines(preds ~ newdata$val_imm, col = entry)
}

## -----------------------------------------------------------------------------
mod_free <- td_bcnm(td_bc_single_ptpt,
                    discount_function = "hyperbolic",
                    fixed_ends = FALSE)
mod_fixed <- td_bcnm(td_bc_single_ptpt,
                     discount_function = "hyperbolic",
                     fixed_ends = TRUE)
cat(sprintf('Log lik. with free endpoints: %.2f\n', logLik(mod_free)))
cat(sprintf('Log lik. with fixed endpoints: %.2f\n', logLik(mod_fixed)))

## -----------------------------------------------------------------------------
mod_cauchit <- td_bcnm(td_bc_single_ptpt,
                       discount_function = "exponential",
                       noise_dist = "cauchy",
                       gamma_scale = "linear",
                       transform = "identity")
plot(mod_cauchit, type = "endpoints", del = 50, val_del = 200)

## -----------------------------------------------------------------------------
mod_cauchit_fixed <- td_bcnm(td_bc_single_ptpt,
                             discount_function = "exponential",
                             noise_dist = "cauchy",
                             gamma_scale = "linear",
                             transform = "noise_dist_quantile")
plot(mod_cauchit_fixed, type = "endpoints", del = 50, val_del = 200)

## -----------------------------------------------------------------------------
mod_lognormal <- td_bcnm(td_bc_single_ptpt,
                         discount_function = "exponential",
                         noise_dist = "norm",
                         gamma_scale = "none",
                         transform = "log")
plot(mod_lognormal, type = "endpoints", del = 50, val_del = 200)

