# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""For training NMT models."""
from __future__ import print_function
import collections
import math
import os
import random
import time
import tensorflow as tf
from . import attention_model
from . import gnmt_model
from . import inference
from . import model as nmt_model
from . import model_helper
from .utils import iterator_utils
from .utils import misc_utils as utils
from .utils import nmt_utils
from .utils import vocab_utils
utils.check_tensorflow_version()
__all__ = [
"create_train_model", "create_eval_model", "run_sample_decode",
"run_internal_eval", "run_external_eval", "run_full_eval", "train"
]
class TrainModel(
collections.namedtuple("TrainModel", ("graph", "model", "iterator",
"skip_count_placeholder"))):
pass
def create_train_model(
model_creator, hparams, scope=None, single_cell_fn=None,
model_device_fn=None):
"""Create train graph, model, and iterator."""
src_file = "%s.%s" % (hparams.train_prefix, hparams.src)
tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt)
src_vocab_file = hparams.src_vocab_file
tgt_vocab_file = hparams.tgt_vocab_file
graph = tf.Graph()
with graph.as_default():
src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
src_vocab_file, tgt_vocab_file, hparams.share_vocab)
src_dataset = tf.contrib.data.TextLineDataset(src_file)
tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file)
skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)
iterator = iterator_utils.get_iterator(
src_dataset,
tgt_dataset,
src_vocab_table,
tgt_vocab_table,
batch_size=hparams.batch_size,
sos=hparams.sos,
eos=hparams.eos,
source_reverse=hparams.source_reverse,
random_seed=hparams.random_seed,
num_buckets=hparams.num_buckets,
src_max_len=hparams.src_max_len,
tgt_max_len=hparams.tgt_max_len,
skip_count=skip_count_placeholder)
# Note: One can set model_device_fn to
# `tf.train.replica_device_setter(ps_tasks)` for distributed training.
with tf.device(model_device_fn):
model = model_creator(
hparams,
iterator=iterator,
mode=tf.contrib.learn.ModeKeys.TRAIN,
source_vocab_table=src_vocab_table,
target_vocab_table=tgt_vocab_table,
scope=scope,
single_cell_fn=single_cell_fn)
return TrainModel(
graph=graph,
model=model,
iterator=iterator,
skip_count_placeholder=skip_count_placeholder)
class EvalModel(
collections.namedtuple("EvalModel",
("graph", "model", "src_file_placeholder",
"tgt_file_placeholder", "iterator"))):
pass
def create_eval_model(model_creator, hparams, scope=None, single_cell_fn=None):
"""Create train graph, model, src/tgt file holders, and iterator."""
src_vocab_file = hparams.src_vocab_file
tgt_vocab_file = hparams.tgt_vocab_file
graph = tf.Graph()
with graph.as_default():
src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
src_vocab_file, tgt_vocab_file, hparams.share_vocab)
src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
src_dataset = tf.contrib.data.TextLineDataset(src_file_placeholder)
tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file_placeholder)
iterator = iterator_utils.get_iterator(
src_dataset,
tgt_dataset,
src_vocab_table,
tgt_vocab_table,
hparams.batch_size,
sos=hparams.sos,
eos=hparams.eos,
source_reverse=hparams.source_reverse,
random_seed=hparams.random_seed,
num_buckets=hparams.num_buckets,
src_max_len=hparams.src_max_len_infer,
tgt_max_len=hparams.tgt_max_len_infer)
model = model_creator(
hparams,
iterator=iterator,
mode=tf.contrib.learn.ModeKeys.EVAL,
source_vocab_table=src_vocab_table,
target_vocab_table=tgt_vocab_table,
scope=scope,
single_cell_fn=single_cell_fn)
return EvalModel(
graph=graph,
model=model,
src_file_placeholder=src_file_placeholder,
tgt_file_placeholder=tgt_file_placeholder,
iterator=iterator)
def run_sample_decode(infer_model, infer_sess, model_dir, hparams,
summary_writer, src_data, tgt_data):
"""Sample decode a random sentence from src_data."""
with infer_model.graph.as_default():
loaded_infer_model, global_step = model_helper.create_or_load_model(
infer_model.model, model_dir, infer_sess, "infer")
_sample_decode(loaded_infer_model, global_step, infer_sess, hparams,
infer_model.iterator, src_data, tgt_data,
infer_model.src_placeholder,
infer_model.batch_size_placeholder, summary_writer)
def run_internal_eval(
eval_model, eval_sess, model_dir, hparams, summary_writer):
"""Compute internal evaluation (perplexity) for both dev / test."""
with eval_model.graph.as_default():
loaded_eval_model, global_step = model_helper.create_or_load_model(
eval_model.model, model_dir, eval_sess, "eval")
dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
dev_eval_iterator_feed_dict = {
eval_model.src_file_placeholder: dev_src_file,
eval_model.tgt_file_placeholder: dev_tgt_file
}
dev_ppl = _internal_eval(loaded_eval_model, global_step, eval_sess,
eval_model.iterator, dev_eval_iterator_feed_dict,
summary_writer, "dev")
test_ppl = None
if hparams.test_prefix:
test_src_file = "%s.%s" % (hparams.test_prefix, hparams.src)
test_tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt)
test_eval_iterator_feed_dict = {
eval_model.src_file_placeholder: test_src_file,
eval_model.tgt_file_placeholder: test_tgt_file
}
test_ppl = _internal_eval(loaded_eval_model, global_step, eval_sess,
eval_model.iterator, test_eval_iterator_feed_dict,
summary_writer, "test")
return dev_ppl, test_ppl
def run_external_eval(infer_model, infer_sess, model_dir, hparams,
summary_writer, save_best_dev=True):
"""Compute external evaluation (bleu, rouge, etc.) for both dev / test."""
with infer_model.graph.as_default():
loaded_infer_model, global_step = model_helper.create_or_load_model(
infer_model.model, model_dir, infer_sess, "infer")
dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
dev_infer_iterator_feed_dict = {
infer_model.src_placeholder: inference.load_data(dev_src_file),
infer_model.batch_size_placeholder: hparams.infer_batch_size,
}
dev_scores = _external_eval(
loaded_infer_model,
global_step,
infer_sess,
hparams,
infer_model.iterator,
dev_infer_iterator_feed_dict,
dev_tgt_file,
"dev",
summary_writer,
save_on_best=save_best_dev)
test_scores = None
if hparams.test_prefix:
test_src_file = "%s.%s" % (hparams.test_prefix, hparams.src)
test_tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt)
test_infer_iterator_feed_dict = {
infer_model.src_placeholder: inference.load_data(test_src_file),
infer_model.batch_size_placeholder: hparams.infer_batch_size,
}
test_scores = _external_eval(
loaded_infer_model,
global_step,
infer_sess,
hparams,
infer_model.iterator,
test_infer_iterator_feed_dict,
test_tgt_file,
"test",
summary_writer,
save_on_best=False)
return dev_scores, test_scores, global_step
def run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess,
hparams, summary_writer, sample_src_data, sample_tgt_data):
"""Wrapper for running sample_decode, internal_eval and external_eval."""
run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer,
sample_src_data, sample_tgt_data)
dev_ppl, test_ppl = run_internal_eval(
eval_model, eval_sess, model_dir, hparams, summary_writer)
dev_scores, test_scores, global_step = run_external_eval(
infer_model, infer_sess, model_dir, hparams, summary_writer)
result_summary = _format_results("dev", dev_ppl, dev_scores, hparams.metrics)
if hparams.test_prefix:
result_summary += ", " + _format_results("test", test_ppl, test_scores,
hparams.metrics)
return result_summary, global_step, dev_scores, test_scores, dev_ppl, test_ppl
def train(hparams, scope=None, target_session="", single_cell_fn=None):
"""Train a translation model."""
log_device_placement = hparams.log_device_placement
out_dir = hparams.out_dir
num_train_steps = hparams.num_train_steps
steps_per_stats = hparams.steps_per_stats
steps_per_external_eval = hparams.steps_per_external_eval
steps_per_eval = 10 * steps_per_stats
if not steps_per_external_eval:
steps_per_external_eval = 5 * steps_per_eval
if not hparams.attention:
model_creator = nmt_model.Model
elif hparams.attention_architecture == "standard":
model_creator = attention_model.AttentionModel
elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
model_creator = gnmt_model.GNMTModel
else:
raise ValueError("Unknown model architecture")
train_model = create_train_model(model_creator, hparams, scope,
single_cell_fn)
eval_model = create_eval_model(model_creator, hparams, scope,
single_cell_fn)
infer_model = inference.create_infer_model(model_creator, hparams,
scope, single_cell_fn)
# Preload data for sample decoding.
dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
sample_src_data = inference.load_data(dev_src_file)
sample_tgt_data = inference.load_data(dev_tgt_file)
summary_name = "train_log"
model_dir = hparams.out_dir
# Log and output files
log_file = os.path.join(out_dir, "log_%d" % time.time())
log_f = tf.gfile.GFile(log_file, mode="a")
utils.print_out("# log_file=%s" % log_file, log_f)
avg_step_time = 0.0
# TensorFlow model
config_proto = utils.get_config_proto(
log_device_placement=log_device_placement)
train_sess = tf.Session(
target=target_session, config=config_proto, graph=train_model.graph)
eval_sess = tf.Session(
target=target_session, config=config_proto, graph=eval_model.graph)
infer_sess = tf.Session(
target=target_session, config=config_proto, graph=infer_model.graph)
with train_model.graph.as_default():
loaded_train_model, global_step = model_helper.create_or_load_model(
train_model.model, model_dir, train_sess, "train")
# Summary writer
summary_writer = tf.summary.FileWriter(
os.path.join(out_dir, summary_name), train_model.graph)
# First evaluation
run_full_eval(
model_dir, infer_model, infer_sess,
eval_model, eval_sess, hparams,
summary_writer, sample_src_data,
sample_tgt_data)
last_stats_step = global_step
last_eval_step = global_step
last_external_eval_step = global_step
# This is the training loop.
step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
# Added the measurements on total number of samples.
checkpoint_total_count, checkpoint_total_samples = 0.0, 0.0
# checkpoint_total_count = 0.0
#
speed, train_ppl = 0.0, 0.0
start_train_time = time.time()
utils.print_out(
"# Start step %d, lr %g, %s" %
(global_step, loaded_train_model.learning_rate.eval(session=train_sess),
time.ctime()),
log_f)
# Initialize all of the iterators
skip_count = hparams.batch_size * hparams.epoch_step
utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
train_sess.run(
train_model.iterator.initializer,
feed_dict={train_model.skip_count_placeholder: skip_count})
while global_step < num_train_steps:
# Added the profiler start and end point.
import numba.cuda as cuda
if global_step == 501:
cuda.profile_start()
if global_step == 511:
cuda.profile_stop()
#
### Run a step ###
start_time = time.time()
try:
step_result = loaded_train_model.train(train_sess)
(_, step_loss, step_predict_count, step_summary, global_step,
step_word_count, batch_size) = step_result
hparams.epoch_step += 1
except tf.errors.OutOfRangeError:
# Finished going through the training dataset. Go to next epoch.
hparams.epoch_step = 0
utils.print_out(
"# Finished an epoch, step %d. Perform external evaluation" %
global_step)
run_sample_decode(infer_model, infer_sess,
model_dir, hparams, summary_writer, sample_src_data,
sample_tgt_data)
dev_scores, test_scores, _ = run_external_eval(
infer_model, infer_sess, model_dir,
hparams, summary_writer)
train_sess.run(
train_model.iterator.initializer,
feed_dict={train_model.skip_count_placeholder: 0})
continue
# Write step summary.
summary_writer.add_summary(step_summary, global_step)
# update statistics
step_time += (time.time() - start_time)
checkpoint_loss += (step_loss * batch_size)
checkpoint_predict_count += step_predict_count
checkpoint_total_count += float(step_word_count)
# Increase the total number of samples by batch size.
checkpoint_total_samples += float(batch_size)
#
# Once in a while, we print statistics.
if global_step - last_stats_step >= steps_per_stats:
last_stats_step = global_step
# Print statistics for the previous epoch.
avg_step_time = step_time / steps_per_stats
train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count)
speed = checkpoint_total_count / (1000 * step_time)
# Added samples per second to the log file.
speed_samples_per_sec = checkpoint_total_samples / (step_time)
utils.print_out(
" global step %d lr %g "
"step-time %.2fs wps %.2fK sps %5.2f ppl %.2f %s" %
(global_step,
loaded_train_model.learning_rate.eval(session=train_sess),
avg_step_time, speed, speed_samples_per_sec,
train_ppl, _get_best_results(hparams)), log_f)
#
"""
utils.print_out(
" global step %d lr %g "
"step-time %.2fs wps %.2fK ppl %.2f %s" %
(global_step,
loaded_train_model.learning_rate.eval(session=train_sess),
avg_step_time, speed, train_ppl, _get_best_results(hparams)),
log_f)
"""
if math.isnan(train_ppl):
break
# Reset timer and loss.
step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
checkpoint_total_count = 0.0
if global_step - last_eval_step >= steps_per_eval:
last_eval_step = global_step
utils.print_out("# Save eval, global step %d" % global_step)
utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl)
# Save checkpoint
loaded_train_model.saver.save(
train_sess,
os.path.join(out_dir, "translate.ckpt"),
global_step=global_step)
# Evaluate on dev/test
run_sample_decode(infer_model, infer_sess,
model_dir, hparams, summary_writer, sample_src_data,
sample_tgt_data)
dev_ppl, test_ppl = run_internal_eval(
eval_model, eval_sess, model_dir, hparams, summary_writer)
if global_step - last_external_eval_step >= steps_per_external_eval:
last_external_eval_step = global_step
# Save checkpoint
loaded_train_model.saver.save(
train_sess,
os.path.join(out_dir, "translate.ckpt"),
global_step=global_step)
run_sample_decode(infer_model, infer_sess,
model_dir, hparams, summary_writer, sample_src_data,
sample_tgt_data)
dev_scores, test_scores, _ = run_external_eval(
infer_model, infer_sess, model_dir,
hparams, summary_writer)
# Done training
loaded_train_model.saver.save(
train_sess,
os.path.join(out_dir, "translate.ckpt"),
global_step=global_step)
result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval(
model_dir, infer_model, infer_sess,
eval_model, eval_sess, hparams,
summary_writer, sample_src_data,
sample_tgt_data)
utils.print_out(
"# Final, step %d lr %g "
"step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
(global_step, loaded_train_model.learning_rate.eval(session=train_sess),
avg_step_time, speed, train_ppl, result_summary, time.ctime()),
log_f)
utils.print_time("# Done training!", start_train_time)
utils.print_out("# Start evaluating saved best models.")
for metric in hparams.metrics:
best_model_dir = getattr(hparams, "best_" + metric + "_dir")
result_summary, best_global_step, _, _, _, _ = run_full_eval(
best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
summary_writer, sample_src_data, sample_tgt_data)
utils.print_out("# Best %s, step %d "
"step-time %.2f wps %.2fK, %s, %s" %
(metric, best_global_step, avg_step_time, speed,
result_summary, time.ctime()), log_f)
summary_writer.close()
return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
def _format_results(name, ppl, scores, metrics):
"""Format results."""
result_str = "%s ppl %.2f" % (name, ppl)
if scores:
for metric in metrics:
result_str += ", %s %s %.1f" % (name, metric, scores[metric])
return result_str
def _get_best_results(hparams):
"""Summary of the current best results."""
tokens = []
for metric in hparams.metrics:
tokens.append("%s %.2f" % (metric, getattr(hparams, "best_" + metric)))
return ", ".join(tokens)
def _internal_eval(model, global_step, sess, iterator, iterator_feed_dict,
summary_writer, label):
"""Computing perplexity."""
sess.run(iterator.initializer, feed_dict=iterator_feed_dict)
ppl = model_helper.compute_perplexity(model, sess, label)
utils.add_summary(summary_writer, global_step, "%s_ppl" % label, ppl)
return ppl
def _sample_decode(model, global_step, sess, hparams, iterator, src_data,
tgt_data, iterator_src_placeholder,
iterator_batch_size_placeholder, summary_writer):
"""Pick a sentence and decode."""
decode_id = random.randint(0, len(src_data) - 1)
utils.print_out(" # %d" % decode_id)
iterator_feed_dict = {
iterator_src_placeholder: [src_data[decode_id]],
iterator_batch_size_placeholder: 1,
}
sess.run(iterator.initializer, feed_dict=iterator_feed_dict)
nmt_outputs, attention_summary = model.decode(sess)
if hparams.beam_width > 0:
# get the top translation.
nmt_outputs = nmt_outputs[0]
translation = nmt_utils.get_translation(
nmt_outputs,
sent_id=0,
tgt_eos=hparams.eos,
bpe_delimiter=hparams.bpe_delimiter)
utils.print_out(" src: %s" % src_data[decode_id])
utils.print_out(" ref: %s" % tgt_data[decode_id])
utils.print_out(b" nmt: %s" % translation)
# Summary
if attention_summary is not None:
summary_writer.add_summary(attention_summary, global_step)
def _external_eval(model, global_step, sess, hparams, iterator,
iterator_feed_dict, tgt_file, label, summary_writer,
save_on_best):
"""External evaluation such as BLEU and ROUGE scores."""
out_dir = hparams.out_dir
decode = global_step > 0
if decode:
utils.print_out("# External evaluation, global step %d" % global_step)
sess.run(iterator.initializer, feed_dict=iterator_feed_dict)
output = os.path.join(out_dir, "output_%s" % label)
scores = nmt_utils.decode_and_evaluate(
label,
model,
sess,
output,
ref_file=tgt_file,
metrics=hparams.metrics,
bpe_delimiter=hparams.bpe_delimiter,
beam_width=hparams.beam_width,
tgt_eos=hparams.eos,
decode=decode)
# Save on best metrics
if decode:
for metric in hparams.metrics:
utils.add_summary(summary_writer, global_step, "%s_%s" % (label, metric),
scores[metric])
# metric: larger is better
if save_on_best and scores[metric] > getattr(hparams, "best_" + metric):
setattr(hparams, "best_" + metric, scores[metric])
model.saver.save(
sess,
os.path.join(
getattr(hparams, "best_" + metric + "_dir"), "translate.ckpt"),
global_step=model.global_step)
utils.save_hparams(out_dir, hparams)
return scores