Add sub, neg and log methods to pygrad.tensor
This commit is contained in:
parent
de084e8ce3
commit
12d229e502
|
@ -36,6 +36,16 @@ class Tensor:
|
||||||
tensor._back = back
|
tensor._back = back
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
def sub(self, other):
|
||||||
|
tensor = Tensor(np.add(self.value, other.value))
|
||||||
|
tensor._save(self, other)
|
||||||
|
|
||||||
|
def back(upstream):
|
||||||
|
return np.dot(np.ones_like(self.value).T, upstream), -np.dot(np.ones_like(self.value).T, upstream)
|
||||||
|
|
||||||
|
tensor._back = back
|
||||||
|
return tensor
|
||||||
|
|
||||||
def mul(self, other):
|
def mul(self, other):
|
||||||
tensor = Tensor(np.dot(self.value, other.value))
|
tensor = Tensor(np.dot(self.value, other.value))
|
||||||
tensor._save(self, other)
|
tensor._save(self, other)
|
||||||
|
@ -58,6 +68,16 @@ class Tensor:
|
||||||
tensor._back = back
|
tensor._back = back
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
def neg(self):
|
||||||
|
tensor = Tensor(-self.value)
|
||||||
|
tensor._save(self)
|
||||||
|
|
||||||
|
def back(upstream):
|
||||||
|
return [np.dot(-np.ones_like(self.value), upstream)]
|
||||||
|
|
||||||
|
tensor._back = back
|
||||||
|
return tensor
|
||||||
|
|
||||||
def expt(self, exponent):
|
def expt(self, exponent):
|
||||||
tensor = Tensor(self.value ** exponent)
|
tensor = Tensor(self.value ** exponent)
|
||||||
tensor._save(self)
|
tensor._save(self)
|
||||||
|
@ -91,6 +111,17 @@ class Tensor:
|
||||||
tensor._back = back
|
tensor._back = back
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
def log(self):
|
||||||
|
tensor = Tensor(np.log(self.value))
|
||||||
|
tensor._save(self)
|
||||||
|
|
||||||
|
def back(upstream):
|
||||||
|
a, = tensor._parents
|
||||||
|
return [np.dot(1 / a.value, upstream)]
|
||||||
|
|
||||||
|
tensor._back = back
|
||||||
|
return tensor
|
||||||
|
|
||||||
def tanh(self):
|
def tanh(self):
|
||||||
tensor = Tensor(np.tanh(self.value))
|
tensor = Tensor(np.tanh(self.value))
|
||||||
tensor._save(self)
|
tensor._save(self)
|
||||||
|
|
Loading…
Reference in a new issue