Text Generation consistently results in blank characters

28 Views Asked by At

The following code is showing a very normal loss chart after 20 epochs, but when trying to test it with a seed text it consistently outputs blank lines (' '). Its either that I simply do not understand the process of "text seed" preparation or processing the predictions or something subtle is wrong that I have not been able to notice, hence my appeal to this community.

Loss Function

Data used here is a small portion of the nietzsche.txt located here:

https://s3.amazonaws.com/text-datasets/nietzsche.txt

The main issue is that the predictions always have argmax() = 1. The word dictionary associated with the test data has ' ' at that index, therefore the output is always series of blank characters. The data being fed to the model however does seem reasonable as shown in the x sequence of integers shown in the [OUTPUT] section below. You can see that after every iteration, a new (and different) character is added to the pattern list and the list is sliced to maintain a constant length by dropping the first character.

I have been trying to figure out where I am going wrong in generating non-blank texts and have not been able to identify the issue. I would be grateful for any help.

import sys
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import LSTM
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.utils import to_categorical

# load ascii text and covert to lowercase
filename = "/content/drive/MyDrive/Colab Notebooks/TexGen/nietzsche-short.txt"
raw_text = open(filename, 'r', encoding='utf-8').read()
raw_text = raw_text.lower()

# create mapping of unique chars to integers, and a reverse mapping
chars = sorted(list(set(raw_text)))
char_to_int = dict((c, i) for i, c in enumerate(chars))
int_to_char = dict((i, c) for i, c in enumerate(chars))

# summarize the loaded data
n_chars = len(raw_text)
n_vocab = len(chars)

# prepare the dataset of input to output pairs encoded as integers
seq_length = 100
dataX = []
dataY = []

for i in range(0, n_chars - seq_length, 1):
 seq_in = raw_text[i:i + seq_length]
 seq_out = raw_text[i + seq_length]
 dataX.append([char_to_int[char] for char in seq_in])
 dataY.append(char_to_int[seq_out])

n_patterns = len(dataX)

# reshape X to be [samples, time steps, features]
X = np.reshape(dataX, (n_patterns, seq_length, 1))

# normalize
X = X / float(n_vocab)

# one hot encode the output variable
y = to_categorical(dataY)

# define the LSTM model
model = tf.keras.models.Sequential([
    tf.keras.layers.Embedding(n_vocab, 50, input_length=seq_length),
    tf.keras.layers.Conv1D(128, 5, activation='relu'),  # CNN layer
    tf.keras.layers.MaxPooling1D(pool_size=4),
    tf.keras.layers.LSTM(256, return_sequences=True),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.LSTM(256),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(y.shape[2], activation='softmax')
], name="LSTM_Model")

model.compile(loss='categorical_crossentropy', optimizer='adam')

history = model.fit(X, y,
          epochs=25,
          batch_size=128
)

# Test the model with a seed
start = np.random.randint(0, len(dataX)-1)
pattern = dataX[start] # dataX is a list of list 100 characters each
print("Seed:")
print("\"", ''.join([int_to_char[value] for value in pattern]), "\"")

# seed:
# " ay upon words, a deception on the part of grammar, or an
# audacious generalization of very restricted "

# pattern[:10]
# [13, 37, 1, 33, 28, 27, 26, 1, 35, 27]

# generate characters
for i in range(5):
  x = np.reshape(pattern, (1, len(pattern), 1))
  x = x / float(n_vocab)
  print("\n==================")
  print("x[:10] : ", x[:, :10, :])
  prediction = model.predict(x, verbose=0)
  index = np.argmax(prediction)
  result = int_to_char[index]
  print("index = ", index)
  print("result = ", result)
  # seq_in = [int_to_char[value] for value in pattern]
  #sys.stdout.write(result)
  pattern.append(index)
  pattern = pattern[1:len(pattern)]
print("\nDone.")


[output]

# predictions:

array([[0.01481258, 0.1349412 , 0.00109681, 0.00254168, 0.00037268,
        0.00085235, 0.00087828, 0.01556777, 0.00960505, 0.00228236,
        0.00174035, 0.00123438, 0.00138978, 0.07334321, 0.01076234,
        0.01881236, 0.03297085, 0.0944486 , 0.0203624 , 0.01628518,
        0.04297792, 0.05854145, 0.00125222, 0.00374453, 0.02957868,
        0.0199816 , 0.05518206, 0.05479056, 0.02143795, 0.000657  ,
        0.04608261, 0.06542768, 0.08481915, 0.02293939, 0.00776505,
        0.01505006, 0.00033998, 0.01451185, 0.00062015]], dtype=float32)

==================
x[:10] :  [[[0.33333333]
  [0.94871795]
  [0.02564103]
  [0.84615385]
  [0.71794872]
  [0.69230769]
  [0.66666667]
  [0.02564103]
  [0.8974359 ]
  [0.69230769]]]
index =  1
result =   

==================
x[:10] :  [[[0.94871795]
  [0.02564103]
  [0.84615385]
  [0.71794872]
  [0.69230769]
  [0.66666667]
  [0.02564103]
  [0.8974359 ]
  [0.69230769]
  [0.76923077]]]
index =  1
result =   

==================
x[:10] :  [[[0.02564103]
  [0.84615385]
  [0.71794872]
  [0.69230769]
  [0.66666667]
  [0.02564103]
  [0.8974359 ]
  [0.69230769]
  [0.76923077]
  [0.41025641]]]
index =  1
result =   

==================
x[:10] :  [[[0.84615385]
  [0.71794872]
  [0.69230769]
  [0.66666667]
  [0.02564103]
  [0.8974359 ]
  [0.69230769]
  [0.76923077]
  [0.41025641]
  [0.79487179]]]
index =  1
result =   

==================
x[:10] :  [[[0.71794872]
  [0.69230769]
  [0.66666667]
  [0.02564103]
  [0.8974359 ]
  [0.69230769]
  [0.76923077]
  [0.41025641]
  [0.79487179]
  [0.17948718]]]
index =  1
result =   

Done.
[/output]
model.summary()

Model: "LSTM_Model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 embedding (Embedding)       (None, 100, 50)           1950      
                                                                 
 conv1d (Conv1D)             (None, 96, 128)           32128     
                                                                 
 max_pooling1d (MaxPooling1  (None, 24, 128)           0         
 D)                                                              
                                                                 
 lstm (LSTM)                 (None, 24, 256)           394240    
                                                                 
 dropout (Dropout)           (None, 24, 256)           0         
                                                                 
 lstm_1 (LSTM)               (None, 256)               525312    
                                                                 
 dropout_1 (Dropout)         (None, 256)               0         
                                                                 
 dense (Dense)               (None, 39)                10023     
                                                                 
=================================================================
Total params: 963653 (3.68 MB)
Trainable params: 963653 (3.68 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


  [1]: https://i.stack.imgur.com/QSWz3.jpg
  [2]: https://s3.amazonaws.com/text-datasets/nietzsche.txt
1

There are 1 best solutions below

0
Nader Afshar On

I found the following on tensorflow.org under "generate text with RNNs":

To get actual predictions from the model you need to sample from the output distribution, to get actual character indices. This distribution is defined by the logits over the character vocabulary.

Note: It is important to sample from this distribution as taking the argmax of the distribution can easily get the model stuck in a loop.

sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()

I still cannot quite understand what this is saying, but the code example provided did substantially improve the results.

If anyone can clearly explain how this is working, and exactly why argmax can easily get the model stuck in a loop, I would be grateful.