Skip to content

Commit 0b2b712

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 1eedbe4 commit 0b2b712

File tree

1 file changed

+50
-35
lines changed

1 file changed

+50
-35
lines changed
Lines changed: 50 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
1-
21
import sympy.parsing.sympy_parser as parser
32
import sympy
43
from pyhf.parameters import ParamViewer
54
import jax.numpy as jnp
65
import jax
76

7+
88
def create_modifiers():
99

1010
class PureFunctionModifierBuilder:
1111
is_shared = True
12+
1213
def __init__(self, pdfconfig):
1314
self.config = pdfconfig
1415
self.required_parsets = {}
15-
self.builder_data = {'local': {},'global': {'symbols': set()}}
16+
self.builder_data = {'local': {}, 'global': {'symbols': set()}}
1617

1718
def collect(self, thismod, nom):
1819
maskval = True if thismod else False
@@ -21,23 +22,25 @@ def collect(self, thismod, nom):
2122

2223
def require_synbols_as_scalars(self, symbols):
2324
param_spec = {
24-
p:
25-
[{
26-
'paramset_type': 'unconstrained',
27-
'n_parameters': 1,
28-
'is_shared': True,
29-
'inits': (1.0,),
30-
'bounds': ((0,10),),
31-
'is_scalar': True,
32-
'fixed': False,
33-
}]
25+
p: [
26+
{
27+
'paramset_type': 'unconstrained',
28+
'n_parameters': 1,
29+
'is_shared': True,
30+
'inits': (1.0,),
31+
'bounds': ((0, 10),),
32+
'is_scalar': True,
33+
'fixed': False,
34+
}
35+
]
3436
for p in symbols
3537
}
3638
return param_spec
3739

38-
3940
def append(self, key, channel, sample, thismod, defined_samp):
40-
self.builder_data['local'].setdefault(key, {}).setdefault(sample, {}).setdefault('data', {'mask': []})
41+
self.builder_data['local'].setdefault(key, {}).setdefault(
42+
sample, {}
43+
).setdefault('data', {'mask': []})
4144

4245
nom = (
4346
defined_samp['data']
@@ -52,10 +55,12 @@ def append(self, key, channel, sample, thismod, defined_samp):
5255
parsed = parser.parse_expr(formula)
5356
free_symbols = parsed.free_symbols
5457
for x in free_symbols:
55-
self.builder_data['global'].setdefault('symbols',set()).add(x)
58+
self.builder_data['global'].setdefault('symbols', set()).add(x)
5659
else:
5760
parsed = None
58-
self.builder_data['local'].setdefault(key,{}).setdefault(sample,{}).setdefault('channels',{}).setdefault(channel,{})['parsed'] = parsed
61+
self.builder_data['local'].setdefault(key, {}).setdefault(
62+
sample, {}
63+
).setdefault('channels', {}).setdefault(channel, {})['parsed'] = parsed
5964

6065
def finalize(self):
6166
list_of_symbols = [str(x) for x in self.builder_data['global']['symbols']]
@@ -67,7 +72,9 @@ def finalize(self):
6772
for sample, samplespec in modspec.items():
6873
for channel, channelspec in samplespec['channels'].items():
6974
if channelspec['parsed'] is not None:
70-
channelspec['jaxfunc'] = sympy.lambdify(list_of_symbols, channelspec['parsed'], 'jax')
75+
channelspec['jaxfunc'] = sympy.lambdify(
76+
list_of_symbols, channelspec['parsed'], 'jax'
77+
)
7178
else:
7279
channelspec['jaxfunc'] = lambda *args: 1.0
7380
return self.builder_data
@@ -93,28 +100,37 @@ def __init__(
93100
else (pdfconfig.npars,)
94101
)
95102

96-
self.param_viewer = ParamViewer(parfield_shape, pdfconfig.par_map, self.inputs)
103+
self.param_viewer = ParamViewer(
104+
parfield_shape, pdfconfig.par_map, self.inputs
105+
)
97106
self.create_jax_eval()
98107

99108
def create_jax_eval(self):
100109
def eval_func(pars):
101-
return jnp.array([
110+
return jnp.array(
102111
[
103-
jnp.concatenate([
104-
self.builder_data['local'][m][s]['channels'][c]['jaxfunc'](*pars)*jnp.ones(self.pdfconfig.channel_nbins[c])
105-
for c in self.pdfconfig.channels
106-
])
107-
for s in self.pdfconfig.samples
112+
[
113+
jnp.concatenate(
114+
[
115+
self.builder_data['local'][m][s]['channels'][c][
116+
'jaxfunc'
117+
](*pars)
118+
* jnp.ones(self.pdfconfig.channel_nbins[c])
119+
for c in self.pdfconfig.channels
120+
]
121+
)
122+
for s in self.pdfconfig.samples
123+
]
124+
for m in self.keys
108125
]
109-
for m in self.keys
126+
)
110127

111-
])
112128
self.jaxeval = eval_func
113-
114-
def apply_nonbatched(self,pars):
115-
return jnp.expand_dims(self.jaxeval(pars),2)
116129

117-
def apply_batched(self,pars):
130+
def apply_nonbatched(self, pars):
131+
return jnp.expand_dims(self.jaxeval(pars), 2)
132+
133+
def apply_batched(self, pars):
118134
return jax.vmap(self.jaxeval, in_axes=(1,), out_axes=2)(pars)
119135

120136
def apply(self, pars):
@@ -127,19 +143,18 @@ def apply(self, pars):
127143
par_selection = self.param_viewer.get(pars)
128144
results_purefunc = self.apply_batched(par_selection)
129145
return results_purefunc
130-
146+
131147
return PureFunctionModifierBuilder, PureFunctionModifierApplicator
132148

133149

134150
from pyhf.modifiers import histfactory_set
135151

152+
136153
def enable():
137154
modifier_set = {}
138155
modifier_set.update(**histfactory_set)
139156

140157
builder, applicator = create_modifiers()
141158

142-
modifier_set.update(**{
143-
applicator.name: (builder, applicator)}
144-
)
145-
return modifier_set
159+
modifier_set.update(**{applicator.name: (builder, applicator)})
160+
return modifier_set

0 commit comments

Comments
 (0)