Repository where I mostly put random python scripts.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

22 lines
662 B

  1. from torch import nn
  2. class BasicNeuralNet(nn.Module):
  3. def __init__(self):
  4. super().__init__()
  5. # Inputs to hidden layer linear transformation
  6. self.hidden = nn.Linear(784, 256)
  7. # Output layer, 10 units
  8. self.output = nn.Linear(256, 10)
  9. # Define sigmoid activation and softmax output
  10. self.sigmoid = nn.Sigmoid()
  11. self.softmax = nn.Softmax(dim=1)
  12. def forward(self, x):
  13. # Pass the input tensor through each of the operations
  14. x = self.hidden(x)
  15. x = self.sigmoid(x)
  16. x = self.output(x)
  17. x = self.softmax(x)
  18. return x