Tensorflow Model is Using Too Much Ram

929 Views Asked by At

I have a tensorflow model that uses a class implemnetation. The model was was designed using the functional API but implemented using the class API this was done as it requires a custom train step. The model does use the gpu but the issue lies in how much memory it requires. During the creation of the function diagram it requires around 200 gb of memory while the model in actuality has about 1.5 million parameters. Furthermore the model takes ages to actually start training but sometimes while training, the cpu & gpu load would drop but the memory usage would remain and the model would still 'train'. Final note the model is using unsupervised learning hence why we need a custom train step.

- The machine i am using is a Macbook pro with the M1 pro chip. I have around 16 gb of ram and have swap enabled.

- The tensorflow i am using is 12.13.0.

- The python version is 3.11.4.

@keras.saving.register_keras_serializable()
class CustomFit(keras.Model):
    def __init__(self, model):
        super(CustomFit, self).__init__()
        self.model = model

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            "model": self.model.get_config()
        })
        return config

    def compile(self, optimizer, loss):
        super(CustomFit, self).compile()
        self.optimizer = optimizer
        self.loss = loss

    def call(self, image):
        return self.model(image)

    def train_step(self, image):
        """
            Performs a single training step.
            args:
                image: image to be trained on
            returns:
                loss: loss value
        """

the model was not using a lot of ram when we were using just the functional api, but once we switched to the class implementation out of necessity the model started using an insane amount of ram. We made sure that the loss function and the train step don't take too much ram and they dont. What we were expecting was that the model to take much less ram and doesn't take soot much time in the making of the function diagram.

1

There are 1 best solutions below

4
mhenning On

Let me start by saying that I'm not that familiar with custom models/training loops. You can have a look at https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch. For speeding up the model, you can add the @tf.function decorator to your training loop like in the example from the link. This will enable graph computation for your model.
As for the 200GB, I can't say exactly without seeing the data (processing), but I don't think it is the model, but the data that occupies the RAM. You can take a look at TensorFlow Datasets for optimization. Try out a generator, loading images from folder directly (if you have image data), and data transformation as layers. Batch loading/transformation of the data and prefetching could ease your RAM problem.