Compare commits

..

No commits in common. "f8101400c2a4964302d59bc75961db54d70ada55" and "52d266ef91f7a075f5e0dfae94d28753ef659ca4" have entirely different histories.

View file

@ -1,13 +1,8 @@
import numpy as np
def tensor(*args, **kwargs):
return Tensor(*args, **kwargs)
class Tensor:
# TODO Implement 'requires_grad' functionality.
def __init__(self, value):
# NOTE We technically could support both numpy arrays and scalar values,
# but it is too much work.
if not isinstance(value, np.ndarray):
print(f"{type(value)} is not compatible with {np.ndarray}")
exit(-1)
@ -48,17 +43,6 @@ class Tensor:
tensor._back = back
return tensor
def div(self, other):
tensor = Tensor(self.value / other.value)
tensor._save(self, other)
def back(upstream):
a, b = tensor._parents
return 1 / np.dot(b.value, upstream), -a.value / np.dot(b.value ** 2, upstream)
tensor._back = back
return tensor
def expt(self, exponent):
tensor = Tensor(self.value ** exponent)
tensor._save(self)
@ -92,18 +76,6 @@ class Tensor:
tensor._back = back
return tensor
def tanh(self):
tensor = Tensor(np.tanh(self.value))
tensor._save(self)
def back(upstream):
# dtanh(x)/dx = 1 - tanh2(x)
a, = tensor._parents
return [1 - np.dot(np.tanh(a.value) ** 2, upstream)]
tensor._back = back
return tensor
# TODO Compute gradients only for tensors that need it.
def _backprop(self, upstream):
# Backprop through the tensor iff it has any parents.