Compatibility of tf.estimator high level API with tf.slim network definition

368 Views Asked by At

I use nasnet.py from tf.slim with the High Level API of tf.estimator. The problem is that the allow growth parameter is not respected, i.e. the entire GPU memory is used

Here is a condensed version of the code (only relevant parts):

from lib.nasnet.nasnet import build_nasnet_mobile

def model_fn(features, labels, mode, params):
    ...
    # build model (based on tf.slim)
    net_out, cells_out = build_nasnet_mobile(
        features, 2, is_training=mode == tf.estimator.ModeKeys.TRAIN)

    predictions = ...
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode,
                                          predictions=predictions)

    loss = ...

    optimizer = tf.train.AdamOptimizer()
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())

    return tf.estimator.EstimatorSpec(loss=loss,
                                      train_op=train_op,
                                      mode=mode)


def main():
    ...
    session_config = tf.ConfigProto()
    session_config.gpu_options.allow_growth = True
    session_config.allow_soft_placement = True

    config = tf.estimator.RunConfig(session_config=session_config)
    estimator = tf.estimator.Estimator(model_fn=model_fn,
                                       model_dir=model_dir,
                                       config=config)

Is there full compatibility with training tf.slim defined networks with tf.estimator or do I have to use the high level API of tf.slim?

0

There are 0 best solutions below