Save load and retrain a tensorflow model for machine translation

150 Views Asked by At

I've been trying train a model for machine translation. It worked pretty fine when I trained it for 10 epochs at a time and tested it. But when I try to train it for 1 epoch at a time, save and load it to continue from where I left earlier it gives some errors.

import tensorflow as tf
import einops

import numpy as np
import os
import tensorflow as tf
import tensorflow_text as tf_text
import pathlib

from keras.layers import TextVectorization


class ShapeChecker():
    def __init__(self):
        self.shapes = {}

    def __call__(self, tensor, names, broadcast=False):
        if not tf.executing_eagerly():
            return

        parsed = einops.parse_shape(tensor, names)

        for name, new_dim in parsed.items():
            old_dim = self.shapes.get(name, None)

            if broadcast and new_dim == 1:
                continue

            if old_dim is None:
                self.shapes[name] = new_dim
                continue

            if new_dim != old_dim:
                raise ValueError(f"Shape mismatch for dimension: '{name}'\n"
                                 f"    found: {new_dim}\n"
                                 f"    expected: {old_dim}\n")


class Encoder(tf.keras.layers.Layer):
    def __init__(self, text_processor, units):
        super(Encoder, self).__init__()
        self.text_processor = text_processor
        self.vocab_size = text_processor.vocabulary_size()
        self.units = units

        self.embedding = tf.keras.layers.Embedding(self.vocab_size, units,
                                                   mask_zero=True)

        self.rnn = tf.keras.layers.Bidirectional(
            merge_mode='sum',
            layer=tf.keras.layers.GRU(units,
                                      return_sequences=True,
                                      recurrent_initializer='glorot_uniform'))

    def call(self, x):
        shape_checker = ShapeChecker()
        shape_checker(x, 'batch s')

        x = self.embedding(x)
        shape_checker(x, 'batch s units')

        x = self.rnn(x)
        shape_checker(x, 'batch s units')

        return x

    def convert_input(self, texts):
        texts = tf.convert_to_tensor(texts)
        if len(texts.shape) == 0:
            texts = tf.convert_to_tensor(texts)[tf.newaxis]
        context = self.text_processor(texts).to_tensor()
        context = self(context)
        return context


class CrossAttention(tf.keras.layers.Layer):
    def __init__(self, units, **kwargs):
        super().__init__()
        self.mha = tf.keras.layers.MultiHeadAttention(key_dim=units, num_heads=1, **kwargs)
        self.layernorm = tf.keras.layers.LayerNormalization()
        self.add = tf.keras.layers.Add()

    def call(self, x, context):
        shape_checker = ShapeChecker()

        shape_checker(x, 'batch t units')
        shape_checker(context, 'batch s units')

        attn_output, attn_scores = self.mha(
            query=x,
            value=context,
            return_attention_scores=True)

        shape_checker(x, 'batch t units')
        shape_checker(attn_scores, 'batch heads t s')

        attn_scores = tf.reduce_mean(attn_scores, axis=1)
        shape_checker(attn_scores, 'batch t s')
        self.last_attention_weights = attn_scores

        x = self.add([x, attn_output])
        x = self.layernorm(x)

        return x


class Decoder(tf.keras.layers.Layer):
    @classmethod
    def add_method(cls, fun):
        setattr(cls, fun.__name__, fun)
        return fun

    def __init__(self, text_processor, units):
        super(Decoder, self).__init__()
        self.text_processor = text_processor
        self.vocab_size = text_processor.vocabulary_size()
        self.word_to_id = tf.keras.layers.StringLookup(
            vocabulary=text_processor.get_vocabulary(),
            mask_token='', oov_token='[UNK]')
        self.id_to_word = tf.keras.layers.StringLookup(
            vocabulary=text_processor.get_vocabulary(),
            mask_token='', oov_token='[UNK]',
            invert=True)
        self.start_token = self.word_to_id('[START]')
        self.end_token = self.word_to_id('[END]')

        self.units = units

        self.embedding = tf.keras.layers.Embedding(self.vocab_size, units, mask_zero=True)

        self.rnn = tf.keras.layers.GRU(units,
                                       return_sequences=True,
                                       return_state=True,
                                       recurrent_initializer='glorot_uniform')

        self.attention = CrossAttention(units)

        self.output_layer = tf.keras.layers.Dense(self.vocab_size)


@Decoder.add_method
def call(self,
         context, x,
         state=None,
         return_state=False):
    shape_checker = ShapeChecker()
    shape_checker(x, 'batch t')
    shape_checker(context, 'batch s units')

    x = self.embedding(x)
    shape_checker(x, 'batch t units')

    x, state = self.rnn(x, initial_state=state)
    shape_checker(x, 'batch t units')

    x = self.attention(x, context)
    self.last_attention_weights = self.attention.last_attention_weights
    shape_checker(x, 'batch t units')
    shape_checker(self.last_attention_weights, 'batch t s')

    logits = self.output_layer(x)
    shape_checker(logits, 'batch t target_vocab_size')

    if return_state:
        return logits, state
    else:
        return logits


@Decoder.add_method
def get_initial_state(self, context):
    batch_size = tf.shape(context)[0]
    start_tokens = tf.fill([batch_size, 1], self.start_token)
    done = tf.zeros([batch_size, 1], dtype=tf.bool)
    embedded = self.embedding(start_tokens)
    return start_tokens, done, self.rnn.get_initial_state(embedded)[0]


@Decoder.add_method
def tokens_to_text(self, tokens):
    words = self.id_to_word(tokens)
    result = tf.strings.reduce_join(words, axis=-1, separator=' ')
    result = tf.strings.regex_replace(result, '^ *\[START\] *', '')
    result = tf.strings.regex_replace(result, ' *\[END\] *$', '')
    return result


@Decoder.add_method
def get_next_token(self, context, next_token, done, state, temperature=0.0):
    logits, state = self(
        context, next_token,
        state=state,
        return_state=True)

    if temperature == 0.0:
        next_token = tf.argmax(logits, axis=-1)
    else:
        logits = logits[:, -1, :] / temperature
        next_token = tf.random.categorical(logits, num_samples=1)

    done = done | (next_token == self.end_token)
    next_token = tf.where(done, tf.constant(0, dtype=tf.int64), next_token)

    return next_token, done, state


class Translator(tf.keras.Model):
    @classmethod
    def add_method(cls, fun):
        setattr(cls, fun.__name__, fun)
        return fun

    def __init__(self, units,
                 context_text_processor,
                 target_text_processor):
        super().__init__()

        self.ctp = context_text_processor
        self.ttp = target_text_processor

        encoder = Encoder(context_text_processor, units)
        decoder = Decoder(target_text_processor, units)

        self.encoder = encoder
        self.decoder = decoder

    def call(self, inputs):
        context, x = inputs
        context = self.encoder(context)
        logits = self.decoder(context, x)

        return logits

    def get_config(self):
        return {
            "context": self.ctp,
            "target": self.ttp
        }



path_to_zip = tf.keras.utils.get_file(
    'spa-eng.zip', origin='http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip',
    extract=True)

path_to_file = pathlib.Path(path_to_zip).parent / 'spa-eng/spa.txt'


def load_data(path):
    text = path.read_text(encoding='utf-8')
    lines = text.splitlines()

    pairs = [line.split('\t') for line in lines]

    context = np.array([context for target, context in pairs])
    target = np.array([target for target, context in pairs])

    return target, context


target_raw, context_raw = load_data(path_to_file)

buffer_size = len(context_raw)
batch_size = 64

is_train = np.random.uniform(size=(len(target_raw),)) < .8

train_raw = (
    tf.data.Dataset
    .from_tensor_slices((context_raw[is_train], target_raw[is_train]))
    .shuffle(buffer_size)
    .batch(batch_size)
)

val_raw = (
    tf.data.Dataset
    .from_tensor_slices((context_raw[~is_train], target_raw[~is_train]))
    .shuffle(buffer_size)
    .batch(batch_size)
)

for example_context_strings, example_target_strings in train_raw.take(1):
    print()
    break

example_text = tf.constant('¿Todavía está en casa?')


def tf_lower_and_split_punctuation(text):
    # Split accented characters.
    text = tf_text.normalize_utf8(text, 'NFKD')
    text = tf.strings.lower(text)
    # Keep space, a to z, and select punctuation.
    text = tf.strings.regex_replace(text, '[^ a-z.?!,¿]', '')
    # Add spaces around punctuation.
    text = tf.strings.regex_replace(text, '[.?!,¿]', r' \0 ')
    # Strip whitespace.
    text = tf.strings.strip(text)

    text = tf.strings.join(['[START]', text, '[END]'], separator=' ')
    return text


max_vocab_size = 5000
context_text_processor = TextVectorization(
    standardize=tf_lower_and_split_punctuation,
    max_tokens=max_vocab_size,
    ragged=True
)

context_text_processor.adapt(train_raw.map(lambda context, target: context))

target_text_processor = TextVectorization(
    standardize=tf_lower_and_split_punctuation,
    max_tokens=max_vocab_size,
    ragged=True
)

target_text_processor.adapt(train_raw.map(lambda context, target: target))

example_tokens = context_text_processor(example_context_strings)

context_vocab = np.array(context_text_processor.get_vocabulary())
tokens = context_vocab[example_tokens[0].numpy()]
' '.join(tokens)


def process_text(context, target):
    context = context_text_processor(context).to_tensor()
    target = target_text_processor(target)
    targ_in = target[:, :-1].to_tensor()
    targ_out = target[:, 1:].to_tensor()
    return (context, targ_in), targ_out


train_ds = train_raw.map(process_text, tf.data.AUTOTUNE)
val_ds = val_raw.map(process_text, tf.data.AUTOTUNE)

for (ex_context_tok, ex_tar_in), ex_tar_out in train_ds.take(1):
    print()

def masked_loss(y_true, y_pred):
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
    loss = loss_fn(y_true, y_pred)

    mask = tf.cast(y_true != 0, loss.dtype)
    loss *= mask

    return tf.reduce_sum(loss) / tf.reduce_sum(mask)


def masked_acc(y_true, y_pred):
    y_pred = tf.argmax(y_pred, axis=-1)
    y_pred = tf.cast(y_pred, y_true.dtype)

    match = tf.cast(y_true == y_pred, tf.float32)
    mask = tf.cast(y_true != 0, tf.float32)

    return tf.reduce_sum(match) / tf.reduce_sum(mask)


UNITS = 256

model = Translator(UNITS, context_text_processor, target_text_processor)
model.compile(
    optimizer='adam',
    loss=masked_loss,
    metrics=[masked_acc, masked_loss]
)
vocab_size = 1.0 * target_text_processor.vocabulary_size()

model_path = "./Saved Model"

initial_epoch = 0

os.makedirs(model_path, exist_ok=True)

for (dir_path, dir_names, filenames) in os.walk(model_path):

    if len(dir_names) != 0:
        dir_names.sort()

        initial_epoch = int(dir_names[-1])

        model = tf.keras.models.load_model(os.path.join(model_path, dir_names[-1]))
    else:
        model.compile(
            optimizer='adam',
            loss=masked_loss,
            metrics=[masked_acc, masked_loss]
        )
    break

history = model.fit(
    train_ds.repeat(),
    initial_epoch=initial_epoch,
    epochs=initial_epoch + 1,
    steps_per_epoch=100,
    validation_data=val_ds,
    validation_steps=20,
)

model.save(os.path.join(model_path, f'{initial_epoch + 1:02d}'))

This is not even my code. It's given in tensorflow doc. I just modified it to train for multiple epochs separately.

When I try to train, first epoch runs smoothly. But while trying to load to train for 2nd epoch, the following error message is shown:

RuntimeError: Unable to restore object of class 'TextVectorization'. One of several possible causes could be a missing custom object. Decorate your custom object with @keras.utils.register_keras_serializable and include that file in your program, or pass your class in a keras.utils.CustomObjectScope that wraps this load call.

Exception: Error when deserializing class 'TextVectorization' using config={'name': 'text_vectorization', 'trainable': True, 'dtype': 'string', 'batch_input_shape': (None,), 'max_tokens': 5000, 'standardize': 'tf_lower_and_split_punctuation', 'split': 'whitespace', 'ngrams': None, 'output_mode': 'int', 'output_sequence_length': None, 'pad_to_max_tokens': False, 'sparse': False, 'ragged': True, 'vocabulary': None, 'idf_weights': None, 'encoding': 'utf-8', 'vocabulary_size': 5000, 'has_input_vocabulary': False}.

Exception encountered: Unkown value for standardize argument of layer TextVectorization. If restoring a model and standardize is a custom callable, please ensure the callable is registered as a custom object. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details. Allowed values are: None, a Callable, or one of the following values: ('lower_and_strip_punctuation', 'lower', 'strip_punctuation'). Received: tf_lower_and_split_punctuation

I've already tried this code for training. It works. But the accuracy of the model drops to ~45%. When I trained for 10 consecutive epochs (without saving and loading) it was ~70%.

initial_epoch = 0

checkpoint_path = "./checkpoints/train"
ckpt = tf.train.Checkpoint(model=model, optimizer=model.optimizer.get_config())
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=10)

# Restore the latest checkpoint if it exists
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print("Latest checkpoint restored!")

    # Update the current_epoch variable
    initial_epoch = int(ckpt_manager.latest_checkpoint.split("-")[-1])

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=10)

history = model.fit(
    train_ds.repeat(),
    initial_epoch=initial_epoch,
    epochs=initial_epoch + 1,
    steps_per_epoch=100,
    validation_data=val_ds,
    validation_steps=20,
)

ckpt_manager.save()

Why is this happening?

0

There are 0 best solutions below