Cfis Yoi Cfis Yoi - 1 month ago 55
Python Question

Use tf.scatter_update in a two dimensional tf.Variable

I' m following this Manipulating matrix elements in tensorflow. using tf.scatter_update. But my problem is:
What happens if my tf.Variable is 2D? Let's say:

a = tf.Variable(initial_value=[[0, 0, 0, 0],[0, 0, 0, 0]])


How can i update for example the first element of every row and assign to that the value 1?

I tried something like

for line in range(2):
sess.run(tf.scatter_update(a[line],[0],[1]))


but it fails (i was expecting that) and gives me the error:


TypeError: Input 'ref' of 'ScatterUpdate' Op requires l-value input


How can i fix that kind of problems?

`

Answer

In tensorflow you cannot update a Tensor but you can update a Variable.

The scatter_update operator can update only the first dimension of the variable. You have to pass always a reference tensor to the scatter update (a instead of a[line]).

This is how you can update the first element of the variable:

import tensorflow as tf

g = tf.Graph()
with g.as_default():
    a = tf.Variable(initial_value=[[0, 0, 0, 0],[0, 0, 0, 0]])
    b = tf.scatter_update(a, [0, 1], [[1, 0, 0, 0], [1, 0, 0, 0]])

with tf.Session(graph=g) as sess:
   sess.run(tf.initialize_all_variables())
   print sess.run(a)
   print sess.run(b)

Output:

[[0 0 0 0]
 [0 0 0 0]]
[[1 0 0 0]
 [1 0 0 0]]

But having to change again the whole tensor it might be faster to just assign a completely new one.

Comments