diff --git a/include/testlib.hpp b/include/testlib.hpp index fbc0d17..eee3f31 100644 --- a/include/testlib.hpp +++ b/include/testlib.hpp @@ -44,19 +44,55 @@ void m_assert(bool expr, std::string expr_str, std::string func, std::string file, int line, std::string msg); -/** @brief Test if two armadillo vectors are close to each other. +/** @brief Test if two armadillo matrices/vectors are close to each other. * - * @details This function takes in 2 vectors and checks if they are + * @details This function takes in 2 matrices/vectors and checks if they are * approximately equal to each other given a tolerance. * - * @param a Vector a - * @param b Vector b + * @param a Matrix/vector a + * @param b Matrix/vector b * @param tol The tolerance * * @return bool * */ -bool close_to(arma::vec &a, arma::vec &b, double tol = 1e-8); +template ::value>::type> +static bool close_to(arma::Mat &a, arma::Mat &b, double tol = 1e-8) +{ + if (a.n_elem != b.n_elem) { + return false; + } + for (size_t i = 0; i < a.n_elem; i++) { + if (std::abs(a(i) - b(i)) >= tol) { + return false; + } + } + return true; + +} + +/** @brief Test if two armadillo matrices/vectors are equal. + * + * @details This function takes in 2 matrices/vectors and checks if they are + * equal to each other. This should only be used for integral types. + * + * @param a Matrix/vector a + * @param b Matrix/vector b + * + * @return bool + * */ +template ::value>::type> +static bool is_equal(arma::Mat &a, arma::Mat &b) +{ + for (size_t i = 0; i < a.n_elem; i++) { + if (!(a(i) == b(i))) { + return false; + } + } + return true; +} /** @brief Test that all elements fulfill the condition. * * @param expr The boolean expression to apply to each element diff --git a/src/testlib.cpp b/src/testlib.cpp index 4a4d9ea..7cce040 100644 --- a/src/testlib.cpp +++ b/src/testlib.cpp @@ -12,8 +12,6 @@ #include "testlib.hpp" -#include - static void print_message(std::string msg) { if (msg.size() > 0) { @@ -42,17 +40,3 @@ void m_assert(bool expr, std::string expr_str, std::string f, std::string file, abort(); } } - -bool close_to(arma::vec &a, arma::vec &b, double tol) -{ - if (a.n_elem != b.n_elem) { - return false; - } - - for (size_t i = 0; i < a.n_elem; i++) { - if (std::abs(a(i) - b(i)) >= tol) { - return false; - } - } - return true; -}