1
-
2
1
import sympy .parsing .sympy_parser as parser
3
2
import sympy
4
3
from pyhf .parameters import ParamViewer
5
4
import jax .numpy as jnp
6
5
import jax
7
6
7
+
8
8
def create_modifiers ():
9
9
10
10
class PureFunctionModifierBuilder :
11
11
is_shared = True
12
+
12
13
def __init__ (self , pdfconfig ):
13
14
self .config = pdfconfig
14
15
self .required_parsets = {}
15
- self .builder_data = {'local' : {},'global' : {'symbols' : set ()}}
16
+ self .builder_data = {'local' : {}, 'global' : {'symbols' : set ()}}
16
17
17
18
def collect (self , thismod , nom ):
18
19
maskval = True if thismod else False
@@ -21,23 +22,25 @@ def collect(self, thismod, nom):
21
22
22
23
def require_synbols_as_scalars (self , symbols ):
23
24
param_spec = {
24
- p :
25
- [{
26
- 'paramset_type' : 'unconstrained' ,
27
- 'n_parameters' : 1 ,
28
- 'is_shared' : True ,
29
- 'inits' : (1.0 ,),
30
- 'bounds' : ((0 ,10 ),),
31
- 'is_scalar' : True ,
32
- 'fixed' : False ,
33
- }]
25
+ p : [
26
+ {
27
+ 'paramset_type' : 'unconstrained' ,
28
+ 'n_parameters' : 1 ,
29
+ 'is_shared' : True ,
30
+ 'inits' : (1.0 ,),
31
+ 'bounds' : ((0 , 10 ),),
32
+ 'is_scalar' : True ,
33
+ 'fixed' : False ,
34
+ }
35
+ ]
34
36
for p in symbols
35
37
}
36
38
return param_spec
37
39
38
-
39
40
def append (self , key , channel , sample , thismod , defined_samp ):
40
- self .builder_data ['local' ].setdefault (key , {}).setdefault (sample , {}).setdefault ('data' , {'mask' : []})
41
+ self .builder_data ['local' ].setdefault (key , {}).setdefault (
42
+ sample , {}
43
+ ).setdefault ('data' , {'mask' : []})
41
44
42
45
nom = (
43
46
defined_samp ['data' ]
@@ -52,10 +55,12 @@ def append(self, key, channel, sample, thismod, defined_samp):
52
55
parsed = parser .parse_expr (formula )
53
56
free_symbols = parsed .free_symbols
54
57
for x in free_symbols :
55
- self .builder_data ['global' ].setdefault ('symbols' ,set ()).add (x )
58
+ self .builder_data ['global' ].setdefault ('symbols' , set ()).add (x )
56
59
else :
57
60
parsed = None
58
- self .builder_data ['local' ].setdefault (key ,{}).setdefault (sample ,{}).setdefault ('channels' ,{}).setdefault (channel ,{})['parsed' ] = parsed
61
+ self .builder_data ['local' ].setdefault (key , {}).setdefault (
62
+ sample , {}
63
+ ).setdefault ('channels' , {}).setdefault (channel , {})['parsed' ] = parsed
59
64
60
65
def finalize (self ):
61
66
list_of_symbols = [str (x ) for x in self .builder_data ['global' ]['symbols' ]]
@@ -67,7 +72,9 @@ def finalize(self):
67
72
for sample , samplespec in modspec .items ():
68
73
for channel , channelspec in samplespec ['channels' ].items ():
69
74
if channelspec ['parsed' ] is not None :
70
- channelspec ['jaxfunc' ] = sympy .lambdify (list_of_symbols , channelspec ['parsed' ], 'jax' )
75
+ channelspec ['jaxfunc' ] = sympy .lambdify (
76
+ list_of_symbols , channelspec ['parsed' ], 'jax'
77
+ )
71
78
else :
72
79
channelspec ['jaxfunc' ] = lambda * args : 1.0
73
80
return self .builder_data
@@ -93,28 +100,37 @@ def __init__(
93
100
else (pdfconfig .npars ,)
94
101
)
95
102
96
- self .param_viewer = ParamViewer (parfield_shape , pdfconfig .par_map , self .inputs )
103
+ self .param_viewer = ParamViewer (
104
+ parfield_shape , pdfconfig .par_map , self .inputs
105
+ )
97
106
self .create_jax_eval ()
98
107
99
108
def create_jax_eval (self ):
100
109
def eval_func (pars ):
101
- return jnp .array ([
110
+ return jnp .array (
102
111
[
103
- jnp .concatenate ([
104
- self .builder_data ['local' ][m ][s ]['channels' ][c ]['jaxfunc' ](* pars )* jnp .ones (self .pdfconfig .channel_nbins [c ])
105
- for c in self .pdfconfig .channels
106
- ])
107
- for s in self .pdfconfig .samples
112
+ [
113
+ jnp .concatenate (
114
+ [
115
+ self .builder_data ['local' ][m ][s ]['channels' ][c ][
116
+ 'jaxfunc'
117
+ ](* pars )
118
+ * jnp .ones (self .pdfconfig .channel_nbins [c ])
119
+ for c in self .pdfconfig .channels
120
+ ]
121
+ )
122
+ for s in self .pdfconfig .samples
123
+ ]
124
+ for m in self .keys
108
125
]
109
- for m in self . keys
126
+ )
110
127
111
- ])
112
128
self .jaxeval = eval_func
113
-
114
- def apply_nonbatched (self ,pars ):
115
- return jnp .expand_dims (self .jaxeval (pars ),2 )
116
129
117
- def apply_batched (self ,pars ):
130
+ def apply_nonbatched (self , pars ):
131
+ return jnp .expand_dims (self .jaxeval (pars ), 2 )
132
+
133
+ def apply_batched (self , pars ):
118
134
return jax .vmap (self .jaxeval , in_axes = (1 ,), out_axes = 2 )(pars )
119
135
120
136
def apply (self , pars ):
@@ -127,19 +143,18 @@ def apply(self, pars):
127
143
par_selection = self .param_viewer .get (pars )
128
144
results_purefunc = self .apply_batched (par_selection )
129
145
return results_purefunc
130
-
146
+
131
147
return PureFunctionModifierBuilder , PureFunctionModifierApplicator
132
148
133
149
134
150
from pyhf .modifiers import histfactory_set
135
151
152
+
136
153
def enable ():
137
154
modifier_set = {}
138
155
modifier_set .update (** histfactory_set )
139
156
140
157
builder , applicator = create_modifiers ()
141
158
142
- modifier_set .update (** {
143
- applicator .name : (builder , applicator )}
144
- )
145
- return modifier_set
159
+ modifier_set .update (** {applicator .name : (builder , applicator )})
160
+ return modifier_set
0 commit comments