# This function is copied from tensorflow.contrib.slim.learning. We then add cuda profiler start and stop calls so that nvprof is feasible. import os import time from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import timeline from tensorflow.python.lib.io import file_io from tensorflow.python.platform import tf_logging as logging def train_step(sess, train_op, global_step, train_step_kwargs): """Function that takes a gradient step and specifies whether to stop. Args: sess: The current session. train_op: An `Operation` that evaluates the gradients and returns the total loss. global_step: A `Tensor` representing the global training step. train_step_kwargs: A dictionary of keyword arguments. Returns: The total loss and a boolean indicating whether or not to stop training. Raises: ValueError: if 'should_trace' is in `train_step_kwargs` but `logdir` is not. """ start_time = time.time() trace_run_options = None run_metadata = None if 'should_trace' in train_step_kwargs: if 'logdir' not in train_step_kwargs: raise ValueError('logdir must be present in train_step_kwargs when ' 'should_trace is present') if sess.run(train_step_kwargs['should_trace']): trace_run_options = config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE) run_metadata = config_pb2.RunMetadata() total_loss, np_global_step = sess.run([train_op, global_step], options=trace_run_options, run_metadata=run_metadata) time_elapsed = time.time() - start_time if 'nvprof_on' in train_step_kwargs: import numba.cuda as cuda if np_global_step == train_step_kwargs['nvprof_start_step']: cuda.profile_start() if np_global_step == train_step_kwargs['nvprof_stop_step']: cuda.profile_stop() if run_metadata is not None: tl = timeline.Timeline(run_metadata.step_stats) trace = tl.generate_chrome_trace_format() trace_filename = os.path.join(train_step_kwargs['logdir'], 'tf_trace-%d.json' % np_global_step) logging.info('Writing trace to %s', trace_filename) file_io.write_string_to_file(trace_filename, trace) if 'summary_writer' in train_step_kwargs: train_step_kwargs['summary_writer'].add_run_metadata(run_metadata, 'run_metadata-%d' % np_global_step) if 'should_log' in train_step_kwargs: if sess.run(train_step_kwargs['should_log']): logging.info('global step %d: loss = %.4f (%.3f sec/step)', np_global_step, total_loss, time_elapsed) # TODO(nsilberman): figure out why we can't put this into sess.run. The # issue right now is that the stop check depends on the global step. The # increment of global step often happens via the train op, which used # created using optimizer.apply_gradients. # # Since running `train_op` causes the global step to be incremented, one # would expected that using a control dependency would allow the # should_stop check to be run in the same session.run call: # # with ops.control_dependencies([train_op]): # should_stop_op = ... # # However, this actually seems not to work on certain platforms. if 'should_stop' in train_step_kwargs: should_stop = sess.run(train_step_kwargs['should_stop']) else: should_stop = False return total_loss, should_stop #