diff --git a/pygrad/tensor.py b/pygrad/tensor.py index 0c0e9e9..0f0cc8b 100644 --- a/pygrad/tensor.py +++ b/pygrad/tensor.py @@ -92,6 +92,18 @@ class Tensor: tensor._back = back return tensor + def tanh(self): + tensor = Tensor(np.tanh(self.value)) + tensor._save(self) + + def back(upstream): + # dtanh(x)/dx = 1 - tanh2(x) + a, = tensor._parents + return [1 - np.dot(np.tanh(a.value) ** 2, upstream)] + + tensor._back = back + return tensor + # TODO Compute gradients only for tensors that need it. def _backprop(self, upstream): # Backprop through the tensor iff it has any parents.