Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
575 views
in Technique[技术] by (71.8m points)

tensorflow - Can cond support TF ops with side effects?

The (source code) documentation for tf.cond is unclear on whether the functions to be performed when the predicate is evaluated can have side effects or not. I've done some tests but I'm getting conflicting results. For example the code below does not work:

import tensorflow as tf
from tensorflow.python.ops import control_flow_ops

pred = tf.placeholder(tf.bool, [])
count = tf.Variable(0)
adder = count.assign_add(1)
subtractor = count.assign_sub(2)

my_op = control_flow_ops.cond(pred, lambda: adder, lambda: subtractor)

sess = tf.InteractiveSession()
tf.initialize_all_variables().run()

my_op.eval(feed_dict={pred: True})
count.eval() # returns -1

my_op.eval(feed_dict={pred: False})
count.eval() # returns -2

I.e. no matter what value the predicate evaluates to, both functions are getting run, and so the net result is a subtraction of 1. On the other hand, this code snippet does work, where the only difference is that I add new ops to the graph every time my_op is called:

pred = tf.placeholder(tf.bool, [])
count = tf.Variable(0)

my_op = control_flow_ops.cond(pred, lambda:count.assign_add(1), lambda:count.assign_sub(2))

sess = tf.InteractiveSession()
tf.initialize_all_variables().run()

my_op.eval(feed_dict={pred: False})
count.eval() # returns -2

my_op.eval(feed_dict={pred: True})
count.eval() # returns -1

Not sure why creating new ops every time works while the other case doesn't, but I'd obviously rather not be adding nodes as the graph will eventually become too big.

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

Your second version—where the assign_add() and assign_sub() ops are creating inside the lambdas passed to cond()—is the correct way to do this. Fortunately, each of the two lambdas is only evaluated once, during the call to cond(), so your graph will not grow without bound.

Essentially what cond() does is the following:

  1. Create a Switch node, which forwards its input to only one of two outputs, depending on the value of pred. Let's call the outputs pred_true and pred_false. (They have the same value as pred but that's unimportant since this is never directly evaluated.)

  2. Build the subgraph corresponding to the if_true lambda, where all of the nodes have a control dependency on pred_true.

  3. Build the subgraph corresponding to the if_false lambda, where all of the nodes have a control dependency on pred_false.

  4. Zip together the lists of return values from the two lambdas, and create a Merge node for each of these. A Merge node takes two inputs, of which only one is expected to be produced, and forwards it to its output.

  5. Return the tensors that are the outputs of the Merge nodes.

This means you can run your second version, and be content that the graph remains a fixed size, regardless of how many steps you run.

The reason your first version doesn't work is that, when a Tensor is captured (like adder or subtractor in your example), an additional Switch node is added to enforce the logic that the value of the tensor is only forwarded to the branch that actually executes. This is an artifact of how TensorFlow combines feed-forward dataflow and control flow in its execution model. The result is that the captured tensors (in this case the results of the assign_add and assign_sub) will always be evaluated, even if they aren't used, and you'll see their side effects. This is something we need to document better, and as Michael says, we're going to make this more usable in future.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

1.4m articles

1.4m replys

5 comments

56.9k users

...