Add utility tensor function

This commit is contained in:
Aodhnait Étaín 2020-11-15 21:51:56 +00:00
parent 52d266ef91
commit d39a843ef9

View file

@ -1,8 +1,13 @@
import numpy as np
def tensor(*args, **kwargs):
return Tensor(*args, **kwargs)
class Tensor:
# TODO Implement 'requires_grad' functionality.
def __init__(self, value):
# NOTE We technically could support both numpy arrays and scalar values,
# but it is too much work.
if not isinstance(value, np.ndarray):
print(f"{type(value)} is not compatible with {np.ndarray}")
exit(-1)