From 12d229e50228602351edc5a2ab4b9c138a4d6879 Mon Sep 17 00:00:00 2001 From: aodhneine Date: Mon, 16 Nov 2020 00:44:12 +0000 Subject: [PATCH] Add sub, neg and log methods to pygrad.tensor --- pygrad/tensor.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/pygrad/tensor.py b/pygrad/tensor.py index 400a2a8..00d94ec 100644 --- a/pygrad/tensor.py +++ b/pygrad/tensor.py @@ -36,6 +36,16 @@ class Tensor: tensor._back = back return tensor + def sub(self, other): + tensor = Tensor(np.add(self.value, other.value)) + tensor._save(self, other) + + def back(upstream): + return np.dot(np.ones_like(self.value).T, upstream), -np.dot(np.ones_like(self.value).T, upstream) + + tensor._back = back + return tensor + def mul(self, other): tensor = Tensor(np.dot(self.value, other.value)) tensor._save(self, other) @@ -58,6 +68,16 @@ class Tensor: tensor._back = back return tensor + def neg(self): + tensor = Tensor(-self.value) + tensor._save(self) + + def back(upstream): + return [np.dot(-np.ones_like(self.value), upstream)] + + tensor._back = back + return tensor + def expt(self, exponent): tensor = Tensor(self.value ** exponent) tensor._save(self) @@ -91,6 +111,17 @@ class Tensor: tensor._back = back return tensor + def log(self): + tensor = Tensor(np.log(self.value)) + tensor._save(self) + + def back(upstream): + a, = tensor._parents + return [np.dot(1 / a.value, upstream)] + + tensor._back = back + return tensor + def tanh(self): tensor = Tensor(np.tanh(self.value)) tensor._save(self)