Source code for machine.dataset.fields

import logging

import torchtext


[docs]class SourceField(torchtext.data.Field): """Wrapper class of torchtext.data.Field that forces batch_first and include_lengths to be True. Attributes: eos_id: index of the end of sentence symbol. """ def __init__(self, **kwargs): """Initialize the datafield, but force batch_first and include_lengths to be True, which is required for correct functionality of machine. Also allow to include SOS and EOS symbols for the source sequence. Args: **kwargs: Description """ logger = logging.getLogger(__name__) if kwargs.get('batch_first') is False: logger.warning( "Option batch_first has to be set to use machine. Changed to True.") kwargs['batch_first'] = True if kwargs.get('include_lengths') is False: logger.warning( "Option include_lengths has to be set to use machine. Changed to True.") kwargs['include_lengths'] = True super(SourceField, self).__init__(**kwargs)
[docs] def build_vocab(self, *args, **kwargs): super(SourceField, self).build_vocab(*args, **kwargs)
[docs]class TargetField(torchtext.data.Field): """ Wrapper class of torchtext.data.Field that forces batch_first to be True and prepend <sos> and append <eos> to sequences in preprocessing step. Attributes: sos_id: index of the start of sentence symbol eos_id: index of the end of sentence symbol """ SYM_SOS = '<sos>' SYM_EOS = '<eos>' include_eos = True def __init__(self, include_eos=True, **kwargs): logger = logging.getLogger(__name__) self.include_eos = include_eos if not kwargs.get('batch_first'): logger.warning( "Option batch_first has to be set to use machine. Changed to True.") kwargs['batch_first'] = True if kwargs.get('preprocessing') is None: def func(seq): return seq else: func = kwargs['preprocessing'] if self.include_eos: app_eos = [self.SYM_EOS] else: app_eos = [] kwargs['preprocessing'] = lambda seq: [ self.SYM_SOS] + func(seq) + app_eos self.sos_id = None self.eos_id = None super(TargetField, self).__init__(**kwargs)
[docs] def build_vocab(self, *args, **kwargs): super(TargetField, self).build_vocab(*args, **kwargs) self.sos_id = self.vocab.stoi[self.SYM_SOS] self.eos_id = self.vocab.stoi[self.SYM_EOS]