Skip to content

Commit

Permalink
Fix #377
Browse files Browse the repository at this point in the history
  • Loading branch information
strengejacke committed Feb 4, 2025
1 parent e2a2107 commit 7aa93cb
Show file tree
Hide file tree
Showing 11 changed files with 148 additions and 25 deletions.
6 changes: 5 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
- Argument `fixed` has been removed, as you can fix predictor at certain values
using the `by` argument.

- Argument `transform` is deprecated. Please use `predict` instead.
- Argument `transform` is no longer used to determin the scale of the predictions.
Please use `predict` instead.

- Argument `transform` is now used to (back-) transform predictions and confidence
intervals.

- Argument `method` in `estimate_contrasts()` was renamed into `comparison`.

Expand Down
7 changes: 1 addition & 6 deletions R/estimate_contrasts.R
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,6 @@ estimate_contrasts.default <- function(model,
transform = NULL,
verbose = TRUE,
...) {
## TODO: remove deprecation warning later
if (!is.null(transform)) {
insight::format_warning("Argument `transform` is deprecated. Please use `predict` instead.")
predict <- transform
}

if (backend == "emmeans") {
# Emmeans ------------------------------------------------------------------
estimated <- get_emcontrasts(model,
Expand All @@ -139,6 +133,7 @@ estimate_contrasts.default <- function(model,
p_adjust = p_adjust,
ci = ci,
estimate = estimate,
transform = transform,
verbose = verbose,
...
)
Expand Down
33 changes: 24 additions & 9 deletions R/estimate_means.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,13 @@
#' `options(modelbased_backend = "emmeans")` to use the **emmeans** package or
#' `options(modelbased_backend = "marginaleffects")` to set **marginaleffects**
#' as default backend.
#' @param transform Deprecated, please use `predict` instead.
#' @param transform A function applied to predictions and confidence intervals
#' to (back-) transform results, which can be useful in case the regression
#' model has a transformed response variable (e.g., `lm(log(y) ~ x)`). For
#' Bayesian models, this function is applied to individual draws from the
#' posterior distribution, before computing summaries. Can also be `TRUE`, in
#' which case `insight::get_transformation()` is called to determine the
#' appropriate transformation-function.
#' @param verbose Use `FALSE` to silence messages and warnings.
#' @param ... Other arguments passed, for instance, to [insight::get_datagrid()],
#' to functions from the **emmeans** or **marginaleffects** package, or to process
Expand Down Expand Up @@ -179,12 +185,6 @@ estimate_means <- function(model,
transform = NULL,
verbose = TRUE,
...) {
## TODO: remove deprecation warning later
if (!is.null(transform)) {
insight::format_warning("Argument `transform` is deprecated. Please use `predict` instead.")
predict <- transform
}

# validate input
estimate <- insight::validate_argument(
estimate,
Expand All @@ -193,11 +193,26 @@ estimate_means <- function(model,

if (backend == "emmeans") {
# Emmeans ------------------------------------------------------------------
estimated <- get_emmeans(model, by = by, predict = predict, verbose = verbose, ...)
estimated <- get_emmeans(
model,
by = by,
predict = predict,
verbose = verbose,
...
)
means <- .format_emmeans_means(estimated, model, ci = ci, verbose = verbose, ...)
} else {
# Marginalmeans ------------------------------------------------------------
estimated <- get_marginalmeans(model, by = by, predict = predict, ci = ci, estimate = estimate, verbose = verbose, ...) # nolint
estimated <- get_marginalmeans(
model,
by = by,
predict = predict,
ci = ci,
estimate = estimate,
transform = transform,
verbose = verbose,
...
)
means <- format(estimated, model, ...)
}

Expand Down
25 changes: 25 additions & 0 deletions R/estimate_predicted.R
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@
#' you are directly predicting the value of some distributional parameter), and
#' the corresponding functions will then only differ in the default value of
#' their `data` argument.
#' @param transform A function applied to predictions and confidence intervals
#' to (back-) transform results, which can be useful in case the regression
#' model has a transformed response variable (e.g., `lm(log(y) ~ x)`). Can also
#' be `TRUE`, in which case `insight::get_transformation()` is called to
#' determine the appropriate transformation-function. **Note:** Standard errors
#' are not (back-) transformed!
#' @param ... You can add all the additional control arguments from
#' [insight::get_datagrid()] (used when `data = "grid"`) and
#' [insight::get_predicted()].
Expand Down Expand Up @@ -228,6 +234,7 @@ estimate_expectation <- function(model,
by = NULL,
predict = "expectation",
ci = 0.95,
transform = NULL,
keep_iterations = FALSE,
...) {
.estimate_predicted(
Expand All @@ -237,6 +244,7 @@ estimate_expectation <- function(model,
ci = ci,
keep_iterations = keep_iterations,
predict = predict,
transform = transform,
...
)
}
Expand All @@ -249,6 +257,7 @@ estimate_link <- function(model,
by = NULL,
predict = "link",
ci = 0.95,
transform = NULL,
keep_iterations = FALSE,
...) {
# reset to NULL if only "by" was specified
Expand All @@ -263,6 +272,7 @@ estimate_link <- function(model,
ci = ci,
keep_iterations = keep_iterations,
predict = predict,
transform = transform,
...
)
}
Expand All @@ -274,6 +284,7 @@ estimate_prediction <- function(model,
by = NULL,
predict = "prediction",
ci = 0.95,
transform = NULL,
keep_iterations = FALSE,
...) {
.estimate_predicted(
Expand All @@ -283,6 +294,7 @@ estimate_prediction <- function(model,
ci = ci,
keep_iterations = keep_iterations,
predict = predict,
transform = transform,
...
)
}
Expand All @@ -294,6 +306,7 @@ estimate_relation <- function(model,
by = NULL,
predict = "expectation",
ci = 0.95,
transform = NULL,
keep_iterations = FALSE,
...) {
# reset to NULL if only "by" was specified
Expand All @@ -308,6 +321,7 @@ estimate_relation <- function(model,
ci = ci,
keep_iterations = keep_iterations,
predict = predict,
transform = transform,
...
)
}
Expand All @@ -321,6 +335,7 @@ estimate_relation <- function(model,
by = NULL,
predict = "expectation",
ci = 0.95,
transform = NULL,
keep_iterations = FALSE,
...) {
# only "by" or "data", but not both
Expand Down Expand Up @@ -445,6 +460,16 @@ estimate_relation <- function(model,
out$Residuals <- response - out$Predicted
}

# transform reponse?
if (isTRUE(transform)) {
transform <- insight::get_transformation(model, verbose = FALSE)$inverse
}
if (!is.null(transform)) {
out$Predicted <- transform(out$Predicted)
out$CI_low <- transform(out$CI_low)
out$CI_high <- transform(out$CI_high)
}

# Store relevant information
attr(out, "ci") <- ci
attr(out, "keep_iterations") <- keep_iterations
Expand Down
2 changes: 2 additions & 0 deletions R/get_marginalcontrasts.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ get_marginalcontrasts <- function(model,
estimate = "average",
ci = 0.95,
p_adjust = "none",
transform = NULL,
verbose = TRUE,
...) {
# check if available
Expand Down Expand Up @@ -78,6 +79,7 @@ get_marginalcontrasts <- function(model,
predict = predict,
backend = "marginaleffects",
estimate = estimate,
transform = transform,
verbose = verbose,
...
)
Expand Down
14 changes: 8 additions & 6 deletions R/get_marginalmeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@ get_marginalmeans <- function(model,
# check if available
insight::check_if_installed("marginaleffects")

## TODO: remove deprecation warning later
if (!is.null(transform)) {
insight::format_warning("Argument `transform` is deprecated. Please use `predict` instead.")
predict <- transform
}

# First step: process arguments --------------------------------------------
# --------------------------------------------------------------------------

Expand Down Expand Up @@ -164,6 +158,14 @@ get_marginalmeans <- function(model,
fun_args$re.form <- NULL
}

# transform reponse?
if (isTRUE(transform)) {
transform <- insight::get_transformation(model, verbose = FALSE)$inverse
}
if (!is.null(transform)) {
fun_args$transform <- transform
}


# Fourth step: compute marginal means ---------------------------------------
# ---------------------------------------------------------------------------
Expand Down
8 changes: 7 additions & 1 deletion man/estimate_contrasts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions man/estimate_expectation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion man/estimate_means.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 8 additions & 1 deletion man/get_emmeans.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 50 additions & 0 deletions tests/testthat/test-transform_response.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
test_that("estimate_means, transform", {
data(cars)
m <- lm(log(dist) ~ speed, data = cars)
out <- estimate_means(m, "speed")
expect_equal(
out$Mean,
c(
2.15918, 2.44097, 2.72276, 3.00454, 3.28633, 3.56811, 3.8499,
4.13168, 4.41347, 4.69525
),
tolerance = 1e-4
)
out1 <- estimate_means(m, "speed", transform = TRUE)
expect_equal(
out1$Mean,
c(
8.66407, 11.48417, 15.2222, 20.17694, 26.74442, 35.44958, 46.98822,
62.28261, 82.55525, 109.42651
),
tolerance = 1e-4
)
out2 <- estimate_means(m, "speed", transform = exp)
expect_equal(out1$Mean, out2$Mean, tolerance = 1e-4)
})


test_that("estimate_expectation, transform", {
data(cars)
m <- lm(log(dist) ~ speed, data = cars)
out <- estimate_expectation(m, by = "speed")
expect_equal(
out$Predicted,
c(
2.15918, 2.44097, 2.72276, 3.00454, 3.28633, 3.56811, 3.8499,
4.13168, 4.41347, 4.69525
),
tolerance = 1e-4
)
out1 <- estimate_expectation(m, by = "speed", transform = TRUE)
expect_equal(
out$Predicted,
c(
8.66407, 11.48417, 15.2222, 20.17694, 26.74442, 35.44958, 46.98822,
62.28261, 82.55525, 109.42651
),
tolerance = 1e-4
)
out2 <- estimate_expectation(m, by ="speed", transform = exp)
expect_equal(out1$Predicted, out2$Predicted, tolerance = 1e-4)
})

0 comments on commit 7aa93cb

Please sign in to comment.