Tensorflow tf.data.Dataset API profiling

23 Views Asked by At

According to the trace viewer in Tensorboard and https://www.tensorflow.org/guide/data_performance_analysis, my data pipeline is not enough. My first version use the following features of the tf.data API:

  • batch
  • cache
  • prefetch

But when looking at ReduceDatasetOp::DoCompute in the trace viewer, my first version takes too much time (564,664,016 ns > 50 us)).

enter image description here

The second version using interleave gives results that are even worse:

enter image description here

Here is my current code showing the dataset creation and training step (1 epoch):

def process_dataset(                                                               
    X_train,                                                                       
    Y_train,
    batch_size = 200000,                                                           
    shuffle=False,
    reshuffle_each_iteration = False,                                              
):                                                                                 
    
    ds = tf.data.Dataset.from_tensor_slices((X_train, Y_train))                    
    """ Version 1
    ds = ds.batch(batch_size)                                                      
    ds = ds.cache()
    ds = ds.prefetch(tf.data.AUTOTUNE)                                             
    #"""
    #""" Version 2
    ds = tf.data.Dataset.range(4).interleave(lambda _: ds, deterministic=False, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(batch_size)
    ds = ds.cache()
    ds = ds.prefetch(tf.data.AUTOTUNE)
    #"""
    
    n_batch_train = tf.data.experimental.cardinality(ds)
    if shuffle: ds = ds.shuffle(n_batch, reshuffle_each_iteration=reshuffle_each_iteration)
    return ds, n_batch_train 

@tf.function
def _training_step(model, dataset, optimizer):
    lossTotal = 0.0
    for X_batch, Y_batch in dataset:
        Y_pred = model(X_batch)
        loss = mse(Y_batch, Y_pred) 
 
        gradients = tf.gradients(loss, model.trainable_weights)
        optimizer.apply_gradients(zip(gradients, model.trainable_weights))

        lossTotal += loss 
    
    return lossTotal
            

I am open to any suggestion.

Additional information

I ran my custom code on TF2.8 with the following packages:

absl-py==2.0.0
airfrans==0.1.5.1
anyio==4.0.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
astunparse==1.6.3
async-lru==2.0.4
attrs==23.1.0
Babel==2.13.1
beautifulsoup4==4.12.2
bleach==6.1.0
cachetools==4.2.4
certifi==2023.7.22
cffi==1.16.0
charset-normalizer==3.3.2
comm==0.2.0
contourpy==1.2.0
cycler==0.12.1
debugpy==1.8.0
decorator==5.1.1
defusedxml==0.7.1
exceptiongroup==1.1.3
executing==2.0.1
fastjsonschema==2.19.0
filelock==3.13.1
flatbuffers==23.5.26
fonttools==4.44.3
fqdn==1.5.1
fsspec==2023.10.0
gast==0.5.4
google-auth==1.35.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
grpcio==1.59.2
gviz-api==1.10.0
h5py==3.10.0
idna==3.4
importlib-metadata==6.8.0
importlib-resources==6.1.1
ipykernel==6.26.0
ipython==8.17.2
isoduration==20.11.0
jedi==0.19.1
Jinja2==3.1.2
joblib==1.3.2
json5==0.9.14
jsonpointer==2.4
jsonschema==4.20.0
jsonschema-specifications==2023.11.1
jupyter-events==0.9.0
jupyter-lsp==2.2.0
jupyter_client==8.6.0
jupyter_core==5.5.0
jupyter_server==2.10.1
jupyter_server_terminals==0.4.4
jupyterlab==4.0.8
jupyterlab-pygments==0.2.2
jupyterlab_server==2.25.1
keras==2.8.0
Keras-Preprocessing==1.1.2
kiwisolver==1.4.5
libclang==16.0.6
lips @ file:///home/max/Devel/LIPS
llvmlite==0.41.1
Markdown==3.5.1
MarkupSafe==2.1.3
matplotlib==3.8.1
matplotlib-inline==0.1.6
mistune==3.0.2
mpmath==1.3.0
nbclient==0.9.0
nbconvert==7.11.0
nbformat==5.9.2
nest-asyncio==1.5.8
networkx==3.2.1
notebook_shim==0.2.3
numba==0.58.0
numpy==1.21.5
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.2
opt-einsum==3.3.0
overrides==7.4.0
packaging==23.2
pandocfilters==1.5.0
parso==0.8.3
pathlib==1.0.1
pexpect==4.8.0
Pillow==10.1.0
platformdirs==4.0.0
pooch==1.8.0
prometheus-client==0.18.0
prompt-toolkit==3.0.41
protobuf==3.20.1
psutil==5.9.6
ptyprocess==0.7.0
pure-eval==0.2.2
pyasn1==0.5.0
pyasn1-modules==0.3.0
pycparser==2.21
Pygments==2.16.1
pyparsing==3.1.1
python-dateutil==2.8.2
python-json-logger==2.0.7
pyvista==0.42.3
PyYAML==6.0.1
pyzmq==25.1.1
referencing==0.31.0
requests==2.31.0
requests-oauthlib==1.3.1
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.13.0
rsa==4.9
scikit-learn==1.3.2
scipy==1.6.0
scooby==0.9.2
Send2Trash==1.8.2
six==1.16.0
sniffio==1.3.0
soupsieve==2.5
stack-data==0.6.3
sympy==1.12
tensorboard==2.5.0
tensorboard-data-server==0.6.1
tensorboard-plugin-profile==2.15.0
tensorboard-plugin-wit==1.8.1
tensorflow==2.8.0
tensorflow-io-gcs-filesystem==0.34.0
termcolor==2.3.0
terminado==0.18.0
tf-estimator-nightly==2.8.0.dev2021122109
threadpoolctl==3.2.0
tinycss2==1.2.1
tomli==2.0.1
torch==2.1.1
tornado==6.3.3
tqdm==4.66.1
traitlets==5.13.0
triton==2.1.0
types-python-dateutil==2.8.19.14
typing_extensions==4.8.0
uri-template==1.3.0
urllib3==2.1.0
vtk==9.3.0
wcwidth==0.2.10
webcolors==1.13
webencodings==0.5.1
websocket-client==1.6.4
Werkzeug==3.0.1
wrapt==1.16.0

I have the following packages when running Tensorboard:

absl-py==2.0.0
astunparse==1.6.3
cachetools==5.3.2
certifi==2023.11.17
charset-normalizer==3.3.2
flatbuffers==23.5.26
gast==0.5.4
google-auth==2.25.2
google-auth-oauthlib==1.2.0
google-pasta==0.2.0
grpcio==1.60.0
gviz-api==1.10.0
h5py==3.10.0
idna==3.6
importlib-metadata==7.0.0
keras==2.15.0
Keras-Preprocessing==1.1.2
libclang==16.0.6
Markdown==3.5.1
MarkupSafe==2.1.3
ml-dtypes==0.2.0
numpy==1.26.2
oauthlib==3.2.2
opt-einsum==3.3.0
packaging==23.2
protobuf==4.23.4
pyasn1==0.5.1
pyasn1-modules==0.3.0
requests==2.31.0
requests-oauthlib==1.3.1
rsa==4.9
six==1.16.0
tensorboard==2.15.1
tensorboard-data-server==0.7.2
tensorboard-plugin-profile==2.15.0
tensorboard-plugin-wit==1.8.1
tensorflow==2.15.0
tensorflow-estimator==2.15.0
tensorflow-io-gcs-filesystem==0.34.0
termcolor==2.4.0
tf-estimator-nightly==2.8.0.dev2021122109
typing_extensions==4.9.0
urllib3==2.1.0
Werkzeug==3.0.1
wrapt==1.14.1
zipp==3.17.0
0

There are 0 best solutions below