From b088b019bf215f2af3e51b01ff1eaff1416c826f Mon Sep 17 00:00:00 2001 From: Raj-Parekh24 Date: Sat, 11 Mar 2023 19:18:18 +0530 Subject: [PATCH 1/6] Added functionality that can use name of variable in eval --- pytensor/graph/basic.py | 37 +++++++++++++++++++++++++++++++++++++ tests/graph/test_basic.py | 8 ++++++++ 2 files changed, 45 insertions(+) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 091b045109..1869f2b411 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -558,6 +558,39 @@ def get_parents(self): return [self.owner] return [] + def convert_string_keys_to_pytensor_variables(self, inputs_to_values): + r"""Convert the string keys to corresponding `Variable` with nearest name. + + Parameters + ---------- + inputs_to_values : + A dictionary mapping PyTensor `Variable`\s to values. + + Examples + -------- + + >>> import numpy as np + >>> import pytensor.tensor as at + >>> x = at.dscalar('x') + >>> y = at.dscalar('y') + >>> z = x + y + >>> np.allclose(z.eval({'x' : 3, 'y' : 1}), 4) + True + """ + process_input_to_values = {} + for i in inputs_to_values: + if isinstance(i, str): + nodes_with_matching_names = get_var_by_name([self], i) + if len(nodes_with_matching_names) == 0: + raise Exception(f"{i} not found in graph") + else: + process_input_to_values[ + nodes_with_matching_names[0] + ] = inputs_to_values[i] + else: + process_input_to_values[i] = inputs_to_values[i] + return process_input_to_values + def eval(self, inputs_to_values=None): r"""Evaluate the `Variable`. @@ -597,6 +630,10 @@ def eval(self, inputs_to_values=None): if inputs_to_values is None: inputs_to_values = {} + inputs_to_values = self.convert_string_keys_to_pytensor_variables( + inputs_to_values + ) + if not hasattr(self, "_fn_cache"): self._fn_cache = dict() diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index a4779c8299..345f816e44 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -302,6 +302,14 @@ def test_eval(self): pickle.loads(pickle.dumps(self.w)), "_fn_cache" ), "temporary functions must not be serialized" + def test_eval_with_strings(self): + assert self.w.eval({"x": 1.0, "y": 2.0}) == 6.0 + assert self.w.eval({self.z: 3}) == 6.0 + assert hasattr(self.w, "_fn_cache"), "variable must have cache after eval" + assert not hasattr( + pickle.loads(pickle.dumps(self.w)), "_fn_cache" + ), "temporary functions must not be serialized" + class TestAutoName: def test_auto_name(self): From 8976c94b347092636eb5a2ba977a131358987cfa Mon Sep 17 00:00:00 2001 From: Raj-Parekh24 Date: Sun, 12 Mar 2023 00:00:52 +0530 Subject: [PATCH 2/6] Updated function with warning when multiple variables with same name are defined --- pytensor/graph/basic.py | 11 +++++++++-- tests/graph/test_basic.py | 13 +++++++------ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 1869f2b411..5b52735ba6 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -581,11 +581,18 @@ def convert_string_keys_to_pytensor_variables(self, inputs_to_values): for i in inputs_to_values: if isinstance(i, str): nodes_with_matching_names = get_var_by_name([self], i) - if len(nodes_with_matching_names) == 0: + length_of_nodes_with_matching_names = len(nodes_with_matching_names) + if length_of_nodes_with_matching_names == 0: raise Exception(f"{i} not found in graph") else: + if length_of_nodes_with_matching_names > 1: + warnings.warn( + f"Found {length_of_nodes_with_matching_names} pytensor variables with name {i} taking the first declared named variable for computation" + ) process_input_to_values[ - nodes_with_matching_names[0] + nodes_with_matching_names[ + length_of_nodes_with_matching_names - 1 + ] ] = inputs_to_values[i] else: process_input_to_values[i] = inputs_to_values[i] diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index 345f816e44..11d6f792f4 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -290,9 +290,11 @@ def test_outputs_clients(self): class TestEval: def setup_method(self): - self.x, self.y = scalars("x", "y") + self.x, self.y, self.e = scalars("x", "y", "e") self.z = self.x + self.y self.w = 2 * self.z + self.t = self.e + 1 + self.t.name = "e" def test_eval(self): assert self.w.eval({self.x: 1.0, self.y: 2.0}) == 6.0 @@ -303,12 +305,11 @@ def test_eval(self): ), "temporary functions must not be serialized" def test_eval_with_strings(self): - assert self.w.eval({"x": 1.0, "y": 2.0}) == 6.0 + assert self.w.eval({"x": 1.0, self.y: 2.0}) == 6.0 assert self.w.eval({self.z: 3}) == 6.0 - assert hasattr(self.w, "_fn_cache"), "variable must have cache after eval" - assert not hasattr( - pickle.loads(pickle.dumps(self.w)), "_fn_cache" - ), "temporary functions must not be serialized" + + def test_eval_with_strings_with_mulitple_same_name(self): + assert self.t.eval({"e": 1.0}) == 2.0 class TestAutoName: From b7f745ea14d2d3acb4eeb49b9caf9bcce03ede0a Mon Sep 17 00:00:00 2001 From: Raj-Parekh24 Date: Tue, 14 Mar 2023 18:12:48 +0530 Subject: [PATCH 3/6] Making convert_string_keys_to_variables internal to eval --- pytensor/graph/basic.py | 66 ++++++++++++++------------------------- tests/graph/test_basic.py | 12 ++++--- 2 files changed, 30 insertions(+), 48 deletions(-) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 5b52735ba6..4a61bbdeb9 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -558,46 +558,6 @@ def get_parents(self): return [self.owner] return [] - def convert_string_keys_to_pytensor_variables(self, inputs_to_values): - r"""Convert the string keys to corresponding `Variable` with nearest name. - - Parameters - ---------- - inputs_to_values : - A dictionary mapping PyTensor `Variable`\s to values. - - Examples - -------- - - >>> import numpy as np - >>> import pytensor.tensor as at - >>> x = at.dscalar('x') - >>> y = at.dscalar('y') - >>> z = x + y - >>> np.allclose(z.eval({'x' : 3, 'y' : 1}), 4) - True - """ - process_input_to_values = {} - for i in inputs_to_values: - if isinstance(i, str): - nodes_with_matching_names = get_var_by_name([self], i) - length_of_nodes_with_matching_names = len(nodes_with_matching_names) - if length_of_nodes_with_matching_names == 0: - raise Exception(f"{i} not found in graph") - else: - if length_of_nodes_with_matching_names > 1: - warnings.warn( - f"Found {length_of_nodes_with_matching_names} pytensor variables with name {i} taking the first declared named variable for computation" - ) - process_input_to_values[ - nodes_with_matching_names[ - length_of_nodes_with_matching_names - 1 - ] - ] = inputs_to_values[i] - else: - process_input_to_values[i] = inputs_to_values[i] - return process_input_to_values - def eval(self, inputs_to_values=None): r"""Evaluate the `Variable`. @@ -637,9 +597,29 @@ def eval(self, inputs_to_values=None): if inputs_to_values is None: inputs_to_values = {} - inputs_to_values = self.convert_string_keys_to_pytensor_variables( - inputs_to_values - ) + def convert_string_keys_to_variables(): + process_input_to_values = {} + for i in inputs_to_values: + if isinstance(i, str): + nodes_with_matching_names = get_var_by_name([self], i) + length_of_nodes_with_matching_names = len(nodes_with_matching_names) + if length_of_nodes_with_matching_names == 0: + raise Exception(f"{i} not found in graph") + else: + if length_of_nodes_with_matching_names > 1: + raise Exception( + f"Found {length_of_nodes_with_matching_names} pytensor variables with name {i}" + ) + process_input_to_values[ + nodes_with_matching_names[ + length_of_nodes_with_matching_names - 1 + ] + ] = inputs_to_values[i] + else: + process_input_to_values[i] = inputs_to_values[i] + return process_input_to_values + + inputs_to_values = convert_string_keys_to_variables() if not hasattr(self, "_fn_cache"): self._fn_cache = dict() diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index 11d6f792f4..96eb7997c3 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -290,11 +290,9 @@ def test_outputs_clients(self): class TestEval: def setup_method(self): - self.x, self.y, self.e = scalars("x", "y", "e") + self.x, self.y = scalars("x", "y") self.z = self.x + self.y self.w = 2 * self.z - self.t = self.e + 1 - self.t.name = "e" def test_eval(self): assert self.w.eval({self.x: 1.0, self.y: 2.0}) == 6.0 @@ -308,8 +306,12 @@ def test_eval_with_strings(self): assert self.w.eval({"x": 1.0, self.y: 2.0}) == 6.0 assert self.w.eval({self.z: 3}) == 6.0 - def test_eval_with_strings_with_mulitple_same_name(self): - assert self.t.eval({"e": 1.0}) == 2.0 + def test_eval_errors_having_mulitple_variables_same_name(self): + e = scalars("e") + t = e + 1 + t.name = "e" + with pytest.raises(Exception, match="Found 2 pytensor variables with name e"): + t.eval({"e": 1}) class TestAutoName: From 8869a9d45fa3543fdb80949f05f5dbba2a0d7116 Mon Sep 17 00:00:00 2001 From: Raj-Parekh24 Date: Sat, 18 Mar 2023 20:05:50 +0530 Subject: [PATCH 4/6] Removed last variable selection and added test case for no name available --- pytensor/graph/basic.py | 4 +--- tests/graph/test_basic.py | 7 +++++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 4a61bbdeb9..b5e315ab2f 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -611,9 +611,7 @@ def convert_string_keys_to_variables(): f"Found {length_of_nodes_with_matching_names} pytensor variables with name {i}" ) process_input_to_values[ - nodes_with_matching_names[ - length_of_nodes_with_matching_names - 1 - ] + nodes_with_matching_names[0] ] = inputs_to_values[i] else: process_input_to_values[i] = inputs_to_values[i] diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index 96eb7997c3..bcd2f92095 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -313,6 +313,13 @@ def test_eval_errors_having_mulitple_variables_same_name(self): with pytest.raises(Exception, match="Found 2 pytensor variables with name e"): t.eval({"e": 1}) + def test_eval_errors_with_no_name_exists(self): + e = scalars("e") + t = e + 1 + t.name = "p" + with pytest.raises(Exception, match="o not found in graph"): + t.eval({"o": 1}) + class TestAutoName: def test_auto_name(self): From 6e2efaa45fb3ce62138772743a476e75b9118561 Mon Sep 17 00:00:00 2001 From: Raj-Parekh24 Date: Mon, 20 Mar 2023 21:00:25 +0530 Subject: [PATCH 5/6] Modified the code structure based on suggestions --- pytensor/graph/basic.py | 31 +++++++++++++------------------ tests/graph/test_basic.py | 2 +- 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index b5e315ab2f..df846b7567 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -597,27 +597,22 @@ def eval(self, inputs_to_values=None): if inputs_to_values is None: inputs_to_values = {} - def convert_string_keys_to_variables(): - process_input_to_values = {} - for i in inputs_to_values: - if isinstance(i, str): - nodes_with_matching_names = get_var_by_name([self], i) - length_of_nodes_with_matching_names = len(nodes_with_matching_names) - if length_of_nodes_with_matching_names == 0: - raise Exception(f"{i} not found in graph") - else: - if length_of_nodes_with_matching_names > 1: + def convert_string_keys_to_variables(input_to_values): + new_input_to_values = {} + for key, value in inputs_to_values.items(): + if isinstance(key, str): + matching_vars = get_var_by_name([self], key) + if not matching_vars: + raise Exception(f"{key} not found in graph") + elif len(matching_vars) > 1: raise Exception( - f"Found {length_of_nodes_with_matching_names} pytensor variables with name {i}" + f"Found multiple variables with name {key}" ) - process_input_to_values[ - nodes_with_matching_names[0] - ] = inputs_to_values[i] + new_input_to_values[matching_vars[0]] = value else: - process_input_to_values[i] = inputs_to_values[i] - return process_input_to_values - - inputs_to_values = convert_string_keys_to_variables() + new_input_to_values[key] = value + return new_input_to_values + inputs_to_values = convert_string_keys_to_variables(inputs_to_values) if not hasattr(self, "_fn_cache"): self._fn_cache = dict() diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index bcd2f92095..df772efa6f 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -310,7 +310,7 @@ def test_eval_errors_having_mulitple_variables_same_name(self): e = scalars("e") t = e + 1 t.name = "e" - with pytest.raises(Exception, match="Found 2 pytensor variables with name e"): + with pytest.raises(Exception, match="Found multiple variables with name e"): t.eval({"e": 1}) def test_eval_errors_with_no_name_exists(self): From 9008a32e1cc89d296b1d397532c2a60e21eebbbb Mon Sep 17 00:00:00 2001 From: Raj-Parekh24 Date: Mon, 20 Mar 2023 21:01:45 +0530 Subject: [PATCH 6/6] Modified the code structure based on suggestions --- pytensor/graph/basic.py | 5 ++--- tests/graph/test_basic.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index df846b7567..236ec93ed0 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -605,13 +605,12 @@ def convert_string_keys_to_variables(input_to_values): if not matching_vars: raise Exception(f"{key} not found in graph") elif len(matching_vars) > 1: - raise Exception( - f"Found multiple variables with name {key}" - ) + raise Exception(f"Found multiple variables with name {key}") new_input_to_values[matching_vars[0]] = value else: new_input_to_values[key] = value return new_input_to_values + inputs_to_values = convert_string_keys_to_variables(inputs_to_values) if not hasattr(self, "_fn_cache"): diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index df772efa6f..935301be05 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -306,14 +306,14 @@ def test_eval_with_strings(self): assert self.w.eval({"x": 1.0, self.y: 2.0}) == 6.0 assert self.w.eval({self.z: 3}) == 6.0 - def test_eval_errors_having_mulitple_variables_same_name(self): + def test_eval_with_strings_multiple_matches(self): e = scalars("e") t = e + 1 t.name = "e" with pytest.raises(Exception, match="Found multiple variables with name e"): t.eval({"e": 1}) - def test_eval_errors_with_no_name_exists(self): + def test_eval_with_strings_no_match(self): e = scalars("e") t = e + 1 t.name = "p"