Skip to content

Commit a228685

Browse files
authored
Upgrade tensorflow to 2 (#73)
1 parent 47f441d commit a228685

10 files changed

+64
-61
lines changed

Dockerfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
#
1616

17-
FROM quay.io/codait/max-base:v1.4.0
17+
FROM quay.io/codait/max-base:v1.5.0
1818

1919
ARG model_bucket=https://max-cdn.cdn.appdomain.cloud/max-image-caption-generator/1.0.0
2020
ARG model_file=assets.tar.gz

api/predict.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from core.model import ModelWrapper
1919

2020
from flask import abort
21-
from flask_restplus import fields
21+
from flask_restx import fields
2222
from werkzeug.datastructures import FileStorage
2323

2424

core/inference_utils/inference_wrapper_base.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def build_model(self, model_config):
6868
Returns:
6969
model: The model object.
7070
"""
71-
tf.logging.fatal("Please implement build_model in subclass")
71+
tf.compat.v1.logging.fatal("Please implement build_model in subclass")
7272

7373
def _create_restore_fn(self, checkpoint_path, saver):
7474
"""Creates a function that restores a model from checkpoint.
@@ -86,16 +86,16 @@ def _create_restore_fn(self, checkpoint_path, saver):
8686
ValueError: If checkpoint_path does not refer to a checkpoint file or a
8787
directory containing a checkpoint file.
8888
"""
89-
if tf.gfile.IsDirectory(checkpoint_path):
89+
if tf.compat.v1.gfile.IsDirectory(checkpoint_path):
9090
checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
9191
if not checkpoint_path:
9292
raise ValueError("No checkpoint file found in: %s" % checkpoint_path)
9393

9494
def _restore_fn(sess):
95-
tf.logging.info("Loading model from checkpoint: %s", checkpoint_path)
95+
tf.compat.v1.logging.info("Loading model from checkpoint: %s", checkpoint_path)
9696
saver.restore(sess, checkpoint_path)
97-
tf.logging.info("Successfully loaded checkpoint: %s",
98-
os.path.basename(checkpoint_path))
97+
tf.compat.v1.logging.info("Successfully loaded checkpoint: %s",
98+
os.path.basename(checkpoint_path))
9999

100100
return _restore_fn
101101

@@ -111,9 +111,9 @@ def build_graph_from_config(self, model_config, checkpoint_path):
111111
restore_fn: A function such that restore_fn(sess) loads model variables
112112
from the checkpoint file.
113113
"""
114-
tf.logging.info("Building model.")
114+
tf.compat.v1.logging.info("Building model.")
115115
self.build_model(model_config)
116-
saver = tf.train.Saver()
116+
saver = tf.compat.v1.train.Saver()
117117

118118
return self._create_restore_fn(checkpoint_path, saver)
119119

@@ -132,18 +132,18 @@ def build_graph_from_proto(self, graph_def_file, saver_def_file,
132132
from the checkpoint file.
133133
"""
134134
# Load the Graph.
135-
tf.logging.info("Loading GraphDef from file: %s", graph_def_file)
135+
tf.compat.v1.logging.info("Loading GraphDef from file: %s", graph_def_file)
136136
graph_def = tf.GraphDef()
137-
with tf.gfile.FastGFile(graph_def_file, "rb") as f:
137+
with tf.compat.v1.gfile.FastGFile(graph_def_file, "rb") as f:
138138
graph_def.ParseFromString(f.read())
139139
tf.import_graph_def(graph_def, name="")
140140

141141
# Load the Saver.
142-
tf.logging.info("Loading SaverDef from file: %s", saver_def_file)
143-
saver_def = tf.train.SaverDef()
144-
with tf.gfile.FastGFile(saver_def_file, "rb") as f:
142+
tf.compat.v1.logging.info("Loading SaverDef from file: %s", saver_def_file)
143+
saver_def = tf.compat.v1.train.SaverDef()
144+
with tf.compat.v1.gfile.FastGFile(saver_def_file, "rb") as f:
145145
saver_def.ParseFromString(f.read())
146-
saver = tf.train.Saver(saver_def=saver_def)
146+
saver = tf.compat.v1.train.Saver(saver_def=saver_def)
147147

148148
return self._create_restore_fn(checkpoint_path, saver)
149149

@@ -159,7 +159,7 @@ def feed_image(self, sess, encoded_image):
159159
Returns:
160160
state: A numpy array of shape [1, state_size].
161161
"""
162-
tf.logging.fatal("Please implement feed_image in subclass")
162+
tf.compat.v1.logging.fatal("Please implement feed_image in subclass")
163163

164164
def inference_step(self, sess, input_feed, state_feed):
165165
"""Runs one step of inference.
@@ -176,6 +176,6 @@ def inference_step(self, sess, input_feed, state_feed):
176176
current inference step (e.g. serialized numpy array containing
177177
activations from a particular model layer.).
178178
"""
179-
tf.logging.fatal("Please implement inference_step in subclass")
179+
tf.compat.v1.logging.fatal("Please implement inference_step in subclass")
180180

181181
# pylint: enable=unused-argument

core/inference_utils/vocabulary.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ def __init__(self,
3939
end_word: Special word denoting sentence end.
4040
unk_word: Special word denoting unknown words.
4141
"""
42-
if not tf.gfile.Exists(vocab_file):
43-
tf.logging.fatal("Vocab file %s not found.", vocab_file)
44-
tf.logging.info("Initializing vocabulary from file: %s", vocab_file)
42+
if not tf.compat.v1.gfile.Exists(vocab_file):
43+
tf.compat.v1.logging.fatal("Vocab file %s not found.", vocab_file)
44+
tf.compat.v1.logging.info("Initializing vocabulary from file: %s", vocab_file)
4545

46-
with tf.gfile.GFile(vocab_file, mode="r") as f:
46+
with tf.compat.v1.gfile.GFile(vocab_file, mode="r") as f:
4747
reverse_vocab = list(f.readlines())
4848
reverse_vocab = [line.split()[0] for line in reverse_vocab]
4949
if start_word not in reverse_vocab:
@@ -54,7 +54,7 @@ def __init__(self,
5454
reverse_vocab.append(unk_word)
5555
vocab = dict([(x, y) for (y, x) in enumerate(reverse_vocab)])
5656

57-
tf.logging.info("Created vocabulary with %d words" % len(vocab))
57+
tf.compat.v1.logging.info("Created vocabulary with %d words" % len(vocab))
5858

5959
self.vocab = vocab # vocab[word] = id
6060
self.reverse_vocab = reverse_vocab # reverse_vocab[id] = word

core/model.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131
logger = logging.getLogger()
3232

33+
tf.compat.v1.disable_eager_execution()
34+
3335

3436
class ModelWrapper(MAXModelWrapper):
3537

@@ -42,7 +44,7 @@ def __init__(self, path=DEFAULT_MODEL_PATH):
4244
path)
4345
g.finalize()
4446
self.model = model
45-
sess = tf.Session(graph=g)
47+
sess = tf.compat.v1.Session(graph=g)
4648
# Load the model from checkpoint.
4749
restore_fn(sess)
4850
self.sess = sess

core/ops/image_embedding.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121

2222
import tensorflow as tf
2323

24-
from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3_base
25-
26-
slim = tf.contrib.slim
24+
import tf_slim as slim
25+
from tf_slim.nets.inception_v3 import inception_v3_base
2726

2827

2928
def inception_v3(images,
@@ -83,19 +82,19 @@ def inception_v3(images,
8382
else:
8483
weights_regularizer = None
8584

86-
with tf.variable_scope(scope, "InceptionV3", [images]) as scope:
85+
with tf.compat.v1.variable_scope(scope, "InceptionV3", [images]) as scope:
8786
with slim.arg_scope(
8887
[slim.conv2d, slim.fully_connected],
8988
weights_regularizer=weights_regularizer,
9089
trainable=trainable):
9190
with slim.arg_scope(
9291
[slim.conv2d],
93-
weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
92+
weights_initializer=tf.compat.v1.truncated_normal_initializer(stddev=stddev),
9493
activation_fn=tf.nn.relu,
9594
normalizer_fn=slim.batch_norm,
9695
normalizer_params=batch_norm_params):
9796
net, end_points = inception_v3_base(images, scope=scope)
98-
with tf.variable_scope("logits"):
97+
with tf.compat.v1.variable_scope("logits"):
9998
shape = net.get_shape()
10099
net = slim.avg_pool2d(net, shape[1:3], padding="VALID", scope="pool")
101100
net = slim.dropout(
@@ -108,6 +107,6 @@ def inception_v3(images,
108107
# Add summaries.
109108
if add_summaries:
110109
for v in end_points.values():
111-
tf.contrib.layers.summaries.summarize_activation(v)
110+
slim.summarize_activation(v)
112111

113112
return net

core/ops/image_processing.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ def distort_image(image, thread_id):
3535
[0, 1].
3636
"""
3737
# Randomly flip horizontally.
38-
with tf.name_scope("flip_horizontal", values=[image]):
38+
with tf.name_scope("flip_horizontal"):
3939
image = tf.image.random_flip_left_right(image)
4040

4141
# Randomly distort the colors based on thread id.
4242
color_ordering = thread_id % 2
43-
with tf.name_scope("distort_color", values=[image]):
43+
with tf.name_scope("distort_color"):
4444
if color_ordering == 0:
4545
image = tf.image.random_brightness(image, max_delta=32. / 255.)
4646
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
@@ -95,7 +95,7 @@ def image_summary(name, image):
9595
tf.summary.image(name, tf.expand_dims(image, 0))
9696

9797
# Decode image into a float32 Tensor of shape [?, ?, 3] with values in [0, 1).
98-
with tf.name_scope("decode", values=[encoded_image]):
98+
with tf.name_scope("decode"):
9999
if image_format == "jpeg":
100100
image = tf.image.decode_jpeg(encoded_image, channels=3)
101101
elif image_format == "png":
@@ -110,16 +110,16 @@ def image_summary(name, image):
110110
raise ValueError("Invalid resize parameters height: '{0}' width: '{1}'".format(resize_height, resize_width))
111111

112112
if resize_height:
113-
image = tf.image.resize_images(image,
114-
size=[resize_height, resize_width],
115-
method=tf.image.ResizeMethod.BILINEAR)
113+
image = tf.image.resize(image,
114+
size=[resize_height, resize_width],
115+
method=tf.image.ResizeMethod.BILINEAR)
116116

117117
# Crop to final dimensions.
118118
if is_training:
119119
image = tf.random_crop(image, [height, width, 3])
120120
else:
121121
# Central crop, assuming resize_height > height, resize_width > width.
122-
image = tf.image.resize_image_with_crop_or_pad(image, height, width)
122+
image = tf.image.resize_with_crop_or_pad(image, height, width)
123123

124124
image_summary("resized_image", image)
125125

core/ops/inputs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def prefetch_input_data(reader,
8585
"""
8686
data_files = []
8787
for pattern in file_pattern.split(","):
88-
data_files.extend(tf.gfile.Glob(pattern))
88+
data_files.extend(tf.compat.v1.gfile.Glob(pattern))
8989
if not data_files:
9090
tf.logging.fatal("Found no input files matching %s", file_pattern)
9191
else:

core/show_and_tell_model.py

+21-20
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from __future__ import print_function
2525

2626
import tensorflow as tf
27+
import tf_slim as slim
2728

2829
from core.ops import image_embedding
2930
from core.ops import image_processing
@@ -52,7 +53,7 @@ def __init__(self, config, mode, train_inception=False):
5253
self.train_inception = train_inception
5354

5455
# Reader for the input data.
55-
self.reader = tf.TFRecordReader()
56+
self.reader = tf.compat.v1.TFRecordReader()
5657

5758
# To match the "Show and Tell" paper we initialize all variables with a
5859
# random uniform initializer.
@@ -129,10 +130,10 @@ def build_inputs(self):
129130
"""
130131
if self.mode == "inference":
131132
# In inference mode, images and inputs are fed via placeholders.
132-
image_feed = tf.placeholder(dtype=tf.string, shape=[], name="image_feed")
133-
input_feed = tf.placeholder(dtype=tf.int64,
134-
shape=[None], # batch_size
135-
name="input_feed")
133+
image_feed = tf.compat.v1.placeholder(dtype=tf.string, shape=[], name="image_feed")
134+
input_feed = tf.compat.v1.placeholder(dtype=tf.int64,
135+
shape=[None], # batch_size
136+
name="input_feed")
136137

137138
# Process image and insert batch dimensions.
138139
images = tf.expand_dims(self.process_image(image_feed), 0)
@@ -192,12 +193,12 @@ def build_image_embeddings(self):
192193
self.images,
193194
trainable=self.train_inception,
194195
is_training=self.is_training())
195-
self.inception_variables = tf.get_collection(
196-
tf.GraphKeys.GLOBAL_VARIABLES, scope="InceptionV3")
196+
self.inception_variables = tf.compat.v1.get_collection(
197+
tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope="InceptionV3")
197198

198199
# Map inception output into embedding space.
199-
with tf.variable_scope("image_embedding") as scope:
200-
image_embeddings = tf.contrib.layers.fully_connected(
200+
with tf.compat.v1.variable_scope("image_embedding") as scope:
201+
image_embeddings = slim.layers.fully_connected(
201202
inputs=inception_output,
202203
num_outputs=self.config.embedding_size,
203204
activation_fn=None,
@@ -219,8 +220,8 @@ def build_seq_embeddings(self):
219220
Outputs:
220221
self.seq_embeddings
221222
"""
222-
with tf.variable_scope("seq_embedding"), tf.device("/cpu:0"):
223-
embedding_map = tf.get_variable(
223+
with tf.compat.v1.variable_scope("seq_embedding"), tf.device("/cpu:0"):
224+
embedding_map = tf.compat.v1.get_variable(
224225
name="map",
225226
shape=[self.config.vocab_size, self.config.embedding_size],
226227
initializer=self.initializer)
@@ -245,15 +246,15 @@ def build_model(self):
245246
# This LSTM cell has biases and outputs tanh(new_c) * sigmoid(o), but the
246247
# modified LSTM in the "Show and Tell" paper has no biases and outputs
247248
# new_c * sigmoid(o).
248-
lstm_cell = tf.contrib.rnn.BasicLSTMCell(
249+
lstm_cell = tf.compat.v1.nn.rnn_cell.BasicLSTMCell(
249250
num_units=self.config.num_lstm_units, state_is_tuple=True)
250251
if self.mode == "train":
251252
lstm_cell = tf.contrib.rnn.DropoutWrapper(
252253
lstm_cell,
253254
input_keep_prob=self.config.lstm_dropout_keep_prob,
254255
output_keep_prob=self.config.lstm_dropout_keep_prob)
255256

256-
with tf.variable_scope("lstm", initializer=self.initializer) as lstm_scope:
257+
with tf.compat.v1.variable_scope("lstm", initializer=self.initializer) as lstm_scope:
257258
# Feed the image embeddings to set the initial LSTM state.
258259
zero_state = lstm_cell.zero_state(
259260
batch_size=self.image_embeddings.get_shape()[0], dtype=tf.float32)
@@ -268,9 +269,9 @@ def build_model(self):
268269
tf.concat(axis=1, values=initial_state, name="initial_state")
269270

270271
# Placeholder for feeding a batch of concatenated states.
271-
state_feed = tf.placeholder(dtype=tf.float32,
272-
shape=[None, sum(lstm_cell.state_size)],
273-
name="state_feed")
272+
state_feed = tf.compat.v1.placeholder(dtype=tf.float32,
273+
shape=[None, sum(lstm_cell.state_size)],
274+
name="state_feed")
274275
state_tuple = tf.split(value=state_feed, num_or_size_splits=2, axis=1)
275276

276277
# Run a single LSTM step.
@@ -293,8 +294,8 @@ def build_model(self):
293294
# Stack batches vertically.
294295
lstm_outputs = tf.reshape(lstm_outputs, [-1, lstm_cell.output_size])
295296

296-
with tf.variable_scope("logits") as logits_scope:
297-
logits = tf.contrib.layers.fully_connected(
297+
with tf.compat.v1.variable_scope("logits") as logits_scope:
298+
logits = slim.layers.fully_connected(
298299
inputs=lstm_outputs,
299300
num_outputs=self.config.vocab_size,
300301
activation_fn=None,
@@ -341,11 +342,11 @@ def restore_fn(sess):
341342

342343
def setup_global_step(self):
343344
"""Sets up the global step Tensor."""
344-
global_step = tf.Variable(
345+
global_step = tf.compat.v1.Variable(
345346
initial_value=0,
346347
name="global_step",
347348
trainable=False,
348-
collections=[tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.GLOBAL_VARIABLES])
349+
collections=[tf.compat.v1.GraphKeys.GLOBAL_STEP, tf.compat.v1.GraphKeys.GLOBAL_VARIABLES])
349350

350351
self.global_step = global_step
351352

requirements.txt

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1-
tensorflow==1.15.4
2-
Pillow==8.2.0
3-
numpy==1.17.4
1+
tensorflow==2.6.0
2+
tf_slim==1.1.0
3+
Pillow==8.3.1
4+
# NumPy version is dictated by tensorflow

0 commit comments

Comments
 (0)