Alex I Alex I - 16 days ago 14
Python Question

Tensorflow: How to write op with gradient in python?

I would like to write a TensorFlow op in python, but I would like it to be differentiable (to be able to compute a gradient).

This question asks how to write an op in python, and the answer suggests using py_func (which has no gradient): Tensorflow: Writing an Op in Python

The TF documentation describes how to add an op starting from C++ code only:

In my case, I am prototyping so I don't care about whether it runs on GPU, and I don't care about it being usable from anything other than the TF python API.


Here's an example of adding gradient to a specific py_func

Here's the issue discussion