I have a tensorflow model that uses a class implemnetation. The model was was designed using the functional API but implemented using the class API this was done as it requires a custom train step. The model does use the gpu but the issue lies in how much memory it requires. During the creation of the function diagram it requires around 200 gb of memory while the model in actuality has about 1.5 million parameters. Furthermore the model takes ages to actually start training but sometimes while training, the cpu & gpu load would drop but the memory usage would remain and the model would still 'train'. Final note the model is using unsupervised learning hence why we need a custom train step.
- The machine i am using is a Macbook pro with the M1 pro chip. I have around 16 gb of ram and have swap enabled.
- The tensorflow i am using is 12.13.0.
- The python version is 3.11.4.
the input dimension to the mode is
(16, 387, 826, 1)while the data type istf.float32Here is the repo that i adapted the Class implementation from: https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/TensorFlow/Basics/tutorial15-customizing-modelfit.py
@keras.saving.register_keras_serializable()
class CustomFit(keras.Model):
def __init__(self, model):
super(CustomFit, self).__init__()
self.model = model
def get_config(self):
config = super().get_config().copy()
config.update({
"model": self.model.get_config()
})
return config
def compile(self, optimizer, loss):
super(CustomFit, self).compile()
self.optimizer = optimizer
self.loss = loss
def call(self, image):
return self.model(image)
def train_step(self, image):
"""
Performs a single training step.
args:
image: image to be trained on
returns:
loss: loss value
"""
the model was not using a lot of ram when we were using just the functional api, but once we switched to the class implementation out of necessity the model started using an insane amount of ram. We made sure that the loss function and the train step don't take too much ram and they dont. What we were expecting was that the model to take much less ram and doesn't take soot much time in the making of the function diagram.
Let me start by saying that I'm not that familiar with custom models/training loops. You can have a look at https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch. For speeding up the model, you can add the
@tf.functiondecorator to your training loop like in the example from the link. This will enable graph computation for your model.As for the 200GB, I can't say exactly without seeing the data (processing), but I don't think it is the model, but the data that occupies the RAM. You can take a look at TensorFlow Datasets for optimization. Try out a generator, loading images from folder directly (if you have image data), and data transformation as layers. Batch loading/transformation of the data and prefetching could ease your RAM problem.