Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
259 views
in Technique[技术] by (71.8m points)

python - Save and load model optimizer state

I have a set of fairly complicated models that I am training and I am looking for a way to save and load the model optimizer states. The "trainer models" consist of different combinations of several other "weight models", of which some have shared weights, some have frozen weights depending on the trainer, etc. It is a bit too complicated of an example to share, but in short, I am not able to use model.save('model_file.h5') and keras.models.load_model('model_file.h5') when stopping and starting my training.

Using model.load_weights('weight_file.h5') works fine for testing my model if the training has finished, but if I attempt to continue training the model using this method, the loss does not come even close to returning to its last location. I have read that this is because the optimizer state is not saved using this method which makes sense. However, I need a method for saving and loading the states of the optimizers of my trainer models. It seems as though keras once had a model.optimizer.get_sate() and model.optimizer.set_sate() that would accomplish what I am after, but that does not seem to be the case anymore (at least for the Adam optimizer). Are there any other solutions with the current Keras?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

You can extract the important lines from the load_model and save_model functions.

For saving optimizer states, in save_model:

# Save optimizer weights.
symbolic_weights = getattr(model.optimizer, 'weights')
if symbolic_weights:
    optimizer_weights_group = f.create_group('optimizer_weights')
    weight_values = K.batch_get_value(symbolic_weights)

For loading optimizer states, in load_model:

# Set optimizer weights.
if 'optimizer_weights' in f:
    # Build train function (to get weight updates).
    if isinstance(model, Sequential):
        model.model._make_train_function()
    else:
        model._make_train_function()

    # ...

    try:
        model.optimizer.set_weights(optimizer_weight_values)

Combining the lines above, here's an example:

  1. First fit the model for 5 epochs.
X, y = np.random.rand(100, 50), np.random.randint(2, size=100)
x = Input((50,))
out = Dense(1, activation='sigmoid')(x)
model = Model(x, out)
model.compile(optimizer='adam', loss='binary_crossentropy')
model.fit(X, y, epochs=5)

Epoch 1/5
100/100 [==============================] - 0s 4ms/step - loss: 0.7716
Epoch 2/5
100/100 [==============================] - 0s 64us/step - loss: 0.7678
Epoch 3/5
100/100 [==============================] - 0s 82us/step - loss: 0.7665
Epoch 4/5
100/100 [==============================] - 0s 56us/step - loss: 0.7647
Epoch 5/5
100/100 [==============================] - 0s 76us/step - loss: 0.7638
  1. Now save the weights and optimizer states.
model.save_weights('weights.h5')
symbolic_weights = getattr(model.optimizer, 'weights')
weight_values = K.batch_get_value(symbolic_weights)
with open('optimizer.pkl', 'wb') as f:
    pickle.dump(weight_values, f)
  1. Rebuild the model in another python session, and load weights.
x = Input((50,))
out = Dense(1, activation='sigmoid')(x)
model = Model(x, out)
model.compile(optimizer='adam', loss='binary_crossentropy')

model.load_weights('weights.h5')
model._make_train_function()
with open('optimizer.pkl', 'rb') as f:
    weight_values = pickle.load(f)
model.optimizer.set_weights(weight_values)
  1. Continue model training.
model.fit(X, y, epochs=5)

Epoch 1/5
100/100 [==============================] - 0s 674us/step - loss: 0.7629
Epoch 2/5
100/100 [==============================] - 0s 49us/step - loss: 0.7617
Epoch 3/5
100/100 [==============================] - 0s 49us/step - loss: 0.7611
Epoch 4/5
100/100 [==============================] - 0s 55us/step - loss: 0.7601
Epoch 5/5
100/100 [==============================] - 0s 49us/step - loss: 0.7594

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...