Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
752 views
in Technique[技术] by (71.8m points)

multithreading - Keras Tensorflow - Exception while predicting from multiple threads

I am using keras 2.0.8 with tensorflow 1.3.0 backend.

I am loading a model in the class init and then use it to predict multithreaded.

import tensorflow as tf
from keras import backend as K
from keras.models import load_model


class CNN:
    def __init__(self, model_path):
        self.cnn_model = load_model(model_path)
        self.session = K.get_session()
        self.graph = tf.get_default_graph()

    def query_cnn(self, data):
        X = self.preproccesing(data)
        with self.session.as_default():
            with self.graph.as_default():
                return self.cnn_model.predict(X)

I initialize the CNN once and the query_cnn method happens from multiple threads.

The exception i get in my log is:

  File "/home/*/Similarity/CNN.py", line 43, in query_cnn
    return self.cnn_model.predict(X)
  File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 913, in predict
    return self.model.predict(x, batch_size=batch_size, verbose=verbose)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py", line 1713, in predict
    verbose=verbose, steps=steps)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py", line 1269, in _predict_loop
    batch_outs = f(ins_batch)
  File "/usr/local/lib/python3.5/dist-packages/keras/backend/tensorflow_backend.py", line 2273, in __call__
    **self.session_kwargs)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 895, in run
    run_metadata_ptr)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1124, in _run
    feed_dict_tensor, options, run_metadata)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1321, in _do_run
    options, run_metadata)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1340, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.NotFoundError: PruneForTargets: Some target nodes not found: group_deps 

The code works fine most of the times, its probably some problem with the multithreading.

How can i fix it?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

Make sure you finish the graph creation before creating the other threads.

Calling finalize() on the graph may help you with that.

def __init__(self, model_path):
        self.cnn_model = load_model(model_path)
        self.session = K.get_session()
        self.graph = tf.get_default_graph()
        self.graph.finalize()

Update 1: finalize() will make your graph read-only so it can be safely used in multiple threads. As a side effect, it will help you find unintentional behavior and sometimes memory leaks as it will throw an exception when you try to modify the graph.

Imagine that you have a thread that does for instance one hot encoding of your inputs. (bad example:)

def preprocessing(self, data):
    one_hot_data = tf.one_hot(data, depth=self.num_classes)
    return self.session.run(one_hot_data)

If you print the amount of objects in the graph you will notice that it will increase over time

# amount of nodes in tf graph
print(len(list(tf.get_default_graph().as_graph_def().node)))

But if you define the graph first that won't be the case (slightly better code):

def preprocessing(self, data):
    # run pre-created operation with self.input as placeholder
    return self.session.run(self.one_hot_data, feed_dict={self.input: data})

Update 2: According to this thread you need to call model._make_predict_function() on a keras model before doing multithreading.

Keras builds the GPU function the first time you call predict(). That way, if you never call predict, you save some time and resources. However, the first time you call predict is slightly slower than every other time.

The updated code:

def __init__(self, model_path):
    self.cnn_model = load_model(model_path)
    self.cnn_model._make_predict_function() # have to initialize before threading
    self.session = K.get_session()
    self.graph = tf.get_default_graph() 
    self.graph.finalize() # make graph read-only

Update 3: I did a proof of concept of a warming up, because _make_predict_function() doesn't seems to work as expected. First I created a dummy model:

import tensorflow as tf
from keras.layers import *
from keras.models import *

model = Sequential()
model.add(Dense(256, input_shape=(2,)))
model.add(Dense(1, activation='softmax'))

model.compile(loss='mean_squared_error', optimizer='adam')

model.save("dummymodel")

Then in another script I loaded that model and made it run on multiple threads

import tensorflow as tf
from keras import backend as K
from keras.models import load_model
import threading as t
import numpy as np

K.clear_session()

class CNN:
    def __init__(self, model_path):

        self.cnn_model = load_model(model_path)
        self.cnn_model.predict(np.array([[0,0]])) # warmup
        self.session = K.get_session()
        self.graph = tf.get_default_graph()
        self.graph.finalize() # finalize

    def preproccesing(self, data):
        # dummy
        return data

    def query_cnn(self, data):
        X = self.preproccesing(data)
        with self.session.as_default():
            with self.graph.as_default():
                prediction = self.cnn_model.predict(X)
        print(prediction)
        return prediction


cnn = CNN("dummymodel")

th = t.Thread(target=cnn.query_cnn, kwargs={"data": np.random.random((500, 2))})
th2 = t.Thread(target=cnn.query_cnn, kwargs={"data": np.random.random((500, 2))})
th3 = t.Thread(target=cnn.query_cnn, kwargs={"data": np.random.random((500, 2))})
th4 = t.Thread(target=cnn.query_cnn, kwargs={"data": np.random.random((500, 2))})
th5 = t.Thread(target=cnn.query_cnn, kwargs={"data": np.random.random((500, 2))})
th.start()
th2.start()
th3.start()
th4.start()
th5.start()

th2.join()
th.join()
th3.join()
th5.join()
th4.join()

Commenting the lines for the warmingup and finalize I was able to reproduce your first issue


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...