From c7bde98a07656b3d344e8321a010deedc02e749f Mon Sep 17 00:00:00 2001 From: Orif Milod Date: Fri, 15 Mar 2024 00:16:06 +0100 Subject: [PATCH] Fix Conv2D tests --- gigatorch/cnn.py | 6 +++-- tests/cnn_test.py | 57 +++++++++++++++++++++++++++++++---------------- 2 files changed, 42 insertions(+), 21 deletions(-) diff --git a/gigatorch/cnn.py b/gigatorch/cnn.py index d29dfb9bf..52535c435 100644 --- a/gigatorch/cnn.py +++ b/gigatorch/cnn.py @@ -97,7 +97,9 @@ def __init__(self, in_channels, out_channels, kernel_size, activation_fn, stride self.activation_fn = activation_fn self.stride = stride - def compute(self, input): + def compute(self, input): + assert len(input.shape) == 4, f"Can't Conv2D {input.shape}" + (batch_size, _, height, width) = input.shape output_height = (height - self.kernel_size) // self.stride + 1 output_width = (width - self.kernel_size) // self.stride + 1 @@ -112,7 +114,7 @@ def compute(self, input): w_start = j * self.stride w_end = w_start + self.kernel_size output[b, k, i, j] = self.activation_fn( - np.sum(input[b, :, h_start:h_end, w_start:w_end] * self.kernels[k]) + np.sum((input[b, :, h_start:h_end, w_start:w_end] * self.kernels[k]).data) ) return output diff --git a/tests/cnn_test.py b/tests/cnn_test.py index cdb75a77e..bc0e5dbb4 100644 --- a/tests/cnn_test.py +++ b/tests/cnn_test.py @@ -2,32 +2,51 @@ from pytest import raises from gigatorch.cnn import Conv2D, MaxPool2D from gigatorch.activation_fn import relu - - -def test_conv2d_success(): - conv2d = Conv2D(1, 10, 2, relu) - conv2d.kernels = Tensor( +import numpy as np + +def test_conv2d_shape_success(): + # Test case 1: Basic functionality + conv2d = Conv2D(in_channels=3, out_channels=2, kernel_size=2, activation_fn=np.tanh, stride=1) + input_tensor = Tensor(np.random.rand(1, 3, 4, 4)) # 1 image, 3 channels, 4x4 size + output = conv2d.compute(input_tensor) + assert output.shape == (1, 2, 3, 3), "Output shape mismatch in Test Case 1" + + # Test case 2: Stride functionality + conv2d = Conv2D(in_channels=3, out_channels=2, kernel_size=2, activation_fn=np.tanh, stride=2) + input_tensor = Tensor(np.random.rand(1, 3, 4, 4)) # 1 image, 3 channels, 4x4 size + output = conv2d.compute(input_tensor) + assert output.shape == (1, 2, 2, 2), "Output shape mismatch in Test Case 2" + + # Test case 3: Multiple images + conv2d = Conv2D(in_channels=3, out_channels=2, kernel_size=2, activation_fn=np.tanh, stride=1) + input_tensor = Tensor(np.random.rand(2, 3, 4, 4)) # 2 images, 3 channels, 4x4 size + output = conv2d.compute(input_tensor) + assert output.shape == (2, 2, 3, 3), "Output shape mismatch in Test Case 3" + +def test_conv2d_compute_success(): + # Test case: Check actual values + conv2d = Conv2D(in_channels=1, out_channels=1, kernel_size=2, activation_fn=lambda x: x, stride=1) + conv2d.kernels = Tensor(np.array([ [ [ - [1, 2], - [3, 4], + [1, 0], + [0, 1] ] ] - ) - sample_data = Tensor( - [ # Batch 1 - [ # Channel 1 - [1, 1, 1], - [1, 1, 1], - [1, 1, 1], + ])) # Set a fixed kernel + input_tensor = Tensor(np.array([ + [ + [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9] ] ] - ) - - output = conv2d.compute(sample_data) - expected = Tensor([[[10, 10], [10, 10]]]) # for layer 1 + ])) - assert all(output.item() == expected) + output = conv2d.compute(input_tensor) + expected_output = np.array([[[[6, 8], [12, 14]]]]) + assert np.allclose(output.data, expected_output), "Output values mismatch" def test_conv2d_kernel_size_larger_than_input():