AttributeError: Can't pickle local object

1.1k Views Asked by At

I'm working on a machine learning university project and I need to save an "agent" (an object) containing some complex stuff that allows me to do other stuff ahahah...I'm using pickle but unfortunately there is an error....AttributeError: Can't pickle local object 'constant_fn.<locals>.func'

this is a piece of my code:


from finrl.agents.stablebaselines3.models import DRLAgent
import pickle
import os

if os.path.isfile("./filename_pi.obj"):
    print("-FILE FOUND-")
    file_pi = open('filename_pi.obj', 'rb')
    trained_a2c = pickle.load(file_pi)
    file_pi.close()
else:
    print("-FILE NOT FOUND-")
    #A2C
    print("Training A2C model")
    agent = DRLAgent(env=env_train)
    model_a2c = agent.get_model("a2c")
    trained_a2c = agent.train_model(model=model_a2c, tb_log_name="a2c", total_timesteps=50000)
    file_pi = open('filename_pi.obj', 'wb') 
    pickle.dump(trained_a2c, file_pi)
    file_pi.close()

Reading similar problems I understood that the problem is in something that is not global, but the problem is that I can not modify anything that is inside .get_model and .train_model because they are methods of a library not written by me and that I can not touch. Is there anything I can do? Maybe I don't have to pass "trained_a2c" ? or you recommend me to change the road?

2

There are 2 best solutions below

1
AudioBubble On

Check this:

from finrl.agents.stablebaselines3.models import DRLAgent
import pickle
import os

if os.path.isfile("./filename_pi.obj"):
    print("-FILE FOUND-")
    file_pi = open('filename_pi.obj', 'rb')
    trained_a2c = pickle.load(file_pi)
    file_pi.close()
else:
    print("-FILE NOT FOUND-")
    #A2C
    print("Training A2C model")
    agent = DRLAgent(env=env_train)
    model_a2c = agent.get_model("a2c")
    trained_a2c = agent.train_model(model=model_a2c, tb_log_name="a2c", total_timesteps=50000)
    file_pi = open('filename_pi.obj', 'wb') 
    pickle.dump(trained_a2c, file_pi)
    file_pi.close()

And this for better design:

from finrl.agents.stablebaselines3.models import DRLAgent
import pickle
import os

def train_a2c():
    #A2C
    print("Training A2C model")
    agent = DRLAgent(env=env_train)
    model_a2c = agent.get_model("a2c")
    trained_a2c = agent.train_model(model=model_a2c, tb_log_name="a2c", total_timesteps=50000)
    return trained_a2c

if os.path.isfile("./trained_a2c.obj"):
    print("-FILE FOUND-")
    file_pi = open('trained_a2c.obj', 'rb')
    trained_a2c = pickle.load(file_pi)
    file_pi.close()
else:
    print("-FILE NOT FOUND-")
    trained_a2c = train_a2c()
    file_pi = open('trained_a2c.obj', 'wb') 
    pickle.dump(trained_a2c, file_pi) 
    file_pi.close()
0
Jack O'Neill On

If you look at the the source code of the library, you see how stored models can be loaded and adapt that to your own needs.

Models from stable-baselines3 can be loaded with modeltype.load(filename) where modeltype is a Model-class from the library, like A2C.

also, make sure to use the save() method provided from stable_baselines to save a trained model, to make sure it is stored properly. Not sure if just using pickle will achieve the same.

from stable_baselines3 import A2C

filename = "my_a2c_model" # don't have to include .zip extension, if using load()

# loading a trained model from file
model = A2C.load(filename)

# train the model again
agent = DRLAgent(env=env_train)
trained_a2c = agent.train_model(model=model, tb_log_name="a2c", total_timesteps=50000)

# saving the new model with the provided save() method from the library:
trained_a2c.save("my_new_model") # will be saved to my_new_model.zip

more information can be found here: