Skip to content

Commit 3837d0f

Browse files
Remove the estimator_export decorator.
All uses of `estimator_export` in `tensorflow_estimator` have been updated to use an alternative decorator defined inside the estimator package. As the estimator is a deprecated API, no new endpoints should be registered using the `estimator_export` decorator. This change removes the decorator from `tf_export`. `estimator_export` was also the last remaining usage of the `deprecated_inst` argument in the overall `api_export` decorator. This optional argument created a circular dependency between the `deprecation` module and `tf_export`. This argument and its functionality have also been removed. PiperOrigin-RevId: 503514596
1 parent 7b0d43e commit 3837d0f

File tree

2 files changed

+0
-59
lines changed

2 files changed

+0
-59
lines changed

tensorflow/python/util/tf_export.py

-17
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,6 @@ def __init__(self, *args, **kwargs): # pylint: disable=g-doc-args
291291
self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME)
292292
self._overrides = kwargs.get('overrides', [])
293293
self._allow_multiple_exports = kwargs.get('allow_multiple_exports', False)
294-
self._deprecation_inst = kwargs.get('deprecation_inst', None)
295294

296295
self._validate_symbol_names()
297296

@@ -349,15 +348,6 @@ def __call__(self, func):
349348
self.set_attr(undecorated_func, api_names_attr, self._names)
350349
self.set_attr(undecorated_func, api_names_attr_v1, self._names_v1)
351350

352-
# TODO(b/263286841): Remove functionality
353-
if self._deprecation_inst is not None:
354-
# Inline import to avoid dependency cycle between deprecation
355-
# utility and tf_export
356-
from tensorflow.python.util import deprecation # pylint: disable=g-import-not-at-top
357-
deprecation_wrapper = deprecation.deprecated(
358-
None, self._deprecation_inst, warn_once=True)
359-
func = deprecation_wrapper(func)
360-
361351
for name in self._names:
362352
_NAME_TO_SYMBOL_MAPPING[name] = func
363353
for name_v1 in self._names_v1:
@@ -427,11 +417,4 @@ def wrapper(*args, **kwargs):
427417

428418

429419
tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME)
430-
# TODO(b/263286841): Remove in favor of
431-
# `tensorflow_estimator.python.estimator.estimator_export`.
432-
estimator_export = functools.partial(
433-
api_export,
434-
api_name=ESTIMATOR_API_NAME,
435-
is_deprecated=True,
436-
deprecation_inst='Use tf.keras instead.')
437420
keras_export = functools.partial(api_export, api_name=KERAS_API_NAME)

tensorflow/python/util/tf_export_test.py

-42
Original file line numberDiff line numberDiff line change
@@ -141,38 +141,6 @@ def testExportClasses(self):
141141
self.assertEqual(['TestClassA1'], tf_export.get_v1_names(TestClassA))
142142
self.assertEqual(['TestClassB1'], tf_export.get_v1_names(TestClassB))
143143

144-
def testExportClassInEstimator(self):
145-
export_decorator_a = tf_export.tf_export('TestClassA1')
146-
export_decorator_a(TestClassA)
147-
self.assertEqual(('TestClassA1',), TestClassA._tf_api_names)
148-
149-
export_decorator_b = tf_export.estimator_export('estimator.TestClassB1')
150-
export_decorator_b(TestClassB)
151-
self.assertTrue('_tf_api_names' not in TestClassB.__dict__)
152-
self.assertEqual(('TestClassA1',), TestClassA._tf_api_names)
153-
self.assertEqual(['TestClassA1'], tf_export.get_v1_names(TestClassA))
154-
self.assertEqual(['estimator.TestClassB1'],
155-
tf_export.get_v1_names(TestClassB))
156-
157-
@test.mock.patch.object(logging, 'warning', autospec=True)
158-
def testExportDeprecated(self, mock_warning):
159-
export_decorator = tf_export.estimator_export(
160-
'estimator.TestClassA', is_deprecated=True)
161-
export_decorator(TestClassA)
162-
163-
export_decorator2 = tf_export.tf_export('TestClassB1')
164-
export_decorator2(TestClassB)
165-
166-
# Deprecation should trigger a runtime warning
167-
TestClassA()
168-
self.assertEqual(1, mock_warning.call_count)
169-
# Deprecation should only warn once, upon first call
170-
TestClassA()
171-
self.assertEqual(1, mock_warning.call_count)
172-
# No warning should be triggered when inherting from a deprecated class
173-
TestClassB()
174-
self.assertEqual(1, mock_warning.call_count)
175-
176144
def testExportSingleConstant(self):
177145
module1 = self._CreateMockModule('module1')
178146

@@ -215,19 +183,9 @@ def testRaisesExceptionIfInvalidSymbolName(self):
215183
with self.assertRaises(tf_export.InvalidSymbolNameError):
216184
tf_export.tf_export('estimator.invalid')
217185

218-
# All symbols exported by Estimator must be under tf.estimator package.
219-
with self.assertRaises(tf_export.InvalidSymbolNameError):
220-
tf_export.estimator_export('invalid')
221-
with self.assertRaises(tf_export.InvalidSymbolNameError):
222-
tf_export.estimator_export('Estimator.invalid')
223-
with self.assertRaises(tf_export.InvalidSymbolNameError):
224-
tf_export.estimator_export('invalid.estimator')
225-
226186
def testRaisesExceptionIfInvalidV1SymbolName(self):
227187
with self.assertRaises(tf_export.InvalidSymbolNameError):
228188
tf_export.tf_export('valid', v1=['estimator.invalid'])
229-
with self.assertRaises(tf_export.InvalidSymbolNameError):
230-
tf_export.estimator_export('estimator.valid', v1=['invalid'])
231189

232190
def testOverridesFunction(self):
233191
_test_function2._tf_api_names = ['abc']

0 commit comments

Comments
 (0)