Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
416 views
in Technique[技术] by (71.8m points)

TensorFlow: numpy.repeat() alternative

I want to compare the predicted values yp from my neural network in a pairwise fashion, and so I was using (back in my old numpy implementation):

idx = np.repeat(np.arange(len(yp)), len(yp))
jdx = np.tile(np.arange(len(yp)), len(yp))
s = yp[[idx]] - yp[[jdx]]

This basically create a indexing mesh which I then use. idx=[0,0,0,1,1,1,...] while jdx=[0,1,2,0,1,2...]. I do not know if there is a simpler manner of doing it...

Anyhow, TensorFlow has a tf.tile(), but it seems to be lacking a tf.repeat().

idx = np.repeat(np.arange(n), n)
v2 = v[idx]

And I get the error:

TypeError: Bad slice index [  0   0   0 ..., 215 215 215] of type <type 'numpy.ndarray'>

It also does not work to use a TensorFlow constant for the indexing:

idx = tf.constant(np.repeat(np.arange(n), n))
v2 = v[idx]

-

TypeError: Bad slice index Tensor("Const:0", shape=TensorShape([Dimension(46656)]), dtype=int64) of type <class 'tensorflow.python.framework.ops.Tensor'>

The idea is to convert my RankNet implementation to TensorFlow.

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

You can achieve the effect of np.repeat() using a combination of tf.tile() and tf.reshape():

idx = tf.range(len(yp))
idx = tf.reshape(idx, [-1, 1])    # Convert to a len(yp) x 1 matrix.
idx = tf.tile(idx, [1, len(yp)])  # Create multiple columns.
idx = tf.reshape(idx, [-1])       # Convert back to a vector.

You can simply compute jdx using tf.tile():

jdx = tf.range(len(yp))
jdx = tf.tile(jdx, [len(yp)])

For the indexing, you could try using tf.gather() to extract non-contiguous slices from the yp tensor:

s = tf.gather(yp, idx) - tf.gather(yp, jdx)

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...