Add utility tensor function
This commit is contained in:
parent
52d266ef91
commit
d39a843ef9
|
@ -1,8 +1,13 @@
|
|||
import numpy as np
|
||||
|
||||
def tensor(*args, **kwargs):
|
||||
return Tensor(*args, **kwargs)
|
||||
|
||||
class Tensor:
|
||||
# TODO Implement 'requires_grad' functionality.
|
||||
def __init__(self, value):
|
||||
# NOTE We technically could support both numpy arrays and scalar values,
|
||||
# but it is too much work.
|
||||
if not isinstance(value, np.ndarray):
|
||||
print(f"{type(value)} is not compatible with {np.ndarray}")
|
||||
exit(-1)
|
||||
|
|
Loading…
Reference in a new issue