Skip to content

Commit fdf2936

Browse files
committed
Introduce construct.model parameter
construct.model is a function that takes modeling data (NA in post.period) and returns a bsts model. This invocation: go <- function(data) { y <- data[, 1] sdy <- sd(y, na.rm = TRUE) sd.prior <- SdPrior(sigma.guess = 0.01 * sdy, upper.limit = sdy, sample.size = 32) ss <- AddLocalLevel(list(), y, sigma.prior = sd.prior) bsts.model <- bsts(y, state.specification = ss, niter = 100, seed = 1, ping = 0) } impact <- CausalImpact(data, pre.period, post.period, construct.model = go) is equivalent to: model.args <- list(niter = 100) impact <- CausalImpact(data, pre.period, post.period, model.args) . This change provides an interface that has RunWithData's flexibility with pre- and post-period, and RunWithBstsModel's full configurability of the model. For example, the caller can now specify a custom model where the post-period is before the end of the data, which isn't possible with the previous interface.
1 parent 7e0f59f commit fdf2936

2 files changed

Lines changed: 49 additions & 6 deletions

File tree

R/impact_analysis.R

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ FormatInputPrePostPeriod <- function(pre.period, post.period, data) {
135135

136136
FormatInputForCausalImpact <- function(data, pre.period, post.period,
137137
model.args, bsts.model,
138-
post.period.response, alpha) {
138+
post.period.response, alpha,
139+
construct.model) {
139140
# Checks and formats all input arguments supplied to CausalImpact(). See the
140141
# documentation of CausalImpact() for details.
141142
#
@@ -147,6 +148,7 @@ FormatInputForCausalImpact <- function(data, pre.period, post.period,
147148
# bsts.model: fitted bsts model (instead of data)
148149
# post.period.response: observed response in the post-period
149150
# alpha: tail-area for posterior intervals
151+
# construct.model: custom model constructor
150152
#
151153
# Returns:
152154
# list of checked (and possibly reformatted) input arguments
@@ -213,7 +215,8 @@ CausalImpact <- function(data = NULL,
213215
model.args = NULL,
214216
bsts.model = NULL,
215217
post.period.response = NULL,
216-
alpha = 0.05) {
218+
alpha = 0.05,
219+
construct.model = NULL) {
217220
# CausalImpact() performs causal inference through counterfactual
218221
# predictions using a Bayesian structural time-series model.
219222
#
@@ -278,6 +281,8 @@ CausalImpact <- function(data = NULL,
278281
# alpha: Desired tail-area probability for posterior intervals.
279282
# Defaults to 0.05, which will produce central 95\% intervals.
280283
#
284+
# construct.model: Custom model constructor.
285+
#
281286
# Returns:
282287
# A CausalImpact object. This is a list of:
283288
# series: observed data, counterfactual, pointwise and cumulative impact
@@ -341,7 +346,8 @@ CausalImpact <- function(data = NULL,
341346
# Check input
342347
checked <- FormatInputForCausalImpact(data, pre.period, post.period,
343348
model.args, bsts.model,
344-
post.period.response, alpha)
349+
post.period.response, alpha,
350+
construct.model)
345351
data <- checked$data
346352
pre.period <- checked$pre.period
347353
post.period <- checked$post.period
@@ -352,7 +358,7 @@ CausalImpact <- function(data = NULL,
352358

353359
# Depending on input, dispatch to the appropriate Run* method()
354360
if (!is.null(data)) {
355-
impact <- RunWithData(data, pre.period, post.period, model.args, alpha)
361+
impact <- RunWithData(data, pre.period, post.period, model.args, alpha, construct.model)
356362
# Return pre- and post-period in the time unit of the time series.
357363
times <- time(data)
358364
impact$model$pre.period <- times[pre.period]
@@ -364,7 +370,7 @@ CausalImpact <- function(data = NULL,
364370
return(impact)
365371
}
366372

367-
RunWithData <- function(data, pre.period, post.period, model.args, alpha) {
373+
RunWithData <- function(data, pre.period, post.period, model.args, alpha, construct.model) {
368374
# Runs an impact analysis on top of a fitted bsts model.
369375
#
370376
# Args:
@@ -375,6 +381,7 @@ RunWithData <- function(data, pre.period, post.period, model.args, alpha) {
375381
# limits.
376382
# model.args: list of model arguments
377383
# alpha: tail-probabilities of posterior intervals
384+
# construct.model: custom model constructor
378385
#
379386
# Returns:
380387
# See CausalImpact().
@@ -409,7 +416,16 @@ RunWithData <- function(data, pre.period, post.period, model.args, alpha) {
409416
window(data.modeling[, 1], start = pre.period[2] + 1) <- NA
410417

411418
# Construct model and perform inference
412-
bsts.model <- ConstructModel(data.modeling, model.args)
419+
if (!is.null(construct.model)) {
420+
checked <- FormatInputForConstructModel(data.modeling, model.args)
421+
y <- checked$data[, 1]
422+
# If the series is ill-conditioned, abort inference and return NULL
423+
bsts.model <- if (ObservationsAreIllConditioned(y)) NULL else {
424+
construct.model(checked$data)
425+
}
426+
} else {
427+
bsts.model <- ConstructModel(data.modeling, model.args)
428+
}
413429

414430
# Compile posterior inferences
415431
if (!is.null(bsts.model)) {

tests/testthat/test-impact-analysis.R

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,33 @@ test_that("CausalImpact.RunWithData.MissingTimePoint", {
593593
expect_equal(indices, time(series)[-17])
594594
})
595595

596+
test_that("CausalImpact.RunWithData.CustomConstructModel", {
597+
# Test daily data (zoo object)
598+
data <- zoo(cbind(rnorm(200), rnorm(200), rnorm(200)),
599+
seq.Date(as.Date("2014-01-01"), as.Date("2014-01-01") + 199,
600+
by = 1))
601+
602+
pre.period <- as.Date(c("2014-01-01", "2014-04-10")) # 100 days
603+
post.period <- as.Date(c("2014-04-11", "2014-07-09")) # 90 days
604+
605+
go <- function(data) {
606+
y <- data[, 1]
607+
sdy <- sd(y, na.rm = TRUE)
608+
sd.prior <- SdPrior(sigma.guess = 0.01 * sdy,
609+
upper.limit = sdy,
610+
sample.size = 32)
611+
ss <- AddLocalLevel(list(), y, sigma.prior = sd.prior)
612+
bsts.model <- bsts(y, state.specification = ss, niter = 100,
613+
seed = 1, ping = 0)
614+
}
615+
616+
suppressWarnings(impact <- CausalImpact(data, pre.period, post.period,
617+
construct.model = go))
618+
expect_equal(time(impact$model$bsts.model$original.series), 1:200)
619+
expect_equal(time(impact$series), time(data))
620+
CallAllS3Methods(impact)
621+
})
622+
596623
test_that("CausalImpact.RunWithBstsModel", {
597624

598625
# Test on a healthy bsts object

0 commit comments

Comments
 (0)