Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
765 views
in Technique[技术] by (71.8m points)

python - How to get weights from .pb model in Tensorflow

I trained one model and then create one .pb file by freeze that model. so, my question is how to get weights from .pb file or i have to do more process for get weights

@mrry, please guide me.

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

Let us first load the graph from .pb file.

import tensorflow as tf
from tensorflow.python.platform import gfile

GRAPH_PB_PATH = './model/tensorflow_inception_v3_stripped_optimized_quantized.pb' #path to your .pb file
with tf.Session(config=config) as sess:
  print("load graph")
  with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')
    graph_nodes=[n for n in graph_def.node]

Now when you freeze a graph to .pb file your variables are converted to Const type and the weights which were trainabe variables would also be stored as Const in .pb file. graph_nodes contains all the nodes in graph. But we are interested in all the Const type nodes.

wts = [n for n in graph_nodes if n.op=='Const']

Each element of wts is of NodeDef type. It has several atributes such as name, op etc. The values can be extracted as follows -

from tensorflow.python.framework import tensor_util

for n in wts:
    print "Name of the node - %s" % n.name
    print "Value - " 
    print tensor_util.MakeNdarray(n.attr['value'].tensor)

Hope this solves your concern.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...