Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

summary.estimate_slopes() no longer working. #371

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ S3method(format,estimate_smooth)
S3method(format,marginaleffects_contrasts)
S3method(format,marginaleffects_means)
S3method(format,marginaleffects_slopes)
S3method(format,summary_estimate_slopes)
S3method(format,visualisation_matrix)
S3method(plot,estimate_contrasts)
S3method(plot,estimate_grouplevel)
Expand All @@ -26,6 +27,7 @@ S3method(print,estimate_means)
S3method(print,estimate_predicted)
S3method(print,estimate_slopes)
S3method(print,estimate_smooth)
S3method(print,summary_estimate_slopes)
S3method(print,visualisation_matrix)
S3method(print_html,estimate_contrasts)
S3method(print_html,estimate_grouplevel)
Expand Down
6 changes: 6 additions & 0 deletions R/format.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
}

# arrange columns (not for contrast now)
by <- rev(attr(x, "focal_terms", exact = TRUE))

Check warning on line 13 in R/format.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/format.R,line=13,col=3,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.

Check warning on line 13 in R/format.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/format.R,line=13,col=3,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.
# add "Level" columns from contrasts
if (all(c("Level1", "Level2") %in% colnames(x))) {
by <- unique(by, c("Level1", "Level2"))

Check warning on line 16 in R/format.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/format.R,line=16,col=5,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.

Check warning on line 16 in R/format.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/format.R,line=16,col=5,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.
}
# check which columns actually exist
if (!is.null(by)) {
by <- intersect(by, colnames(x))

Check warning on line 20 in R/format.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/format.R,line=20,col=5,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.

Check warning on line 20 in R/format.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/format.R,line=20,col=5,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.
}
# sort
if (length(by)) {
Expand Down Expand Up @@ -70,6 +70,12 @@
}


#' @export
format.summary_estimate_slopes <- function(x, ...) {
insight::format_table(x, ...)
}


#' @export
format.marginaleffects_means <- function(x, model, ci = 0.95, ...) {
# model information
Expand Down Expand Up @@ -148,8 +154,8 @@

#' @export
format.marginaleffects_contrasts <- function(x, model = NULL, p_adjust = NULL, comparison = NULL, ...) {
predict <- attributes(x)$predict

Check warning on line 157 in R/format.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/format.R,line=157,col=3,[object_overwrite_linter] 'predict' is an exported object from package 'stats'. Avoid re-using such symbols.

Check warning on line 157 in R/format.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/format.R,line=157,col=3,[object_overwrite_linter] 'predict' is an exported object from package 'stats'. Avoid re-using such symbols.
by <- attributes(x)$by

Check warning on line 158 in R/format.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/format.R,line=158,col=3,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.

Check warning on line 158 in R/format.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/format.R,line=158,col=3,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.
contrast <- attributes(x)$contrast
focal_terms <- attributes(x)$focal_terms
dgrid <- attributes(x)$datagrid
Expand Down Expand Up @@ -214,7 +220,7 @@
# in the second example, `contrast = c("vs", "am"), by = "gear='5'"`, the
# `by` column is the one with one unique value only, we thus have to update
# `by` as well, and also `contrast` (the latter not(!) for numerics)...
by <- by[lengths(lapply(dgrid[by], unique)) > 1]

Check warning on line 223 in R/format.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/format.R,line=223,col=5,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.

Check warning on line 223 in R/format.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/format.R,line=223,col=5,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.

# for contrasts, we also filter variables with one unique value, but we
# keep numeric variables. When these are hold constant in the data grid,
Expand All @@ -224,7 +230,7 @@
contrast <- contrast[keep_contrasts]

# set to NULL, if all by-values have been removed here
if (!length(by)) by <- NULL

Check warning on line 233 in R/format.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/format.R,line=233,col=22,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.

Check warning on line 233 in R/format.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/format.R,line=233,col=22,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.

# if we have no contrasts left, e.g. due to `contrast = "time = factor(2)"`,
# we error here - we have no contrasts to show
Expand Down Expand Up @@ -298,7 +304,7 @@
# unite back columns with focal contrasts - only needed when not slopes
if (inherits(x, "estimate_slopes")) {
contrast <- by
by <- NULL

Check warning on line 307 in R/format.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/format.R,line=307,col=9,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.

Check warning on line 307 in R/format.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/format.R,line=307,col=9,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.
}

# if we have more than one contrast term, we unite the levels from
Expand Down Expand Up @@ -463,7 +469,7 @@
params <- params[c(setdiff(colnames(params), relocate_columns), relocate_columns)]

# relocate focal terms to the beginning
by <- attr(x, "focal_terms", exact = TRUE)

Check warning on line 472 in R/format.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/format.R,line=472,col=3,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.
if (!is.null(by) && all(by %in% colnames(params))) {
params <- datawizard::data_reorder(params, by, verbose = FALSE)
}
Expand Down Expand Up @@ -583,7 +589,7 @@
if (substring(input_string, match_positions[i], match_positions[i]) == "-") {
inside_parentheses <- FALSE
for (j in seq_along(match_positions)) {
if (i != j && match_positions[i] > match_positions[j] && match_positions[i] < (match_positions[j] + match_lengths[j])) {

Check warning on line 592 in R/format.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/format.R,line=592,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 130 characters.
inside_parentheses <- TRUE
break
}
Expand Down
3 changes: 3 additions & 0 deletions R/print.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ print.estimate_means <- print.estimate_contrasts
#' @export
print.estimate_slopes <- print.estimate_contrasts

#' @export
print.summary_estimate_slopes <- print.estimate_contrasts

#' @export
print.estimate_smooth <- print.estimate_contrasts

Expand Down
101 changes: 49 additions & 52 deletions R/summary.R
Original file line number Diff line number Diff line change
@@ -1,35 +1,31 @@
#' @export
summary.estimate_slopes <- function(object, ...) {
my_data <- as.data.frame(object)
trend <- attributes(object)$trend
summary.estimate_slopes <- function(object, verbose = TRUE, ...) {
out <- as.data.frame(object)
by <- attributes(object)$by

# Add "Confidence" col based on the sig index present in the data
my_data$Confidence <- .estimate_slopes_sig(my_data, ...)
if (verbose && nrow(out) < 50) {
insight::format_alert("There might be too few data to accurately determine intervals. Consider setting `length = 100` (or larger) in your call to `estimate_slopes()`.") # nolint
}

# Grouping variables
vars <- attributes(object)$at
vars <- vars[!vars %in% trend]
# Add "Confidence" col based on the sig index present in the data
out$Confidence <- .estimate_slopes_significance(out, ...)
out$Direction <- .estimate_slopes_direction(out, ...)

# If no grouping variables, summarize all
if (length(vars) == 0) {
out <- .estimate_slopes_summarize(my_data, trend = trend)
# if we have more than one variable in `by`, group result table and
# add group name as separate column
if (length(by) > 1) {
parts <- split(out, out[[by[2]]])
out <- do.call(rbind, lapply(parts, .estimate_slope_parts, by = by[1]))
out <- datawizard::rownames_as_column(out, "Group")
out$Group <- gsub("\\.\\d+$", "", out$Group)
} else {
out <- data.frame()
# Create vizmatrix of grouping variables
groups <- as.data.frame(insight::get_datagrid(my_data[vars], factors = "all", numerics = "all"))
# Summarize all of the chunks
for (i in seq_len(nrow(groups))) {
g <- datawizard::data_match(my_data, groups[i, , drop = FALSE])
out <- rbind(out, .estimate_slopes_summarize(g, trend = trend))
}
out <- datawizard::data_relocate(out, vars)
out <- .estimate_slope_parts(out, by)
}

# Clean and sanitize
out$Confidence <- NULL # Drop significance col
attributes(out) <- utils::modifyList(attributes(object), attributes(out))
class(out) <- c("estimate_slopes", class(out))
class(out) <- c("summary_estimate_slopes", "data.frame")
attr(out, "table_title") <- c("Average Marginal Effects", "blue")

out
}

Expand All @@ -45,43 +41,44 @@ summary.reshape_grouplevel <- function(object, ...) {
# Utilities ===============================================================


.estimate_slopes_summarize <- function(data, trend, ...) {
# Find beginnings and ends -----------------------
# First row - starting point
.estimate_slope_parts <- function(out, by) {
# mark all "changes" from negative to positive and vice versa
index <- 1
out$switch <- index
index <- index + 1

for (i in 2:nrow(out)) {
if (out$Direction[i] != out$Direction[i - 1] || out$Confidence[i] != out$Confidence[i - 1]) {
out$switch[i:nrow(out)] <- index
index <- index + 1
}
}

# split into "switches"
parts <- split(out, out$switch)

do.call(rbind, lapply(parts, function(i) {
data.frame(
Start = i[[by]][1],
End = i[[by]][nrow(i)],
Direction = i$Direction[1],
Confidence = i$Confidence[1]
)
}))
}


.estimate_slopes_direction <- function(data, ...) {
centrality_columns <- datawizard::extract_column_names(
data,
c("Coefficient", "Slope", "Median", "Mean", "MAP_Estimate"),
verbose = FALSE
)
centrality_signs <- sign(data[[centrality_columns]])
centrality_sign <- centrality_signs[1]
sig <- data$Confidence[1]
starts <- 1
ends <- nrow(data)
# Iterate through all rows to find blocks
for (i in 2:nrow(data)) {
if ((data$Confidence[i] != sig) || ((centrality_signs[i] != centrality_sign) && data$Confidence[i] == "Uncertain")) {
centrality_sign <- centrality_signs[i]
sig <- data$Confidence[i]
starts <- c(starts, i)
ends <- c(ends, i - 1)
}
}
ends <- sort(ends)

# Summarize these groups -----------------------
out <- data.frame()
for (g in seq_len(length(starts))) {
dat <- data[starts[g]:ends[g], ]
dat <- as.data.frame(insight::get_datagrid(dat, by = NULL, factors = "mode"))
dat <- cbind(data.frame(Start = data[starts[g], trend], End = data[ends[g], trend]), dat)
out <- rbind(out, dat)
}
out
ifelse(data[[centrality_columns]] < 0, "negative", "positive")
}


.estimate_slopes_sig <- function(x, confidence = "auto", ...) {
.estimate_slopes_significance <- function(x, confidence = "auto", ...) {
insight::check_if_installed("effectsize")

if (confidence == "auto") {
Expand Down
Loading