diff --git a/include/utils.hpp b/include/utils.hpp index a41ce7b..95657ba 100644 --- a/include/utils.hpp +++ b/include/utils.hpp @@ -15,10 +15,11 @@ #ifndef __UTILS__ #define __UTILS__ -#include -#include +#include #include #include +#include +#include /** @def DEBUG(msg) * @brief Writes a debug message @@ -84,10 +85,24 @@ std::string scientific_format(const std::vector& v, * @param msg The message to be displayed * */ void m_assert(bool expr, - const char* expr_str, - const char* func, - const char* file, + std::string expr_str, + std::string func, + std::string file, int line, - const char* msg); + std::string msg); + + +/** @brief Test if two armadillo vectors are close to each other. + * + * This function takes in 2 vectors and checks if they are approximately + * equal to each other given a tolerance. + * + * @param a Vector a + * @param b Vector b + * @param tol The tolerance + * + * @return Boolean + * */ +bool arma_vector_close_to(arma::vec &a, arma::vec &b, double tol=1e-8); #endif diff --git a/src/utils.cpp b/src/utils.cpp index f326b19..9694610 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -10,10 +10,6 @@ * @bug No known bugs * */ #include "utils.hpp" -#include -#include -#include -#include std::string scientific_format(double d, int width, int prec) { @@ -31,9 +27,9 @@ std::string scientific_format(const std::vector& v, int width, int prec) return ss.str(); } -static void print_message(const char* msg) +static void print_message(std::string msg) { - if (strlen(msg) > 0) { + if (msg.size() > 0) { std::cout << "message: " << msg << "\n\n"; } else { @@ -42,13 +38,13 @@ static void print_message(const char* msg) } void m_assert(bool expr, - const char* expr_str, - const char* f, - const char* file, + std::string expr_str, + std::string f, + std::string file, int line, - const char* msg) + std::string msg) { - std::string new_assert(strlen(f) + (expr ? 4 : 6), '-'); + std::string new_assert(f.size() + (expr ? 4 : 6), '-'); std::cout << "\x1B[36m" << new_assert << "\033[0m\n"; std::cout << f << ": "; if (expr) { @@ -63,3 +59,17 @@ void m_assert(bool expr, abort(); } } + +bool arma_vector_close_to(arma::vec &a, arma::vec &b, double tol) +{ + if (a.n_elem != b.n_elem) { + return false; + } + + for (int i=0; i < a.n_elem; i++) { + if (std::abs(a(i) - b(i)) >= tol) { + return false; + } + } + return true; +}