# coding=utf-8 # Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Data generator for Wikipedia title to article dataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os # Dependency imports import bz2file import numpy as np import six from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder from tensor2tensor.utils import metrics from tensor2tensor.utils import registry import tensorflow as tf # End-of-sentence marker. EOS = text_encoder.EOS_ID def _maybe_download_corpus(tmp_dir): """Download corpus if necessary. Args: tmp_dir: directory containing dataset. Returns: filepath of the downloaded corpus file. """ corpus_url = ("https://dumps.wikimedia.org/enwiki/20170620/" "enwiki-20170620-pages-articles-multistream.xml.bz2") corpus_filename = os.path.basename(corpus_url) corpus_filepath = os.path.join(tmp_dir, corpus_filename) if not tf.gfile.Exists(corpus_filepath): generator_utils.maybe_download(tmp_dir, corpus_filename, corpus_url) return corpus_filepath def page_generator(tmp_dir, max_docs=None): doc = u"" count = 0 corpus_filepath = _maybe_download_corpus(tmp_dir) for line in bz2file.BZ2File(corpus_filepath, "r", buffering=1000000): line = unicode(line, "utf-8") if six.PY2 else line.decode("utf-8") if not doc and line != u" \n": continue doc += line if line == u" \n": yield doc doc = u"" count += 1 if max_docs and count >= max_docs: break def _page_title(page): start_pos = page.find(u"") end_pos = page.find(u"") assert start_pos != -1 assert end_pos != -1 start_pos += len(u"") return page[start_pos:end_pos] @registry.register_problem class LanguagemodelWikiFull32k(problem.Text2TextProblem): """A language model on full English Wikipedia.""" @property def is_character_level(self): return False @property def has_inputs(self): return True @property def input_space_id(self): return problem.SpaceID.EN_TOK @property def target_space_id(self): return problem.SpaceID.EN_TOK @property def num_shards(self): return 1000 @property def vocab_name(self): return "vocab.wiki" @property def use_subword_tokenizer(self): return True @property def targeted_vocab_size(self): return 2**15 # 32768 @property def use_train_shards_for_dev(self): return True def generator(self, data_dir, tmp_dir, _): encoder = generator_utils.get_or_generate_vocab_inner( data_dir, self.vocab_file, self.targeted_vocab_size, page_generator(tmp_dir, max_docs=10000)) for page in page_generator(tmp_dir): title = _page_title(page) encoded = encoder.encode(page) + [EOS] encoded_title = encoder.encode(title) + [EOS] yield {"inputs": encoded_title, "targets": encoded} class LanguagemodelWikiScramble(problem.Text2TextProblem): """Language modeling on English wikipedia. "targets" is a sequence of sequence_length tokens - a fragment of an article. "inputs" is a copy of "targets", but with a random scramble_fraction of the tokens randomly permuted. This dataset is intended to test parallel (non-autoregressive) prediction of the target sequence given the input sequence. """ @property def sequence_length(self): raise NotImplementedError() @property def scramble_fraction(self): raise NotImplementedError() @property def is_character_level(self): return False @property def has_inputs(self): return True @property def input_space_id(self): return problem.SpaceID.EN_TOK @property def target_space_id(self): return problem.SpaceID.EN_TOK @property def num_shards(self): return 1000 @property def vocab_name(self): return "vocab.wiki" @property def use_subword_tokenizer(self): return True @property def targeted_vocab_size(self): return 2**13 # 8192 @property def use_train_shards_for_dev(self): return True @property def max_cases(self): return (2 ** 30) / self.sequence_length def scramble(self, seq): seq = np.array(seq) num_permute = int(self.sequence_length * self.scramble_fraction) full_permutation = np.random.permutation(self.sequence_length) inverse_full_permutation = np.argsort(full_permutation) partial_permutation = np.random.permutation(num_permute) seq = seq[full_permutation] seq = np.concatenate( (seq[:num_permute][partial_permutation], seq[num_permute:])) seq = seq[inverse_full_permutation] seq = list(seq) return seq def generator(self, data_dir, tmp_dir, _): encoder = generator_utils.get_or_generate_vocab_inner( data_dir, self.vocab_file, self.targeted_vocab_size, page_generator(tmp_dir, max_docs=1000)) case_num = 0 for page in page_generator(tmp_dir): encoded = encoder.encode(page) for i in xrange(len(encoded) // self.sequence_length): case_num += 1 if self.max_cases and case_num > self.max_cases: return targets = encoded[ i * self.sequence_length:(i + 1) * self.sequence_length] inputs = self.scramble(targets) yield {"inputs": inputs, "targets": targets} def eval_metrics(self): return [ metrics.Metrics.ACC, metrics.Metrics.NEG_LOG_PERPLEXITY ] @registry.register_problem class LanguagemodelWikiScramble128(LanguagemodelWikiScramble): """Sequence length 128, 50% scrambed.""" @property def sequence_length(self): return 128 @property def scramble_fraction(self): return 0.5 @registry.register_problem class LanguagemodelWikiScramble1k50(LanguagemodelWikiScramble): """Sequence length 1024, 50% scrambed.""" @property def sequence_length(self): return 1024 @property def scramble_fraction(self): return 0.5 @registry.register_problem class LanguagemodelWikiScramble8k50(LanguagemodelWikiScramble): """Sequence length 8192, 50% scrambed.""" @property def sequence_length(self): return 8192 @property def scramble_fraction(self): return 0.5