# coding=utf-8 # Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Modality base class - defines the bottom and top of the model.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import re # Dependency imports from tensor2tensor.layers import common_layers import tensorflow as tf class Modality(object): """Abstract Modality class for data transformations. An abstract class representing modalities for transforming data to a space interpretable by T2T models. It has 4 functions: * bottom: called on inputs entering the model. * targets_bottom: called on targets entering the model (e.g., the decoder). * top: called on model outputs to generate predictions (e.g., logits). * loss: called on predictions (outputs of top) and targets. For example, think about a modality for images: * `bottom` represents the part of the model applied to an incoming image, e.g., an entry flow of a convolutional network. * `top` represents the top part of a model that is generating images, e.g., a PixelCNN network. * `targets_bottom` represents the auto-regressive part of the network. It is applied to the already-generated part of an image, which is given to the decoder to generate the next part. In some cases, e.g., for text, it is the same as the `bottom` function, and that is the default we use. But, e.g., for images, a different function might be needed to regress properly. * `loss` would compare the generated image to the target image and score it. All the functions have simple and sharded versions. A sub-class only needs to implement the simple version, the default sharding will be used then. """ def __init__(self, model_hparams, vocab_size=None): self._model_hparams = model_hparams self._vocab_size = vocab_size @property def name(self): camelcase_name = type(self).__name__ # DeCamelCase for TF readability. return re.sub("([A-Z]+)", r"_\1", camelcase_name).lower()[1:] @property def top_dimensionality(self): """Integer, the last dimension of the predictions (vocab size).""" raise NotImplementedError("Abstract Method") @property def _body_input_depth(self): return self._model_hparams.hidden_size def bottom(self, x): """Transform one shard of input. Args: x: An int32 Tensor with shape [batch, p0, p1, input_channels] Returns: A float32 Tensor with shape [batch, p0, p1, body_input_depth] """ raise NotImplementedError("Abstract Method") def bottom_sharded(self, xs, data_parallelism): """Transform the inputs. Args: xs: A list of num_datashards Tensors (one per shard) each with shape [batch, p0, p1, depth] data_parallelism: a expert_utils.Parallelism object Returns: shaded_body_input: A list of num_datashards Tensors, each with shape [batch, p0, p1, body_input_depth]. """ return data_parallelism(self.bottom, xs) def targets_bottom(self, x): """Transform one shard of targets. Args: x: An int32 Tensor with shape [batch, p0, p1, target_channels] Returns: A float32 Tensor with shape [batch, p0, p1, body_input_depth] """ with tf.variable_scope("targets_bottom"): return self.bottom(x) def targets_bottom_sharded(self, xs, data_parallelism): """Transform the targets. Args: xs: A list of num_datashards Tensors (one per shard) each with shape [batch, p0, p1, target_channels] data_parallelism: a expert_utils.Parallelism object Returns: shaded_body_input: A list of num_datashards Tensors, each with shape [batch, p0, p1, body_input_depth]. """ return data_parallelism(self.targets_bottom, xs) def top(self, body_output, targets): """Generate predictions/logits for one shard of output. Most classes will override this function. Args: body_output: A Tensor with shape [batch, p0, p1, body_output_depth] targets: A Tensor with shape [batch, p0, p1, targets_channels, top_dimensionality] Returns: A Tensor of class logits. """ raise NotImplementedError("Abstract Method") def top_sharded(self, sharded_body_output, sharded_targets, data_parallelism): """Generate predictions/logits for all shards. Classes with cross-shard interaction will override this function. Args: sharded_body_output: A list of Tensors. sharded_targets: A list of Tensors. data_parallelism: a expert_utils.Parallelism object. Returns: sharded_logits: A list of Tensors. """ return data_parallelism(self.top, sharded_body_output, sharded_targets) def loss(self, top_out, targets, weights_fn=common_layers.weights_nonzero): """Compute loss numerator and denominator for one shard of output.""" logits = top_out return common_layers.padded_cross_entropy( logits, targets, self._model_hparams.label_smoothing, weights_fn=weights_fn) def loss_sharded(self, sharded_top_out, sharded_targets, data_parallelism): """Compute loss for all shards.""" sharded_loss_num, sharded_loss_den = data_parallelism( self.loss, sharded_top_out, sharded_targets) loss = tf.add_n(sharded_loss_num) / tf.maximum(1.0, tf.add_n(sharded_loss_den)) return loss