I have a dataset with 113287 train rows. Each 'caption' field is however an array with multiple strings. I would like to flatmap this array and add new rows.
The documentation for datasets states that the batch mapping feature may be used to achieve this:
This means you can concatenate your examples, divide it up, and even add more examples!
from datasets import load_dataset
dataset_name = "Jotschi/coco-karpathy-opus-de"
coco_dataset = load_dataset(dataset_name)
def chunk_examples(entry):
captions = [caption for caption in entry["caption"][0]]
return {"caption": captions}
print(coco_dataset)
chunked_dataset = coco_dataset.map(chunk_examples, batched=True, num_proc=4,
remove_columns=["image_id", "caption", "image"])
print(chunked_dataset)
print(len(chunked_dataset['train']))
DatasetDict({
train: Dataset({
features: ['caption', 'image_id', 'image'],
num_rows: 113287
})
validation: Dataset({
features: ['caption', 'image_id', 'image'],
num_rows: 5000
})
test: Dataset({
features: ['caption', 'image_id', 'image'],
num_rows: 5000
})
})
DatasetDict({
train: Dataset({
features: ['caption'],
num_rows: 464
})
validation: Dataset({
features: ['caption'],
num_rows: 40
})
test: Dataset({
features: ['caption'],
num_rows: 40
})
})
464
The problem that I'm having is that the resulting dataset does not contain the expected amount of rows.
It states num_rows: 464 have been added. I suspect this to be the batches. How can I normalize this back into a "regular" dataset? Is there something wrong with my mapping function?
- datasets==2.18.0
My mapping function was incorrect. I was only accessing the first entry via
[0].Now it yields: