Skip to content

Commit

Permalink
Fix Conv2D tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Orif Milod committed Mar 14, 2024
1 parent 6ba0e28 commit c7bde98
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 21 deletions.
6 changes: 4 additions & 2 deletions gigatorch/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def __init__(self, in_channels, out_channels, kernel_size, activation_fn, stride
self.activation_fn = activation_fn
self.stride = stride

def compute(self, input):
def compute(self, input):
assert len(input.shape) == 4, f"Can't Conv2D {input.shape}"

(batch_size, _, height, width) = input.shape
output_height = (height - self.kernel_size) // self.stride + 1
output_width = (width - self.kernel_size) // self.stride + 1
Expand All @@ -112,7 +114,7 @@ def compute(self, input):
w_start = j * self.stride
w_end = w_start + self.kernel_size
output[b, k, i, j] = self.activation_fn(
np.sum(input[b, :, h_start:h_end, w_start:w_end] * self.kernels[k])
np.sum((input[b, :, h_start:h_end, w_start:w_end] * self.kernels[k]).data)
)

return output
Expand Down
57 changes: 38 additions & 19 deletions tests/cnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,51 @@
from pytest import raises
from gigatorch.cnn import Conv2D, MaxPool2D
from gigatorch.activation_fn import relu


def test_conv2d_success():
conv2d = Conv2D(1, 10, 2, relu)
conv2d.kernels = Tensor(
import numpy as np

def test_conv2d_shape_success():
# Test case 1: Basic functionality
conv2d = Conv2D(in_channels=3, out_channels=2, kernel_size=2, activation_fn=np.tanh, stride=1)
input_tensor = Tensor(np.random.rand(1, 3, 4, 4)) # 1 image, 3 channels, 4x4 size
output = conv2d.compute(input_tensor)
assert output.shape == (1, 2, 3, 3), "Output shape mismatch in Test Case 1"

# Test case 2: Stride functionality
conv2d = Conv2D(in_channels=3, out_channels=2, kernel_size=2, activation_fn=np.tanh, stride=2)
input_tensor = Tensor(np.random.rand(1, 3, 4, 4)) # 1 image, 3 channels, 4x4 size
output = conv2d.compute(input_tensor)
assert output.shape == (1, 2, 2, 2), "Output shape mismatch in Test Case 2"

# Test case 3: Multiple images
conv2d = Conv2D(in_channels=3, out_channels=2, kernel_size=2, activation_fn=np.tanh, stride=1)
input_tensor = Tensor(np.random.rand(2, 3, 4, 4)) # 2 images, 3 channels, 4x4 size
output = conv2d.compute(input_tensor)
assert output.shape == (2, 2, 3, 3), "Output shape mismatch in Test Case 3"

def test_conv2d_compute_success():
# Test case: Check actual values
conv2d = Conv2D(in_channels=1, out_channels=1, kernel_size=2, activation_fn=lambda x: x, stride=1)
conv2d.kernels = Tensor(np.array([
[
[
[1, 2],
[3, 4],
[1, 0],
[0, 1]
]
]
)
sample_data = Tensor(
[ # Batch 1
[ # Channel 1
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
])) # Set a fixed kernel
input_tensor = Tensor(np.array([
[
[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9]
]
]
)

output = conv2d.compute(sample_data)
expected = Tensor([[[10, 10], [10, 10]]]) # for layer 1
]))

assert all(output.item() == expected)
output = conv2d.compute(input_tensor)
expected_output = np.array([[[[6, 8], [12, 14]]]])
assert np.allclose(output.data, expected_output), "Output values mismatch"


def test_conv2d_kernel_size_larger_than_input():
Expand Down

0 comments on commit c7bde98

Please sign in to comment.