I'm currently trying to train a Whisper model by following the Fine Tune Whisper Model tutorial. However, during the training phase where I call trainer.train(). I see the progress bar progresses through the training, but when it reaches the evaluation step defined at the training arguments, it will just freeze and the progress bar just stalls up. No error output, no nothing. And it will look like this.
I'm using Kaggle notebooks to write the code with GPU P100 turned on. Here are my training arguments leading up to the training function.
from transformers import WhisperForConditionalGeneration
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.generation_config.language = "en"
from transformers import Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
output_dir="./whisper-small-eng-gen", # change to a repo name of your choice
per_device_train_batch_size=16,
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
learning_rate=1e-5,
warmup_steps=500,
max_steps=1000,
gradient_checkpointing=True,
fp16=True,
evaluation_strategy="steps",
per_device_eval_batch_size=8,
predict_with_generate=True,
generation_max_length=225,
save_steps=1000,
eval_steps=1000,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=True,
ignore_data_skip=True
)
from transformers import Seq2SeqTrainer
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=common_voice_train,
eval_dataset=common_voice_test,
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)
Initially, the
max_stepsfor training is 4000, and it always stalls at step 1001.
I think it is also worth noting that my dataset is streamed, and it is an Iterable Dataset.
Any help is appreciated!
**Update** I edited my code to include verbose logging with
import transformers
transformers.logging.set_verbosity_info()
And this is the log after the evaluation step is reached.
You have passed language=en, but also have set
forced_decoder_idsto [[1, None], [2, 50359]] which creates a conflict.forced_decoder_idswill be ignored in favor of language=en.
