Skip to content

Commit 424eda9

Browse files
committed
closes #180.
mcmc_recover_scatter() no longer determines x,y range peeking at a prebuilt plot
1 parent b7885a5 commit 424eda9

File tree

2 files changed

+94
-59
lines changed

2 files changed

+94
-59
lines changed

R/mcmc-recover.R

Lines changed: 59 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -225,65 +225,73 @@ mcmc_recover_scatter <-
225225
size = 3,
226226
alpha = 1) {
227227

228-
check_ignored_arguments(...)
229-
x <- merge_chains(prepare_mcmc_array(x))
228+
check_ignored_arguments(...)
229+
x <- merge_chains(prepare_mcmc_array(x))
230230

231-
stopifnot(
232-
is.numeric(true),
233-
ncol(x) == length(true),
234-
length(batch) == length(true)
235-
)
236-
all_separate <- length(unique(batch)) == length(true)
237-
point_est <- match.arg(point_est)
238-
plot_data <- data.frame(
239-
Parameter = colnames(x),
240-
Point = apply(x, 2, point_est),
241-
True = true
242-
)
243-
if (!all_separate) {
244-
plot_data$Batch <- factor(batch, levels = unique(batch))
245-
} else {
246-
plot_data$Batch <-
247-
factor(colnames(x), levels = colnames(x)[as.integer(as.factor(batch))])
248-
}
231+
stopifnot(
232+
is.numeric(true),
233+
ncol(x) == length(true),
234+
length(batch) == length(true)
235+
)
249236

250-
facet_args[["facets"]] <- ~ Batch
251-
if (is.null(facet_args[["strip.position"]]))
252-
facet_args[["strip.position"]] <- "top"
253-
if (is.null(facet_args[["scales"]]))
254-
facet_args[["scales"]] <- "free"
237+
one_true_per_batch <- length(unique(batch)) == length(true)
238+
one_batch <- length(unique(batch)) == 1
255239

256-
graph <- ggplot(plot_data, aes_(x = ~ True, y = ~ Point)) +
257-
geom_abline(
258-
slope = 1,
259-
intercept = 0,
260-
linetype = 2,
261-
color = "black"
262-
) +
263-
geom_point(
264-
shape = 21,
265-
color = get_color("mh"),
266-
fill = get_color("m"),
267-
size = size,
268-
alpha = alpha
269-
) +
270-
do.call("facet_wrap", facet_args) +
271-
labs(y = "Estimated", x = "True") +
272-
bayesplot_theme_get()
240+
point_est <- match.arg(point_est)
241+
plot_data <- data.frame(
242+
Parameter = colnames(x),
243+
Point = apply(x, 2, point_est),
244+
True = true
245+
)
273246

274-
if (length(unique(batch)) == 1) {
275-
g <- ggplot_build(graph)
276-
xylim <- g$layout$panel_ranges[[1]]
277-
xylim <- range(xylim$y.range, xylim$x.range)
278-
graph <- graph + coord_fixed(xlim = xylim, ylim = xylim)
279-
}
247+
if (!one_true_per_batch) {
248+
plot_data$Batch <- factor(batch, levels = unique(batch))
249+
} else {
250+
plot_data$Batch <-
251+
factor(colnames(x), levels = colnames(x)[as.integer(as.factor(batch))])
252+
}
280253

281-
if (all_separate)
282-
return(graph)
254+
facet_args[["facets"]] <- "Batch"
255+
facet_args[["strip.position"]] <- facet_args[["strip.position"]] %||% "top"
256+
facet_args[["scales"]] <- facet_args[["scales"]] %||% "free"
283257

284-
graph + facet_text(FALSE)
258+
# To ensure that the x and y scales have the same range, find the min and max
259+
# value on each coordinate. plot them invisibly with geom_blank() later on.
260+
corners <- plot_data %>%
261+
group_by(.data$Batch) %>%
262+
summarise(
263+
min = min(pmin(.data$Point, .data$True)),
264+
max = max(pmax(.data$Point, .data$True))
265+
)
266+
267+
graph <-
268+
ggplot(plot_data, aes_(x = ~ True, y = ~ Point)) +
269+
geom_abline(
270+
slope = 1,
271+
intercept = 0,
272+
linetype = 2,
273+
color = "black"
274+
) +
275+
geom_point(
276+
shape = 21,
277+
color = get_color("mh"),
278+
fill = get_color("m"),
279+
size = size,
280+
alpha = alpha
281+
) +
282+
geom_blank(aes(x = min, y = min), data = corners) +
283+
geom_blank(aes(x = max, y = max), data = corners) +
284+
do.call("facet_wrap", facet_args) +
285+
labs(x = "True", y = "Estimated") +
286+
bayesplot_theme_get()
287+
288+
if (one_batch) {
289+
graph <- graph + facet_text(FALSE)
285290
}
286291

292+
graph
293+
}
294+
287295

288296
#' @rdname MCMC-recover
289297
#' @export

tests/testthat/test-mcmc-recover.R

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,41 @@ test_that("mcmc_recover_intervals works when point_est = 'none'", {
6161

6262

6363
test_that("mcmc_recover_scatter returns a ggplot object", {
64-
expect_gg(mcmc_recover_scatter(draws, true))
65-
expect_gg(mcmc_recover_scatter(draws, true, batch = 1:4,
66-
point_est = "mean"))
67-
expect_gg(mcmc_recover_scatter(draws, true, batch = c(1, 2, 2, 1),
68-
point_est = "mean"))
69-
expect_gg(mcmc_recover_scatter(draws, true, batch = grepl("X", colnames(draws))))
70-
expect_gg(mcmc_recover_scatter(draws, true, batch = grepl("X", colnames(draws)),
71-
facet_args = list(ncol = 1)))
64+
expect_gg(
65+
mcmc_recover_scatter(draws, true)
66+
)
67+
expect_gg(
68+
mcmc_recover_scatter(
69+
draws,
70+
true,
71+
batch = 1:4,
72+
point_est = "mean",
73+
facet_args = list(scales = "fixed")
74+
)
75+
)
76+
expect_gg(
77+
mcmc_recover_scatter(
78+
draws,
79+
true,
80+
batch = c(1, 2, 2, 1),
81+
point_est = "mean"
82+
)
83+
)
84+
expect_gg(
85+
mcmc_recover_scatter(
86+
draws,
87+
true,
88+
batch = grepl("X", colnames(draws))
89+
)
90+
)
91+
expect_gg(
92+
mcmc_recover_scatter(
93+
draws,
94+
true,
95+
batch = grepl("X", colnames(draws)),
96+
facet_args = list(ncol = 1)
97+
)
98+
)
7299
})
73100

74101

0 commit comments

Comments
 (0)