Skip to content

Commit

Permalink
Merge pull request #14 from 42pde-bakk/peer/matrixTests
Browse files Browse the repository at this point in the history
Matrix tests incorporation
  • Loading branch information
pde-bakk authored Oct 15, 2024
2 parents 45d4cd6 + 75f7790 commit b61c1d2
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 51 deletions.
5 changes: 4 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ FetchContent_MakeAvailable(Catch2)

# Add the directory with KalmanFilter.hpp and Matrix.hpp to the include directories
include_directories(srcs)
add_executable(tests tests/main.cpp ${SOURCES})
add_executable(tests
tests/mainTests.cpp
tests/MatrixTests.cpp
${SOURCES})
target_link_libraries(tests PRIVATE Catch2::Catch2WithMain)

#add_executable(test tests/main.cpp ${SOURCES})
33 changes: 18 additions & 15 deletions srcs/KalmanFilter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ class KalmanFilter {
using KalmanGain = Matrix<double, Nx, Nz>;

private:
Vector<double, Nx> state{};
Vector<double, Nx> previousState{};
Vector<double, Nx> currenState{};
Vector<double, Nx> predictedState{};
unsigned int k;
// Matrices
Matrix<double, Nx, Nx> state_covariance_matrix;
Expand Down Expand Up @@ -111,7 +113,7 @@ class KalmanFilter {

public:
explicit KalmanFilter() {
this->state = Matrix<double, Nx, 1>();
this->currenState = Matrix<double, Nx, 1>();
}

// Vector3d predict(size_t time_step, const InputVector &inputs);
Expand All @@ -131,8 +133,7 @@ class KalmanFilter {
///

Vector3d predict(size_t time_step, const InputVector& inputs) {
std::cout << "IN" << this->state << std::endl;
Vector3d predicted_pos;
std::cout << "IN" << this->currenState << std::endl;

auto time = (double)time_step / 1000;

Expand All @@ -151,9 +152,9 @@ class KalmanFilter {
std::array<double, 9>({ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1}),
});

this->state = F_mat * this->state;
this->currenState = F_mat * this->currenState;

std::cout << "STATE\n" << this->state << std::endl;
std::cout << "STATE\n" << this->currenState << std::endl;

this->P_mat = this->extrapolate_covariance(F_mat, this->P_mat);

Expand All @@ -167,21 +168,23 @@ class KalmanFilter {

std::cout << "KALMAN GAIN\n" << kalman << std::endl;

auto updated_state = this->update_state_matrix(kalman, this->state, measurement);
this->predictedState = this->update_state_matrix(kalman, this->currenState, measurement);

std::cout << "UPDATED\n" << updated_state << std::endl;
std::cout << "UPDATED\n" << this->predictedState << std::endl;

this->state = updated_state;
this->currenState = this->predictedState;

auto updated_covariance = this->update_covariance_matrix(kalman);

std::cout << "COV\n" << updated_covariance << std::endl;

this->P_mat = updated_covariance;

predicted_pos[0][0] = this->state[0][0];
predicted_pos[1][0] = this->state[1][0];
predicted_pos[2][0] = this->state[2][0];
Vector3d predicted_pos;

predicted_pos[0][0] = this->predictedState[0][0];
predicted_pos[1][0] = this->predictedState[1][0];
predicted_pos[2][0] = this->predictedState[2][0];

// auto predicted_mu = A * mu_t + B * u_t;
// auto predicted_sigma = A * sigma_t * A.transpose() + Q;
Expand All @@ -193,7 +196,7 @@ class KalmanFilter {
}

[[nodiscard]] const Vector<double, Nx>& get_state() const {
return (this->state);
return (this->currenState);
}

Matrix<double, Nx, Nx> extrapolate_covariance(const StateTransitionMatrix F_mat, const EstimateCovarianceMatrix P_mat) {
Expand Down Expand Up @@ -265,7 +268,7 @@ class KalmanFilter {
}

void set_state(std::array<double, Nx> &state) {
this->state = Matrix<double, Nx, 1>(state);
this->currenState = Matrix<double, Nx, 1>(state);
}

Matrix<double, Nx, 1> get_initial_process_noise() {
Expand All @@ -278,7 +281,7 @@ class KalmanFilter {
}

double get_current_speed() {
auto speed = std::sqrt(std::pow(this->state[0][3], 2) + std::pow(this->state[0][4], 2) + std::pow(this->state[0][5], 2));
auto speed = std::sqrt(std::pow(this->currenState[0][3], 2) + std::pow(this->currenState[0][4], 2) + std::pow(this->currenState[0][5], 2));

return speed;
}
Expand Down
31 changes: 24 additions & 7 deletions srcs/Matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@ class Matrix {

public:
template <typename, size_t, size_t> friend class Matrix;
Matrix() = default;
Matrix() {
for (size_t row = 0; row < ROW_AMOUNT; row++) {
for (size_t col = 0; col < COLUMN_AMOUNT; col++) {
this->data[row][col] = T();
}
}
};

Matrix(const Matrix& rhs) = default;

Expand All @@ -29,13 +35,22 @@ class Matrix {
}
}

Matrix(T n) {
for (size_t row = 0; row < ROW_AMOUNT; row++) {
for (size_t col = 0; col < COLUMN_AMOUNT; col++) {
this->data[row][col] = n;
}
}
}

Matrix(const std::vector<T>& vec) {
assert(ROW_AMOUNT == vec.size());
assert(COLUMN_AMOUNT == 1);
for (size_t row = 0; row < ROW_AMOUNT; row++) {
this->data[row][0] = vec[row];
}
}

Matrix(const std::array<T, ROW_AMOUNT>& vec) {
for (size_t row = 0; row < ROW_AMOUNT; row++) {
this->data[row][0] = vec[row];
Expand Down Expand Up @@ -64,6 +79,8 @@ class Matrix {
return (this->data[i]);
}

// For matrix multiplication, the number of columns in the first matrix must be equal to the number of rows in the second matrix.
// The result matrix has the number of rows of the first and the number of columns of the second matrix.
template<size_t ROW_AMOUNT_2, size_t COLUMN_AMOUNT_2>
Matrix<T, ROW_AMOUNT, COLUMN_AMOUNT_2> operator*(const Matrix<T, ROW_AMOUNT_2, COLUMN_AMOUNT_2>& rhs) const {
assert(COLUMN_AMOUNT == ROW_AMOUNT_2);
Expand Down Expand Up @@ -92,8 +109,8 @@ class Matrix {
return (out);
}

[[nodiscard]] Matrix transpose() const {
auto out = Matrix<double, ROW_AMOUNT, COLUMN_AMOUNT>();
[[nodiscard]] Matrix<T, COLUMN_AMOUNT, ROW_AMOUNT> transpose() const {
auto out = Matrix<T, COLUMN_AMOUNT, ROW_AMOUNT>();

for (size_t row = 0; row < ROW_AMOUNT; row++) {
for (size_t column = 0; column < COLUMN_AMOUNT; column++) {
Expand All @@ -104,8 +121,8 @@ class Matrix {
}

template<size_t R>
[[nodiscard]] Matrix<double, ROW_AMOUNT + R, COLUMN_AMOUNT> vstack(const Matrix<T, R, COLUMN_AMOUNT>& rhs) const {
auto out = Matrix<double, ROW_AMOUNT + R, COLUMN_AMOUNT>();
[[nodiscard]] Matrix<T, ROW_AMOUNT + R, COLUMN_AMOUNT> vstack(const Matrix<T, R, COLUMN_AMOUNT>& rhs) const {
auto out = Matrix<T, ROW_AMOUNT + R, COLUMN_AMOUNT>();
for (size_t r = 0; r < ROW_AMOUNT; r++) {
out.data[r] = this->data[r];
}
Expand All @@ -117,8 +134,8 @@ class Matrix {
}

template<size_t SIZE>
static Matrix<double, SIZE, SIZE> identity() {
Matrix<double, SIZE, SIZE> out;
static Matrix<T, SIZE, SIZE> identity() {
Matrix<T, SIZE, SIZE> out;
for (size_t i = 0; i < SIZE; i++) {
out[i][i] = 1;
}
Expand Down
131 changes: 131 additions & 0 deletions tests/MatrixTests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
//
// Created by peer on 8-10-24.
//
#include <catch2/catch_all.hpp>
#include "Matrix.hpp"
constexpr double EPSILON = 0.0001;


TEST_CASE("Matrix default constructor", "[Matrix]") {
const Matrix<double, 2, 2> mat{};
REQUIRE(mat.get_row_amount() == 2);
REQUIRE(mat.get_column_amount() == 2);
}

TEST_CASE("Matrix constructor from std::array", "[Matrix]") {
const std::array<std::array<double, 2>, 2> arr = {{{1.0, 2.0}, {3.0, 4.0}}};
const Matrix<double, 2, 2> mat(arr);

REQUIRE_THAT(mat[0][0], Catch::Matchers::WithinAbs(1.0, EPSILON));
REQUIRE_THAT(mat[0][1], Catch::Matchers::WithinAbs(2.0, EPSILON));
REQUIRE_THAT(mat[1][0], Catch::Matchers::WithinAbs(3.0, EPSILON));
REQUIRE_THAT(mat[1][1], Catch::Matchers::WithinAbs(4.0, EPSILON));
}

TEST_CASE("Matrix constructor from scalar", "[Matrix]") {
const Matrix<double, 2, 2> mat(5.0);

REQUIRE_THAT(mat[0][0], Catch::Matchers::WithinAbs(5.0, EPSILON));
REQUIRE_THAT(mat[0][1], Catch::Matchers::WithinAbs(5.0, EPSILON));
REQUIRE_THAT(mat[1][0], Catch::Matchers::WithinAbs(5.0, EPSILON));
REQUIRE_THAT(mat[1][1], Catch::Matchers::WithinAbs(5.0, EPSILON));
}

TEST_CASE("Matrix multiplication order", "[Matrix]") {
constexpr size_t m = 5, n = 3, p = 4;
const Matrix<double, m, n> lhs(8.0);
const Matrix<double, n, p> rhs(2.0);

const auto result = lhs * rhs;
REQUIRE(result.get_row_amount() == m);
REQUIRE(result.get_column_amount() == p);
}

TEST_CASE("Matrix multiplication", "[Matrix]") {
const std::array<std::array<double, 2>, 2> lhs_arr = {{{1.0, 2.0}, {3.0, 4.0}}};
const std::array<std::array<double, 2>, 2> rhs_arr = {{{5.0, 6.0}, {7.0, 8.0}}};
const Matrix<double, 2, 2> lhs(lhs_arr);
const Matrix<double, 2, 2> rhs(rhs_arr);

const Matrix<double, 2, 2> result = lhs * rhs;

REQUIRE_THAT(result[0][0], Catch::Matchers::WithinAbs(19.0, EPSILON));
REQUIRE_THAT(result[0][1], Catch::Matchers::WithinAbs(22.0, EPSILON));
REQUIRE_THAT(result[1][0], Catch::Matchers::WithinAbs(43.0, EPSILON));
REQUIRE_THAT(result[1][1], Catch::Matchers::WithinAbs(50.0, EPSILON));
}

TEST_CASE("Matrix scalar multiplication", "[Matrix]") {
const Matrix<double, 2, 2> mat(2.0);
const Matrix<double, 2, 2> result = mat * 3.0;

REQUIRE_THAT(result[0][0], Catch::Matchers::WithinAbs(6.0, EPSILON));
REQUIRE_THAT(result[0][1], Catch::Matchers::WithinAbs(6.0, EPSILON));
REQUIRE_THAT(result[1][0], Catch::Matchers::WithinAbs(6.0, EPSILON));
REQUIRE_THAT(result[1][1], Catch::Matchers::WithinAbs(6.0, EPSILON));
}

TEST_CASE("Matrix transpose order", "[Matrix]") {
const std::array<std::array<double, 4>, 2> arr = {{{1.0, 2.0, 3.0, 4.0}, {3.0, 4.0, 5.0, 6.0}}};
const Matrix<double, 2, 4> mat(arr);

const auto transposed = mat.transpose();
REQUIRE(transposed.get_row_amount() == mat.get_column_amount());
REQUIRE(transposed.get_column_amount() == mat.get_row_amount());
}

TEST_CASE("Matrix transpose", "[Matrix]") {
const std::array<std::array<double, 2>, 2> arr = {{{1.0, 2.0}, {3.0, 4.0}}};
const Matrix<double, 2, 2> mat(arr);

const Matrix<double, 2, 2> transposed = mat.transpose();

REQUIRE_THAT(transposed[0][0], Catch::Matchers::WithinAbs(1.0, EPSILON));
REQUIRE_THAT(transposed[0][1], Catch::Matchers::WithinAbs(3.0, EPSILON));
REQUIRE_THAT(transposed[1][0], Catch::Matchers::WithinAbs(2.0, EPSILON));
REQUIRE_THAT(transposed[1][1], Catch::Matchers::WithinAbs(4.0, EPSILON));
}

TEST_CASE("Matrix addition", "[Matrix]") {
const std::array<std::array<double, 2>, 2> lhs_arr = {{{1.0, 2.0}, {3.0, 4.0}}};
const std::array<std::array<double, 2>, 2> rhs_arr = {{{5.0, 6.0}, {7.0, 8.0}}};
const Matrix<double, 2, 2> lhs(lhs_arr);
const Matrix<double, 2, 2> rhs(rhs_arr);

const Matrix<double, 2, 2> result = lhs + rhs;

REQUIRE_THAT(result[0][0], Catch::Matchers::WithinAbs(6.0, EPSILON));
REQUIRE_THAT(result[0][1], Catch::Matchers::WithinAbs(8.0, EPSILON));
REQUIRE_THAT(result[1][0], Catch::Matchers::WithinAbs(10.0, EPSILON));
REQUIRE_THAT(result[1][1], Catch::Matchers::WithinAbs(12.0, EPSILON));
}

TEST_CASE("Matrix subtraction", "[Matrix]") {
const std::array<std::array<double, 2>, 2> lhs_arr = {{{5.0, 6.0}, {7.0, 8.0}}};
const std::array<std::array<double, 2>, 2> rhs_arr = {{{1.0, 2.0}, {3.0, 4.0}}};
const Matrix<double, 2, 2> lhs(lhs_arr);
const Matrix<double, 2, 2> rhs(rhs_arr);

const Matrix<double, 2, 2> result = lhs - rhs;

REQUIRE_THAT(result[0][0], Catch::Matchers::WithinAbs(4.0, EPSILON));
REQUIRE_THAT(result[0][1], Catch::Matchers::WithinAbs(4.0, EPSILON));
REQUIRE_THAT(result[1][0], Catch::Matchers::WithinAbs(4.0, EPSILON));
REQUIRE_THAT(result[1][1], Catch::Matchers::WithinAbs(4.0, EPSILON));
}

TEST_CASE("Matrix identity", "[Matrix]") {
Matrix<double, 3, 3> identity = Matrix<double, 3, 3>::identity<3>();

REQUIRE_THAT(identity[0][0], Catch::Matchers::WithinAbs(1.0, EPSILON));
REQUIRE_THAT(identity[0][1], Catch::Matchers::WithinAbs(0.0, EPSILON));
REQUIRE_THAT(identity[0][2], Catch::Matchers::WithinAbs(0.0, EPSILON));

REQUIRE_THAT(identity[1][0], Catch::Matchers::WithinAbs(0.0, EPSILON));
REQUIRE_THAT(identity[1][1], Catch::Matchers::WithinAbs(1.0, EPSILON));
REQUIRE_THAT(identity[1][2], Catch::Matchers::WithinAbs(0.0, EPSILON));

REQUIRE_THAT(identity[2][0], Catch::Matchers::WithinAbs(0.0, EPSILON));
REQUIRE_THAT(identity[2][1], Catch::Matchers::WithinAbs(0.0, EPSILON));
REQUIRE_THAT(identity[2][2], Catch::Matchers::WithinAbs(1.0, EPSILON));
}
56 changes: 28 additions & 28 deletions tests/main.cpp → tests/mainTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,31 +62,31 @@ TEST_CASE("Covariance Matrix Update") {
}
}

TEST_CASE("Kalman Gain Calculation") {
KalmanFilter<Nx, Nz, Nu> filter;

SECTION("Kalman Gain Matrix Calculation") {
auto kalman_gain = filter.calculate_kalman_gain();

// Ensure Kalman Gain matrix has reasonable values
REQUIRE(kalman_gain[0][0] >= 0);
}
}

TEST_CASE("State Update with Measurement") {
KalmanFilter<Nx, Nz, Nu> filter;
std::array<double, Nx> initial_state = {1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
filter.set_state(initial_state);

SECTION("State Update After Measurement") {
Matrix<double, Nx, 1> measurement;
measurement[0][0] = 2.0; // Simulate a new measurement on position x

auto kalman_gain = filter.calculate_kalman_gain();
auto updated_state = filter.update_state_matrix(kalman_gain, filter.get_state(), measurement);

// Ensure the state is updated towards the measurement
REQUIRE(updated_state[0][0] > 1.0);
REQUIRE(updated_state[0][0] <= 2.0);
}
}
//TEST_CASE("Kalman Gain Calculation") {
// KalmanFilter<Nx, Nz, Nu> filter;
//
// SECTION("Kalman Gain Matrix Calculation") {
// auto kalman_gain = filter.calculate_kalman_gain();
//
// // Ensure Kalman Gain matrix has reasonable values
// REQUIRE(kalman_gain[0][0] >= 0);
// }
//}

//TEST_CASE("State Update with Measurement") {
// KalmanFilter<Nx, Nz, Nu> filter;
// std::array<double, Nx> initial_state = {1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
// filter.set_state(initial_state);
//
// SECTION("State Update After Measurement") {
// Matrix<double, Nx, 1> measurement{};
// measurement[0][0] = 2.0; // Simulate a new measurement on position x
//
// auto kalman_gain = filter.calculate_kalman_gain();
// auto updated_state = filter.update_state_matrix(kalman_gain, filter.get_state(), measurement);
//
// // Ensure the state is updated towards the measurement
// REQUIRE(updated_state[0][0] > 1.0);
// REQUIRE(updated_state[0][0] <= 2.0);
// }
//}

0 comments on commit b61c1d2

Please sign in to comment.