Compare commits
2 commits
de084e8ce3
...
d31a72374d
Author | SHA1 | Date | |
---|---|---|---|
Aodhnait Étaín | d31a72374d | ||
Aodhnait Étaín | 12d229e502 |
|
@ -7,6 +7,8 @@ class Tensor:
|
|||
# TODO Implement 'requires_grad' functionality.
|
||||
def __init__(self, value):
|
||||
# TODO Add support for scalar values.
|
||||
if isinstance(value, list):
|
||||
value = np.array(value)
|
||||
if not isinstance(value, np.ndarray):
|
||||
print(f"{type(value)} is not compatible with {np.ndarray}")
|
||||
exit(-1)
|
||||
|
@ -36,6 +38,16 @@ class Tensor:
|
|||
tensor._back = back
|
||||
return tensor
|
||||
|
||||
def sub(self, other):
|
||||
tensor = Tensor(np.add(self.value, other.value))
|
||||
tensor._save(self, other)
|
||||
|
||||
def back(upstream):
|
||||
return np.dot(np.ones_like(self.value).T, upstream), -np.dot(np.ones_like(self.value).T, upstream)
|
||||
|
||||
tensor._back = back
|
||||
return tensor
|
||||
|
||||
def mul(self, other):
|
||||
tensor = Tensor(np.dot(self.value, other.value))
|
||||
tensor._save(self, other)
|
||||
|
@ -58,6 +70,16 @@ class Tensor:
|
|||
tensor._back = back
|
||||
return tensor
|
||||
|
||||
def neg(self):
|
||||
tensor = Tensor(-self.value)
|
||||
tensor._save(self)
|
||||
|
||||
def back(upstream):
|
||||
return [np.dot(-np.ones_like(self.value), upstream)]
|
||||
|
||||
tensor._back = back
|
||||
return tensor
|
||||
|
||||
def expt(self, exponent):
|
||||
tensor = Tensor(self.value ** exponent)
|
||||
tensor._save(self)
|
||||
|
@ -91,6 +113,17 @@ class Tensor:
|
|||
tensor._back = back
|
||||
return tensor
|
||||
|
||||
def log(self):
|
||||
tensor = Tensor(np.log(self.value))
|
||||
tensor._save(self)
|
||||
|
||||
def back(upstream):
|
||||
a, = tensor._parents
|
||||
return [np.dot(1 / a.value, upstream)]
|
||||
|
||||
tensor._back = back
|
||||
return tensor
|
||||
|
||||
def tanh(self):
|
||||
tensor = Tensor(np.tanh(self.value))
|
||||
tensor._save(self)
|
||||
|
|
Loading…
Reference in a new issue