Make single pass use tensors instead of ndarrays
This commit is contained in:
parent
b0bb7b523d
commit
45168730e8
39
pygrad/nn.py
39
pygrad/nn.py
|
@ -1,38 +1,43 @@
|
|||
# Neural networks from scratch with numpy.
|
||||
|
||||
import numpy as np
|
||||
import pygrad.tensor as tensor
|
||||
from pygrad.tensor import tensor, Tensor
|
||||
|
||||
def mean_absolute_error(x, y):
|
||||
return np.mean(np.abs(x - y))
|
||||
|
||||
def mean_squared_error(x, y):
|
||||
return np.mean(np.power(x - y, 2))
|
||||
def mean_squared_error(x: Tensor, y: Tensor):
|
||||
return x.sub(y).expt(2).div(tensor([[2.0]]))
|
||||
|
||||
def cross_entropy_loss(x, y):
|
||||
return -np.log(np.exp(y) / np.sum(np.exp(x)))
|
||||
def cross_entropy_loss(x: Tensor, y: Tensor):
|
||||
return y.exp().div(np.sum(x.exp())).log().neg()
|
||||
|
||||
# prepare inputs and outputs
|
||||
x = np.array([[1, 0]])
|
||||
y = np.array([[1]])
|
||||
x = tensor(np.array([[1, 0]]))
|
||||
y = tensor(np.array([[1]]))
|
||||
|
||||
# we're doing xavier initialisation - see <http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>
|
||||
w1 = np.random.randn(2, 3) / np.sqrt(2)
|
||||
w2 = np.random.randn(3, 1) / np.sqrt(3)
|
||||
w1 = tensor(np.random.randn(2, 3) / np.sqrt(2))
|
||||
w2 = tensor(np.random.randn(3, 1) / np.sqrt(3))
|
||||
|
||||
def single_pass():
|
||||
global w1, w2
|
||||
|
||||
# forward pass
|
||||
h = np.matmul(x, w1)
|
||||
h_hat = np.tanh(h)
|
||||
j = np.matmul(h_hat, w2)
|
||||
print("prediction {}".format(j))
|
||||
h = x.mul(w1)
|
||||
h_hat = h.tanh()
|
||||
j = h_hat.mul(w2)
|
||||
print(f"prediction {j}")
|
||||
|
||||
# loss calculation
|
||||
loss = cross_entropy_loss(j, y)
|
||||
print("loss {}".format(loss))
|
||||
loss = mean_squared_error(j, y)
|
||||
print(f"loss {loss}")
|
||||
|
||||
# TODO Backward pass.
|
||||
return
|
||||
loss.backward()
|
||||
print(w1.grad, w2.grad)
|
||||
|
||||
w1.value -= 0.1 * w1.grad
|
||||
w2.value -= 0.1 * w2.grad
|
||||
|
||||
# initialise layers
|
||||
# self.lin1 = nn.Linear(2, 3)
|
||||
|
|
Loading…
Reference in a new issue