Skip to content

Commit c61c97a

Browse files
committed
Add choiceproduct utility function.
1 parent 6be73d3 commit c61c97a

File tree

5 files changed

+68
-0
lines changed

5 files changed

+68
-0
lines changed

docs/src/initialize.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,5 @@ To reduce variance, particle filters can also be initialized via stratified samp
2323
pf_initialize(model::GenerativeFunction, model_args::Tuple,
2424
observations::ChoiceMap, strata, n_particles::Int)
2525
```
26+
27+
For convenience, strata can be generated using the [`choiceproduct`](@ref) function.

docs/src/update.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,5 @@ All of the above methods can also be combined with stratified sampling. This can
5454
pf_update!(state::ParticleFilterView, new_args::Tuple, argdiffs::Tuple,
5555
observations::ChoiceMap, strata)
5656
```
57+
58+
For convenience, strata can be generated using the [`choiceproduct`](@ref) function.

docs/src/utils.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,12 @@ var(::ParticleFilterView, addr)
5555
var(f::Function, ::ParticleFilterView, addr, addrs...)
5656
proportionmap(::ParticleFilterView, addr)
5757
proportionmap(f::Function, ::ParticleFilterView, addr, addrs...)
58+
```
59+
60+
## Stratification
61+
62+
To support the use of stratified sampling, the [`choiceproduct`](@ref) method can be used to conveniently generate a list of choicemap strata:
63+
64+
```@docs
65+
choiceproduct
5866
```

src/utils.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
## Various utility functions ##
2+
export choiceproduct
23
export get_log_norm_weights, get_norm_weights
34
export effective_sample_size, get_ess
45
export log_ml_estimate, get_lml_est
@@ -55,6 +56,49 @@ function stratified_map!(f::Function, n_total::Int, strata, args...;
5556
return nothing
5657
end
5758

59+
"""
60+
choiceproduct((addr, vals))
61+
choiceproduct((addr, vals), choices::Tuple...)
62+
choiceproduct(choices::Dict)
63+
64+
Returns an iterator over `ChoiceMap`s given a tuple or sequence of tuples of
65+
the form `(addr, vals)`, where `addr` specifies a choice address, and
66+
`vals` specifies a list of values for that address.
67+
68+
If multiple tuples are provided, the iterator will be a Cartesian product over
69+
the `(addr, vals)` pairs, where each resulting `ChoiceMap` contains all
70+
specified addresses. Instead of specifying multiple tuples, a dictionary mapping
71+
addresses to values can also be provided.
72+
73+
# Examples
74+
75+
This function can be used to conveniently generate `ChoiceMap`s for stratified
76+
sampling. For example, we can use `choiceproduct` instead of manually
77+
constructing a list of strata:
78+
79+
```julia
80+
# Manual construction
81+
strata = [choicemap((:a, 1), (:b, 3)), choicemap((:a, 2), (:b, 3))]
82+
# Using choiceproduct
83+
strata = choiceproduct((:a, [1, 2]), (:b, [3]))
84+
```
85+
"""
86+
function choiceproduct(choices::Tuple...)
87+
prod_iter = Iterators.product((((addr, v) for v in vals) for
88+
(addr, vals) in choices)...)
89+
return (choicemap(cs...) for cs in prod_iter)
90+
end
91+
92+
function choiceproduct(choices::Dict)
93+
prod_iter = Iterators.product((((addr, v) for v in vals) for
94+
(addr, vals) in choices)...)
95+
return (choicemap(cs...) for cs in prod_iter)
96+
end
97+
98+
function choiceproduct((addr, vals)::Tuple)
99+
return (choicemap((addr, v)) for v in vals)
100+
end
101+
58102
lognorm(vs::AbstractVector) = vs .- logsumexp(vs)
59103

60104
function softmax(vs::AbstractVector{T}) where {T <: Real}

test/utils.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,16 @@
88
@test sum(weights) 1.0
99
ess = get_ess(state)
1010
@test ess sum(weights)^2 / sum(weights .^ 2)
11+
12+
strata = choiceproduct((:a, [1, 2])) |> collect
13+
@test choicemap((:a , 1)) in strata
14+
@test choicemap((:a , 2)) in strata
15+
16+
strata = choiceproduct((:a, [1, 2]), (:b, [3])) |> collect
17+
@test choicemap((:a , 1), (:b, 3)) in strata
18+
@test choicemap((:a , 2), (:b, 3)) in strata
19+
20+
strata = choiceproduct(Dict(:a => [1, 2], :b => [3])) |> collect
21+
@test choicemap((:a , 1), (:b, 3)) in strata
22+
@test choicemap((:a , 2), (:b, 3)) in strata
1123
end

0 commit comments

Comments
 (0)