Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
983 views
in Technique[技术] by (71.8m points)

tensorflow - How to make an if statement using a boolean Tensor

How do I make an if statement using a boolean tensor? To be more precise, I'm trying to compare a tensor of size 1 to a constant, checking to see if the value in the tensor is less than the constant. I figured out I have to make the constant its own size 1 tensor and use this method to check if the first tensor is less than the second tensor, but I'm not sure how to make the resulting boolean tensor fit correctly into an if statement. Just putting it in as the query for the if statement makes if statement always return true.

EDIT: This is more or less what the code looked like. However, I'm getting the error 'bool' object has no attribute 'name' regardless of whether it has parameters or not, which makes me think the problem is instead that it's not returning a TensorFlow object.

pred = tf.placeholder(tf.bool)

def if_true(x, y, z):
  #act on x, y, and z
  return True

def if_false():
  return False

# Will be `tf.cond()` in the next release.
from tensorflow.python.ops import control_flow_ops
from functools import partial
x = ...
y = ...
z = ...

result = control_flow_ops.cond(pred, partial(if_true, x, y, z), if_false)
See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

TL;DR: You need to use Session.run() to get a Python boolean, but there are other ways to achieve the same result that might be more efficient.

It looks like you've already figured out how to get a boolean tensor from your value, but for the benefit of other readers, it would look something like this:

computed_val = ...
constant_val = tf.constant(37.0)
pred = tf.less(computed_val, constant_val)  # N.B. Types of the two args must match

The next part is how to use it as a conditional. The simplest thing to do is to use a Python if statement, but to do that you must evaluate the tensor pred using Session.run():

sess = tf.Session()

if sess.run(pred):
  # Do something.
else:
  # Do something else.

One caveat about using a Python if statement is that you have to evaluate the whole expression up to pred, which makes it tricky to reuse intermediate values that have already been computed. I'd like to draw your attention to two other ways you can compute conditional expressions using TensorFlow, which don't require you to evaluate the predicate and get a Python value back.

The first way uses the tf.select() op to conditionally pass through values from two tensors passed as arguments:

pred = tf.placeholder(tf.bool)  # Can be any computed boolean expression.
val_if_true = tf.constant(28.0)
val_if_false = tf.constant(12.0)
result = tf.select(pred, val_if_true, val_if_false)

sess = tf.Session()
sess.run(result, feed_dict={pred: True})   # ==> 28.0
sess.run(result, feed_dict={pred: False})  # ==> 12.0

The tf.select() op works element-wise on all of its arguments, which allows you to combine values from the two input tensors. See its documentation for more details. The drawback of tf.select() is that it evaluates both val_if_true and val_if_false before computing the result, which might be expensive if they are complicated expressions.

The second way uses the tf.cond() op, which conditionally evaluates one of two expressions. This is particularly useful if the expressions are expensive, and it is essential if they have side effects. The basic pattern is to specify two Python functions (or lambda expressions) that build subgraphs that will execute on the true or false branches:

# Define some large matrices
a = ...
b = ...
c = ...

pred = tf.placeholder(tf.bool)

def if_true():
  return tf.matmul(a, b)

def if_false():
  return tf.matmul(b, c)

# Will be `tf.cond()` in the next release.
from tensorflow.python.ops import control_flow_ops

result = tf.cond(pred, if_true, if_false)

sess = tf.Session()
sess.run(result, feed_dict={pred: True})   # ==> executes only (a x b)
sess.run(result, feed_dict={pred: False})  # ==> executes only (b x c)

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...