Skip to content

Commit

Permalink
Merge pull request #27 from stephenslab/eweine/add_subset_option
Browse files Browse the repository at this point in the history
Eweine/add subset option
  • Loading branch information
pcarbo authored Aug 23, 2024
2 parents ed33367 + 49ebe86 commit 840a016
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 5 deletions.
107 changes: 102 additions & 5 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,21 @@
#' control argument for \code{\link[daarem]{daarem}}. This setting
#' determines to what extent the monotonicity condition can be
#' violated.}
#'
#' \item{\code{training_frac}}{Fraction of the columns of input data \code{Y}
#' to fit initial model on. If set to \code{1} (default), the model is fit
#' by optimizing the parameters on the entire dataset. If set between \code{0}
#' and \code{1}, the model is optimized by first fitting a model on a randomly
#' selected fraction of the columns of \code{Y}, and then projecting the
#' remaining columns of \code{Y} onto the solution. Setting this to a smaller
#' value will increase speed but decrease accuracy.
#' }
#'
#' \item{\code{num_projection_ccd_iter}}{Number of co-ordinate descent updates
#' be made to elements of \code{V} if and when a subset of \code{Y} is
#' projected onto \code{U}. Only used if \code{training_frac} is less than
#' \code{1}.
#' }
#'
#' \item{\code{num_ccd_iter}}{Number of co-ordinate descent updates to
#' be made to parameters at each iteration of the algorithm.}
Expand Down Expand Up @@ -196,7 +211,7 @@ fit_glmpca_pois <- function(
# Check and process input argument "control".
control <- modifyList(fit_glmpca_pois_control_default(),
control,keep.null = TRUE)

# Set up the internal fit.
D <- sqrt(fit0$d)
if (K == 1)
Expand All @@ -205,7 +220,7 @@ fit_glmpca_pois <- function(
D <- diag(D)
LL <- t(cbind(fit0$U %*% D,fit0$X,fit0$W))
FF <- t(cbind(fit0$V %*% D,fit0$B,fit0$Z))

# Determine which rows of LL and FF are "clamped".
fixed_l <- numeric(0)
fixed_f <- numeric(0)
Expand All @@ -217,9 +232,86 @@ fit_glmpca_pois <- function(
fixed_f <- c(fixed_f,K + fit0$fixed_b_cols)
if (nz > 0)
fixed_f <- c(fixed_f,K + nx + seq(1,nz))

# Perform the updates.
res <- fit_glmpca_pois_main_loop(LL,FF,Y,fixed_l,fixed_f,verbose,control)

if (control$training_frac == 1) {

# Perform the updates.
res <- fit_glmpca_pois_main_loop(LL,FF,Y,fixed_l,fixed_f,verbose,control)

} else {

if (control$training_frac <= 0 || control$training_frac > 1)
stop("control argument \"training_frac\" should be between 0 and 1")

train_idx <- sample(
1:ncol(Y),
size = ceiling(ncol(Y) * control$training_frac)
)

browser()
Y_train <- Y[, train_idx]

if (any(Matrix::rowSums(Y_train) == 0) || any(Matrix::colSums(Y_train) == 0)) {

stop(
"After subsetting, the remaining values of \"Y\" ",
"contain a row or a column where all counts are 0. This can cause ",
"problems with optimization. Please either remove rows / columns ",
"with few non-zero counts from \"Y\", or set \"training_frac\" to ",
"a larger value."
)

}

FF_train <- FF[, train_idx]
FF_test <- FF[, -train_idx]
Y_test <- Y[, -train_idx]

test_idx <- 1:ncol(Y)
test_idx <- test_idx[-train_idx]

# Perform the updates.
res <- fit_glmpca_pois_main_loop(
LL,
FF_train,
Y_train,
fixed_l,
fixed_f,
verbose,
control
)

update_indices_f <- sort(setdiff(1:K,fixed_f))

# now, I just need to project the results back
update_factors_faster_parallel(
L_T = t(res$fit$LL),
FF = FF_test,
M = as.matrix(res$fit$LL[update_indices_f,,drop = FALSE] %*% Y_test),
update_indices = update_indices_f - 1,
num_iter = control$num_projection_ccd_iter,
line_search = control$line_search,
alpha = control$ls_alpha,
beta = control$ls_beta
)

# now, I need to reconstruct FF, and hopefully compute the log-likelihood
FF[, train_idx] <- res$fit$FF
FF[, test_idx] <- FF_test
res$fit$FF <- FF

if (inherits(Y,"sparseMatrix")) {
test_loglik_const <- sum(mapSparse(Y_test,lfactorial))
loglik_func <- lik_glmpca_pois_log_sp
} else {
test_loglik_const <- sum(lfactorial(Y_test))
loglik_func <- lik_glmpca_pois_log
}

test_loglik <- loglik_func(Y_test,res$fit$LL,FF_test,test_loglik_const)
res$loglik <- res$loglik + test_loglik

}

# Prepare the final output.
res$progress$iter <- max(fit0$progress$iter) + res$progress$iter
Expand Down Expand Up @@ -258,9 +350,12 @@ fit_glmpca_pois <- function(
dimnames(fit$W) <- dimnames(fit0$W)
}
class(fit) <- c("glmpca_pois_fit","list")

return(fit)

}


# This implements the core part of fit_glmpca_pois.
#
#' @importFrom Matrix t
Expand Down Expand Up @@ -358,6 +453,8 @@ fit_glmpca_pois_control_default <- function()
list(use_daarem = FALSE,
maxiter = 100,
tol = 1e-4,
training_frac = 1,
num_projection_ccd_iter = 10,
mon.tol = 0.05,
convtype = "objfn",
line_search = TRUE,
Expand Down
Binary file added inst/.DS_Store
Binary file not shown.
48 changes: 48 additions & 0 deletions inst/scratch/test_projection_method.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
library(fastglmpca)

set.seed(1)
cc <- pbmc_facs$counts[Matrix::rowSums(pbmc_facs$counts) > 10, ]

fit1 <- fit_glmpca_pois(
Y = cc,
K = 2,
control = list(training_frac = 0.99, maxiter = 10)
)

# for some reason the calculated log-likelihood and the expected
# are not matching up
set.seed(1)
fit2 <- fit_glmpca_pois(
Y = pbmc_facs$counts,
K = 2,
control = list(training_frac = 0.25, maxiter = 10, num_projection_ccd_iter = 25)
)

set.seed(1)
fit3 <- fit_glmpca_pois(
Y = pbmc_facs$counts,
K = 2,
control = list(training_frac = 0.25, maxiter = 10, num_projection_ccd_iter = 5)
)
#
# df1 <- data.frame(
# celltype = pbmc_facs$samples$celltype,
# PC1 = fit1$V[,1],
# PC2 = fit1$V[,2]
# )
#
# library(ggplot2)
#
# ggplot(data = df1) +
# geom_point(aes(x = PC1, y = PC2, color = celltype))
#
# df2 <- data.frame(
# celltype = pbmc_facs$samples$celltype,
# PC1 = fit2$V[,1],
# PC2 = fit2$V[,2]
# )
#
# library(ggplot2)
#
# ggplot(data = df2) +
# geom_point(aes(x = PC1, y = PC2, color = celltype))
15 changes: 15 additions & 0 deletions man/fit_glmpca_pois.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 840a016

Please sign in to comment.