## ----message=FALSE, warning=FALSE------------------------------------------
library(SIAMCAT)
fn.in.feat  <- system.file(
    "extdata",
    "feat_crc_study-pop-I_N141_tax_profile_mocat_bn_specI_clusters.tsv",
    package = "SIAMCAT"
)
fn.in.label <- system.file(
    "extdata",
    "label_crc_study-pop-I_N141_tax_profile_mocat_bn_specI_clusters.tsv",
    package = "SIAMCAT"
)
fn.in.meta  <- system.file(
    "extdata",
    "num_metadata_crc_study-pop-I_N141_tax_profile_mocat_bn_specI_clusters.tsv",
    package = "SIAMCAT"
)

## ----results="hide"--------------------------------------------------------
feat  <- read.features(fn.in.feat)
label <- read.labels(fn.in.label)
meta  <- read.meta(fn.in.meta)
siamcat <- siamcat(feat, label, meta)

## --------------------------------------------------------------------------
show(siamcat)

## --------------------------------------------------------------------------
phyloseq <- physeq(siamcat)
show(phyloseq)

## --------------------------------------------------------------------------
siamcat <- validate.data(siamcat, verbose=1)

## --------------------------------------------------------------------------
siamcat <- select.samples(
    siamcat,
    filter = 'age',
    allowed.set = NULL,
    allowed.range = c(20, 90),
    verbose = 2
)

## --------------------------------------------------------------------------
siamcat <- filter.features(
    siamcat,
    filter.method = 'abundance',
    cutoff = 0.001,
    recomp.prop = FALSE,
    rm.unmapped = TRUE,
    verbose = 2
)

## ----eval=FALSE------------------------------------------------------------
#  ## Not run here, since the function produces a pdf-file as output
#  check.confounders(siamcat,
#      fn.plot = 'conf_check.pdf')

## ----fig.height = 6, fig.width = 6, fig.align="center", echo=FALSE---------
label      <- label(siamcat)
case.count <- length(label$label[label$p.idx])
ctrl.count <- length(label$label[label$n.idx])
if (case.count > ctrl.count) {
    lgr <- label$p.idx
    smlr <- label$n.idx
    bp.labs <- c(label$p.lab, label$n.lab)
} else {
    lgr <- label$n.idx
    smlr <- label$p.idx
    bp.labs <- c(label$n.lab, label$p.lab)
}
len.diff <- abs(case.count - ctrl.count)
hmap <- data.frame()
m = 1
phyloseq <- physeq(siamcat)
sam_data <- sample_data(phyloseq)
mname <- gsub('[_.-]', ' ', colnames(sam_data)[m])
mname <-
    paste(toupper(substring(mname, 1, 1)), substring(mname, 2), sep = "")

mvar  <- as.numeric(unlist(sam_data[, m]))
u.val <- unique(mvar)
u.val <- u.val[!is.na(u.val)]
colors <- RColorBrewer::brewer.pal(5, "Spectral")
histcolors <- RColorBrewer::brewer.pal(9, "YlGnBu")

dct <- matrix(NA, nrow = 2, ncol = 2)
dct[1, ] <-
    c(sum(mvar[label$n.idx] <= median(mvar, na.rm = TRUE), na.rm = TRUE),
        sum(mvar[label$p.idx] <= median(mvar, na.rm = TRUE), na.rm = TRUE))
dct[2, ] <-
    c(sum(mvar[label$n.idx] > median(mvar, na.rm = TRUE), na.rm = TRUE),
        sum(mvar[label$p.idx] > median(mvar, na.rm = TRUE), na.rm = TRUE))
rownames(dct) <- c(paste(mname, "<= med"), paste(mname, "> med"))
hmap <- rbind(hmap, dct)
layout(rbind(c(1, 2), c(3, 4)))

# par(mar=c(4.5, 4.5, 2.5, 1.5),mgp=c(2.5,1,0))
ax.int <- c(min(mvar, na.rm = TRUE), max(mvar, na.rm = TRUE))
qqplot(
    mvar[label$n.idx],
    mvar[label$p.idx],
    xlim = ax.int,
    ylim = ax.int,
    pch = 16,
    cex = 0.6,
    xlab = label$n.lab,
    ylab = label$p.lab,
    main = paste('Q-Q plot for', mname)
)
abline(0, 1, lty = 3)
p.val  <- wilcox.test(mvar[label$n.idx], mvar[label$p.idx],
    exact = FALSE)$p.value
text(
    ax.int[1] + 0.9 * (ax.int[2] - ax.int[1]),
    ax.int[1] + 0.1 * (ax.int[2] - ax.int[1]),
    cex = 0.8,
    paste('MWW test p-value:', format(p.val, digits = 4)),
    pos = 2
)

# par(mar=c(4, 2.5, 3.5, 1.5))
hist(
    mvar[label$n.idx],
    main = label$n.lab,
    xlab = mname,
    col = histcolors,
    breaks = seq(min(mvar, na.rm = TRUE), max(mvar, na.rm = TRUE), length.out =
            10)
)
mtext(
    paste('N =', length(mvar[label$n.idx])),
    cex = 0.6,
    side = 3,
    adj = 1,
    line = 1
)

par(mar = c(2.5, 4.5, 2.5, 1.5))
combine <- data.frame(mvar[lgr], c(mvar[smlr], rep(NA, len.diff)))
boxplot(
    combine[, 1],
    na.omit(combine[, 2]),
    use.cols = TRUE,
    names = bp.labs,
    ylab = mname,
    main = paste('Boxplot for', mname),
    col = histcolors
)
stripchart(
    combine,
    vertical = TRUE,
    add = TRUE,
    method = "jitter",
    pch = 20
)

par(mar = c(4.5, 2.5, 3.5, 1.5))
hist(
    mvar[label$p.idx],
    main = label$p.lab,
    xlab = mname,
    col = histcolors,
    breaks = seq(min(mvar, na.rm = TRUE), max(mvar, na.rm = TRUE), length.out =
            10)
)
mtext(
    paste('N =', length(mvar[label$p.idx])),
    cex = 0.6,
    side = 3,
    adj = 1,
    line = 1
)
par(mfrow = c(1, 1))

## ----eval=FALSE------------------------------------------------------------
#  ## Not run here, since the function produces a pdf-file as output
#  check.associations(
#      siamcat,
#      sort.by = 'fc',
#      fn.plot = 'assoc.pdf',
#      alpha = 0.05,
#      mult.corr = "fdr",
#      detect.lim = 10 ^ -6,
#      max.show = 50,
#      plot.type = "quantile.box",
#      panels = c("fc", "prevalence", "auroc"),
#      verbose = 2
#  )

## --------------------------------------------------------------------------
siamcat <- normalize.features(
    siamcat,
    norm.method = "log.unit",
    norm.param = list(
        log.n0 = 1e-06,
        n.p = 2,
        norm.margin = 1
    ),
    verbose = 2
)

## --------------------------------------------------------------------------
siamcat <-  create.data.split(
    siamcat,
    num.folds = 5,
    num.resample = 2,
    stratify = TRUE,
    inseparable = NULL,
    verbose = 2
)

## --------------------------------------------------------------------------
siamcat <- train.model(
    siamcat,
    method = "lasso",
    stratify = TRUE,
    modsel.crit = list("pr"),
    min.nonzero.coeff = 5,
    param.set = NULL,
    verbose = 3
)

## --------------------------------------------------------------------------
model_list <- model_list(siamcat)

## --------------------------------------------------------------------------
model_type(siamcat)

## --------------------------------------------------------------------------
models <- models(siamcat)
models[[1]]

## --------------------------------------------------------------------------
siamcat <- make.predictions(siamcat, verbose=0)
pred_matrix <- pred_matrix(siamcat)
head(pred_matrix)

## --------------------------------------------------------------------------
siamcat <-  evaluate.predictions(siamcat, verbose=2)

## ----eval=FALSE------------------------------------------------------------
#  ## Not run here, since the function produces a pdf-file as output
#  model.evaluation.plot(siamcat,'eval_plot.pdf',verbose = 2)

## ----fig.width = 6, fig.asp=1, fig.align="center"--------------------------
# plot ROC Curves
plot(
    NULL,
    xlim = c(0, 1),
    ylim = c(0, 1),
    xlab = 'False positive rate',
    ylab = 'True positive rate',
    type = 'n'
)
title('ROC curve for the model')
abline(a = 0, b = 1, lty = 3)
# for each resampled CV run
eval_data <- eval_data(siamcat)
for (r in 1:length(eval_data$roc.all)) {
    roc.c = eval_data$roc.all[[r]]
    lines(1 - roc.c$specificities, roc.c$sensitivities, 
        col = gray(runif(1, 0.2, 0.8)))
}
# mean ROC curve
roc.summ = eval_data$roc.average[[1]]
lines(1 - roc.summ$specificities,
    roc.summ$sensitivities,
    col = 'black',
    lwd = 2)
# plot CI
x = as.numeric(rownames(roc.summ$ci))
yl = roc.summ$ci[, 1]
yu = roc.summ$ci[, 3]
polygon(1 - c(x, rev(x)), c(yl, rev(yu)), col = '#88888844' , border = NA)

## ----eval=FALSE------------------------------------------------------------
#  ## Not run here, since the function produces a pdf-file as output
#  model.interpretation.plot(
#      siamcat,
#      fn.plot = 'interpretation.pdf',
#      consens.thres = 0.5,
#      norm.models = TRUE,
#      limits = c(-3, 3),
#      heatmap.type = 'zscore',
#      verbose = 2
#  )

## ----fig.width=7, fig.height=5, fig.align="center", echo=FALSE-------------


color.scheme <- rev(colorRampPalette(
    RColorBrewer::brewer.pal(RColorBrewer::brewer.pal.info['BrBG', 'maxcolors'],
        'BrBG')
)(100))

W.mat       <-
    SIAMCAT:::get.weights.matrix(models(siamcat), verbose = 0)
feat        <- get.features.matrix(siamcat)
all.weights <-
    W.mat[union(row.names(feat), grep('META', row.names(W.mat),
        value = TRUE)), ]
rel.weights <- apply(all.weights, 2, function(x) {
    x / sum(abs(x))
})

sel.idx <-
    SIAMCAT:::model.interpretation.select.features(
        weights = all.weights,
        model.type = model_type(siamcat),
        consens.thres = 0.5,
        label = label(siamcat),
        norm.models = TRUE,
        max.show = 50,
        verbose = 0
    )

mean.agg.pred <- rowMeans(pred_matrix)
srt.idx <- sort(label$label + mean.agg.pred, index.return = TRUE)$ix

img.data <- SIAMCAT:::model.interpretation.prepare.heatmap.zscore(heatmap.data =
        feat[sel.idx, srt.idx],
    limits = c(-3, 3),
    verbose = 0)


# plot stuff
layout(c(1, 2), heights = c(0.1, 0.9))
par(mar = c(0, 2, 1, 10))
hm.label <- label$label[srt.idx]
plot(
    NULL,
    type = 'n',
    xlim = c(0, length(hm.label)),
    xaxs = 'i',
    xaxt = 'n',
    ylim = c(-0.5, 0.5),
    yaxs = 'i',
    yaxt = 'n',
    xlab = '',
    ylab = '',
    bty = 'n'
)
ul <- unique(hm.label)
for (l in 1:length(ul)) {
    idx <- which(ul[l] == hm.label)
    lines(c(idx[1] - 0.8, idx[length(idx)] - 0.2), c(0, 0))
    lines(c(idx[1] - 0.8, idx[1] - 0.8), c(-0.2, 0))
    lines(c(idx[length(idx)] - 0.2, idx[length(idx)] - 0.2), c(-0.2, 0))
    h <- (idx[1] + idx[length(idx)]) / 2
    t <- gsub('_', ' ',
        names(label$info$class.descr)[label$info$class.descr == ul[l]])
    t <- paste(t, ' (n=', length(idx), ')', sep = '')
    mtext(
        t,
        side = 3,
        line = -0.5,
        at = h,
        cex = 0.7,
        adj = 0.5
    )
}
mtext(
    'Metagenomic Features',
    side = 3,
    line = 2,
    at = length(hm.label) / 2,
    cex = 1,
    adj = 0.5
)


SIAMCAT:::model.interpretation.heatmap.plot(
    image.data = img.data,
    limits = c(-3, 3),
    color.scheme = color.scheme,
    effect.size = apply(rel.weights[sel.idx, ],
        1, median),
    verbose = 0
)

## --------------------------------------------------------------------------
sessionInfo()