# coding=utf-8
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Decoding utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import operator
import os
# Dependency imports
import numpy as np
import six
from six.moves import input # pylint: disable=redefined-builtin
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.utils import devices
from tensor2tensor.utils import input_fn_builder
import tensorflow as tf
FLAGS = tf.flags.FLAGS
# Number of samples to draw for an image input (in such cases as captioning)
IMAGE_DECODE_LENGTH = 100
def decode_hparams(overrides=""):
"""Hyperparameters for decoding."""
hp = tf.contrib.training.HParams(
use_last_position_only=False,
save_images=False,
problem_idx=0,
extra_length=50,
batch_size=0,
beam_size=4,
alpha=0.6,
return_beams=False,
max_input_size=-1,
identity_output=False,
num_samples=-1,
delimiter="\n")
hp = hp.parse(overrides)
return hp
def log_decode_results(inputs,
outputs,
problem_name,
prediction_idx,
inputs_vocab,
targets_vocab,
targets=None,
save_images=False,
model_dir=None,
identity_output=False):
"""Log inference results."""
is_image = "image" in problem_name
if is_image and save_images:
save_path = os.path.join(model_dir, "%s_prediction_%d.jpg" %
(problem_name, prediction_idx))
show_and_save_image(inputs / 255., save_path)
elif inputs_vocab:
if identity_output:
decoded_inputs = " ".join(map(str, inputs.flatten()))
else:
decoded_inputs = inputs_vocab.decode(_save_until_eos(inputs, is_image))
tf.logging.info("Inference results INPUT: %s" % decoded_inputs)
decoded_targets = None
if identity_output:
decoded_outputs = " ".join(map(str, outputs.flatten()))
if targets is not None:
decoded_targets = " ".join(map(str, targets.flatten()))
else:
decoded_outputs = targets_vocab.decode(_save_until_eos(outputs, is_image))
if targets is not None:
decoded_targets = targets_vocab.decode(_save_until_eos(targets, is_image))
tf.logging.info("Inference results OUTPUT: %s" % decoded_outputs)
if targets is not None:
tf.logging.info("Inference results TARGET: %s" % decoded_targets)
return decoded_outputs, decoded_targets
def decode_from_dataset(estimator,
problem_names,
decode_hp,
decode_to_file=None,
dataset_split=None):
tf.logging.info("Performing local inference from dataset for %s.",
str(problem_names))
hparams = estimator.params
# We assume that worker_id corresponds to shard number.
shard = decode_hp.shard_id if decode_hp.shards > 1 else None
for problem_idx, problem_name in enumerate(problem_names):
# Build the inference input function
infer_input_fn = input_fn_builder.build_input_fn(
mode=tf.estimator.ModeKeys.PREDICT,
hparams=hparams,
data_dir=hparams.data_dir,
num_datashards=devices.data_parallelism().n,
fixed_problem=problem_idx,
batch_size=decode_hp.batch_size,
dataset_split=dataset_split,
shard=shard)
# Get the predictions as an iterable
predictions = estimator.predict(infer_input_fn)
# Prepare output file writers if decode_to_file passed
if decode_to_file:
if decode_hp.shards > 1:
decode_filename = decode_to_file + ("%.2d" % decode_hp.shard_id)
else:
decode_filename = decode_to_file
output_filepath = _decode_filename(decode_filename, problem_name,
decode_hp)
parts = output_filepath.split(".")
parts[-1] = "targets"
target_filepath = ".".join(parts)
output_file = tf.gfile.Open(output_filepath, "w")
target_file = tf.gfile.Open(target_filepath, "w")
problem_hparams = hparams.problems[problem_idx]
# Inputs vocabulary is set to targets if there are no inputs in the problem,
# e.g., for language models where the inputs are just a prefix of targets.
has_input = "inputs" in problem_hparams.vocabulary
inputs_vocab_key = "inputs" if has_input else "targets"
inputs_vocab = problem_hparams.vocabulary[inputs_vocab_key]
targets_vocab = problem_hparams.vocabulary["targets"]
for num_predictions, prediction in enumerate(predictions):
num_predictions += 1
inputs = prediction["inputs"]
targets = prediction["targets"]
outputs = prediction["outputs"]
# Log predictions
decoded_outputs = []
if decode_hp.return_beams:
output_beams = np.split(outputs, decode_hp.beam_size, axis=0)
for i, beam in enumerate(output_beams):
tf.logging.info("BEAM %d:" % i)
decoded = log_decode_results(
inputs,
beam,
problem_name,
num_predictions,
inputs_vocab,
targets_vocab,
save_images=decode_hp.save_images,
model_dir=estimator.model_dir,
identity_output=decode_hp.identity_output,
targets=targets)
decoded_outputs.append(decoded)
else:
decoded = log_decode_results(
inputs,
outputs,
problem_name,
num_predictions,
inputs_vocab,
targets_vocab,
save_images=decode_hp.save_images,
model_dir=estimator.model_dir,
identity_output=decode_hp.identity_output,
targets=targets)
decoded_outputs.append(decoded)
# Write out predictions if decode_to_file passed
if decode_to_file:
for decoded_output, decoded_target in decoded_outputs:
output_file.write(str(decoded_output) + decode_hp.delimiter)
target_file.write(str(decoded_target) + decode_hp.delimiter)
if (decode_hp.num_samples >= 0 and
num_predictions >= decode_hp.num_samples):
break
if decode_to_file:
output_file.close()
target_file.close()
tf.logging.info("Completed inference on %d samples." % num_predictions) # pylint: disable=undefined-loop-variable
def decode_from_file(estimator, filename, decode_hp, decode_to_file=None):
"""Compute predictions on entries in filename and write them out."""
if not decode_hp.batch_size:
decode_hp.batch_size = 32
tf.logging.info(
"decode_hp.batch_size not specified; default=%d" % decode_hp.batch_size)
hparams = estimator.params
problem_id = decode_hp.problem_idx
# Inputs vocabulary is set to targets if there are no inputs in the problem,
# e.g., for language models where the inputs are just a prefix of targets.
has_input = "inputs" in hparams.problems[problem_id].vocabulary
inputs_vocab_key = "inputs" if has_input else "targets"
inputs_vocab = hparams.problems[problem_id].vocabulary[inputs_vocab_key]
targets_vocab = hparams.problems[problem_id].vocabulary["targets"]
problem_name = FLAGS.problems.split("-")[problem_id]
tf.logging.info("Performing decoding from a file.")
sorted_inputs, sorted_keys = _get_sorted_inputs(filename, decode_hp.shards,
decode_hp.delimiter)
num_decode_batches = (len(sorted_inputs) - 1) // decode_hp.batch_size + 1
def input_fn():
input_gen = _decode_batch_input_fn(
problem_id, num_decode_batches, sorted_inputs, inputs_vocab,
decode_hp.batch_size, decode_hp.max_input_size)
gen_fn = make_input_fn_from_generator(input_gen)
example = gen_fn()
return _decode_input_tensor_to_features_dict(example, hparams)
decodes = []
result_iter = estimator.predict(input_fn)
for result in result_iter:
if decode_hp.return_beams:
beam_decodes = []
output_beams = np.split(result["outputs"], decode_hp.beam_size, axis=0)
for k, beam in enumerate(output_beams):
tf.logging.info("BEAM %d:" % k)
decoded_outputs, _ = log_decode_results(result["inputs"], beam,
problem_name, None,
inputs_vocab, targets_vocab)
beam_decodes.append(decoded_outputs)
decodes.append("\t".join(beam_decodes))
else:
decoded_outputs, _ = log_decode_results(result["inputs"],
result["outputs"], problem_name,
None, inputs_vocab, targets_vocab)
decodes.append(decoded_outputs)
# Reversing the decoded inputs and outputs because they were reversed in
# _decode_batch_input_fn
sorted_inputs.reverse()
decodes.reverse()
# Dumping inputs and outputs to file filename.decodes in
# format result\tinput in the same order as original inputs
if decode_to_file:
output_filename = decode_to_file
else:
output_filename = filename
if decode_hp.shards > 1:
base_filename = output_filename + ("%.2d" % decode_hp.shard_id)
else:
base_filename = output_filename
decode_filename = _decode_filename(base_filename, problem_name, decode_hp)
tf.logging.info("Writing decodes into %s" % decode_filename)
outfile = tf.gfile.Open(decode_filename, "w")
for index in range(len(sorted_inputs)):
outfile.write("%s%s" % (decodes[sorted_keys[index]], decode_hp.delimiter))
def _decode_filename(base_filename, problem_name, decode_hp):
return "{base}.{model}.{hp}.{problem}.beam{beam}.alpha{alpha}.decodes".format(
base=base_filename,
model=FLAGS.model,
hp=FLAGS.hparams_set,
problem=problem_name,
beam=str(decode_hp.beam_size),
alpha=str(decode_hp.alpha))
def make_input_fn_from_generator(gen):
"""Use py_func to yield elements from the given generator."""
first_ex = six.next(gen)
flattened = tf.contrib.framework.nest.flatten(first_ex)
types = [t.dtype for t in flattened]
shapes = [[None] * len(t.shape) for t in flattened]
first_ex_list = [first_ex]
def py_func():
if first_ex_list:
example = first_ex_list.pop()
else:
example = six.next(gen)
return tf.contrib.framework.nest.flatten(example)
def input_fn():
flat_example = tf.py_func(py_func, [], types)
_ = [t.set_shape(shape) for t, shape in zip(flat_example, shapes)]
example = tf.contrib.framework.nest.pack_sequence_as(first_ex, flat_example)
return example
return input_fn
def decode_interactively(estimator, decode_hp):
"""Interactive decoding."""
hparams = estimator.params
def input_fn():
gen_fn = make_input_fn_from_generator(_interactive_input_fn(hparams))
example = gen_fn()
example = _interactive_input_tensor_to_features_dict(example, hparams)
return example
result_iter = estimator.predict(input_fn)
for result in result_iter:
problem_idx = result["problem_choice"]
is_image = False # TODO(lukaszkaiser): find out from problem id / class.
targets_vocab = hparams.problems[problem_idx].vocabulary["targets"]
if decode_hp.return_beams:
beams = np.split(result["outputs"], decode_hp.beam_size, axis=0)
scores = None
if "scores" in result:
scores = np.split(result["scores"], decode_hp.beam_size, axis=0)
for k, beam in enumerate(beams):
tf.logging.info("BEAM %d:" % k)
beam_string = targets_vocab.decode(_save_until_eos(beam, is_image))
if scores is not None:
tf.logging.info("%s\tScore:%f" % (beam_string, scores[k]))
else:
tf.logging.info(beam_string)
else:
if decode_hp.identity_output:
tf.logging.info(" ".join(map(str, result["outputs"].flatten())))
else:
tf.logging.info(
targets_vocab.decode(_save_until_eos(result["outputs"], is_image)))
def _decode_batch_input_fn(problem_id, num_decode_batches, sorted_inputs,
vocabulary, batch_size, max_input_size):
tf.logging.info(" batch %d" % num_decode_batches)
# First reverse all the input sentences so that if you're going to get OOMs,
# you'll see it in the first batch
sorted_inputs.reverse()
for b in range(num_decode_batches):
tf.logging.info("Decoding batch %d" % b)
batch_length = 0
batch_inputs = []
for inputs in sorted_inputs[b * batch_size:(b + 1) * batch_size]:
input_ids = vocabulary.encode(inputs)
if max_input_size > 0:
# Subtract 1 for the EOS_ID.
input_ids = input_ids[:max_input_size - 1]
input_ids.append(text_encoder.EOS_ID)
batch_inputs.append(input_ids)
if len(input_ids) > batch_length:
batch_length = len(input_ids)
final_batch_inputs = []
for input_ids in batch_inputs:
assert len(input_ids) <= batch_length
x = input_ids + [0] * (batch_length - len(input_ids))
final_batch_inputs.append(x)
yield {
"inputs": np.array(final_batch_inputs).astype(np.int32),
"problem_choice": np.array(problem_id).astype(np.int32),
}
def _interactive_input_fn(hparams):
"""Generator that reads from the terminal and yields "interactive inputs".
Due to temporary limitations in tf.learn, if we don't want to reload the
whole graph, then we are stuck encoding all of the input as one fixed-size
numpy array.
We yield int32 arrays with shape [const_array_size]. The format is:
[num_samples, decode_length, len(input ids), , ]
Args:
hparams: model hparams
Yields:
numpy arrays
Raises:
Exception: when `input_type` is invalid.
"""
num_samples = 1
decode_length = 100
input_type = "text"
problem_id = 0
p_hparams = hparams.problems[problem_id]
has_input = "inputs" in p_hparams.input_modality
vocabulary = p_hparams.vocabulary["inputs" if has_input else "targets"]
# This should be longer than the longest input.
const_array_size = 10000
# Import readline if available for command line editing and recall.
try:
import readline # pylint: disable=g-import-not-at-top,unused-variable
except ImportError:
pass
while True:
prompt = ("INTERACTIVE MODE num_samples=%d decode_length=%d \n"
" it= ('text' or 'image' or 'label', default: "
"text)\n"
" pr= (set the problem number, default: 0)\n"
" in= (set the input problem number)\n"
" ou= (set the output problem number)\n"
" ns= (changes number of samples, default: 1)\n"
" dl= (changes decode length, default: 100)\n"
" <%s> (decode)\n"
" q (quit)\n"
">" % (num_samples, decode_length, "source_string"
if has_input else "target_prefix"))
input_string = input(prompt)
if input_string == "q":
return
elif input_string[:3] == "pr=":
problem_id = int(input_string[3:])
p_hparams = hparams.problems[problem_id]
has_input = "inputs" in p_hparams.input_modality
vocabulary = p_hparams.vocabulary["inputs" if has_input else "targets"]
elif input_string[:3] == "in=":
problem = int(input_string[3:])
p_hparams.input_modality = hparams.problems[problem].input_modality
p_hparams.input_space_id = hparams.problems[problem].input_space_id
elif input_string[:3] == "ou=":
problem = int(input_string[3:])
p_hparams.target_modality = hparams.problems[problem].target_modality
p_hparams.target_space_id = hparams.problems[problem].target_space_id
elif input_string[:3] == "ns=":
num_samples = int(input_string[3:])
elif input_string[:3] == "dl=":
decode_length = int(input_string[3:])
elif input_string[:3] == "it=":
input_type = input_string[3:]
else:
if input_type == "text":
input_ids = vocabulary.encode(input_string)
if has_input:
input_ids.append(text_encoder.EOS_ID)
x = [num_samples, decode_length, len(input_ids)] + input_ids
assert len(x) < const_array_size
x += [0] * (const_array_size - len(x))
yield {
"inputs": np.array(x).astype(np.int32),
"problem_choice": np.array(problem_id).astype(np.int32)
}
elif input_type == "image":
input_path = input_string
img = read_image(input_path)
yield {
"inputs": img.astype(np.int32),
"problem_choice": np.array(problem_id).astype(np.int32)
}
elif input_type == "label":
input_ids = [int(input_string)]
x = [num_samples, decode_length, len(input_ids)] + input_ids
yield {
"inputs": np.array(x).astype(np.int32),
"problem_choice": np.array(problem_id).astype(np.int32)
}
else:
raise Exception("Unsupported input type.")
def read_image(path):
try:
import matplotlib.image as im # pylint: disable=g-import-not-at-top
except ImportError as e:
tf.logging.warning(
"Reading an image requires matplotlib to be installed: %s", e)
raise NotImplementedError("Image reading not implemented.")
return im.imread(path)
def show_and_save_image(img, save_path):
try:
import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top
except ImportError as e:
tf.logging.warning("Showing and saving an image requires matplotlib to be "
"installed: %s", e)
raise NotImplementedError("Image display and save not implemented.")
plt.imshow(img)
plt.savefig(save_path)
def _get_sorted_inputs(filename, num_shards=1, delimiter="\n"):
"""Returning inputs sorted according to length.
Args:
filename: path to file with inputs, 1 per line.
num_shards: number of input shards. If > 1, will read from file filename.XX,
where XX is FLAGS.worker_id.
delimiter: str, delimits records in the file.
Returns:
a sorted list of inputs
"""
tf.logging.info("Getting sorted inputs")
# read file and sort inputs according them according to input length.
if num_shards > 1:
decode_filename = filename + ("%.2d" % FLAGS.worker_id)
else:
decode_filename = filename
with tf.gfile.Open(decode_filename) as f:
text = f.read()
records = text.split(delimiter)
inputs = [record.strip() for record in records[:-1]]
input_lens = [(i, len(line.split())) for i, line in enumerate(inputs)]
sorted_input_lens = sorted(input_lens, key=operator.itemgetter(1))
# We'll need the keys to rearrange the inputs back into their original order
sorted_keys = {}
sorted_inputs = []
for i, (index, _) in enumerate(sorted_input_lens):
sorted_inputs.append(inputs[index])
sorted_keys[index] = i
return sorted_inputs, sorted_keys
def _save_until_eos(hyp, is_image):
"""Strips everything after the first token, which is normally 1."""
hyp = hyp.flatten()
if is_image:
return hyp
try:
index = list(hyp).index(text_encoder.EOS_ID)
return hyp[0:index]
except ValueError:
# No EOS_ID: return the array as-is.
return hyp
def _interactive_input_tensor_to_features_dict(feature_map, hparams):
"""Convert the interactive input format (see above) to a dictionary.
Args:
feature_map: a dictionary with keys `problem_choice` and `input` containing
Tensors.
hparams: model hyperparameters
Returns:
a features dictionary, as expected by the decoder.
"""
inputs = tf.convert_to_tensor(feature_map["inputs"])
input_is_image = False if len(inputs.get_shape()) < 3 else True
def input_fn(problem_choice, x=inputs): # pylint: disable=missing-docstring
if input_is_image:
x = tf.image.resize_images(x, [299, 299])
x = tf.reshape(x, [1, 299, 299, -1])
x = tf.to_int32(x)
else:
# Remove the batch dimension.
num_samples = x[0]
length = x[2]
x = tf.slice(x, [3], tf.to_int32([length]))
x = tf.reshape(x, [1, -1, 1, 1])
# Transform into a batch of size num_samples to get that many random
# decodes.
x = tf.tile(x, tf.to_int32([num_samples, 1, 1, 1]))
p_hparams = hparams.problems[problem_choice]
return (tf.constant(p_hparams.input_space_id), tf.constant(
p_hparams.target_space_id), x)
input_space_id, target_space_id, x = input_fn_builder.cond_on_index(
input_fn, feature_map["problem_choice"], len(hparams.problems) - 1)
features = {}
features["problem_choice"] = tf.convert_to_tensor(
feature_map["problem_choice"])
features["input_space_id"] = input_space_id
features["target_space_id"] = target_space_id
features["decode_length"] = (
IMAGE_DECODE_LENGTH if input_is_image else inputs[1])
features["inputs"] = x
return features
def _decode_input_tensor_to_features_dict(feature_map, hparams):
"""Convert the interactive input format (see above) to a dictionary.
Args:
feature_map: a dictionary with keys `problem_choice` and `input` containing
Tensors.
hparams: model hyperparameters
Returns:
a features dictionary, as expected by the decoder.
"""
inputs = tf.convert_to_tensor(feature_map["inputs"])
input_is_image = False
def input_fn(problem_choice, x=inputs): # pylint: disable=missing-docstring
p_hparams = hparams.problems[problem_choice]
# Add a third empty dimension dimension
x = tf.expand_dims(x, axis=[2])
x = tf.to_int32(x)
return (tf.constant(p_hparams.input_space_id), tf.constant(
p_hparams.target_space_id), x)
input_space_id, target_space_id, x = input_fn_builder.cond_on_index(
input_fn, feature_map["problem_choice"], len(hparams.problems) - 1)
features = {}
features["problem_choice"] = feature_map["problem_choice"]
features["input_space_id"] = input_space_id
features["target_space_id"] = target_space_id
features["decode_length"] = (
IMAGE_DECODE_LENGTH if input_is_image else tf.shape(x)[1] + 50)
features["inputs"] = x
return features