diff --git a/include/ttl/nn/bits/engines/eigen_impl.hpp b/include/ttl/nn/bits/engines/eigen_impl.hpp new file mode 100644 index 0000000..44ae472 --- /dev/null +++ b/include/ttl/nn/bits/engines/eigen_impl.hpp @@ -0,0 +1,75 @@ +#include +#include + +#include + +namespace ttl::nn::engines +{ +template +struct eigen_matrix; + +template <> +struct eigen_matrix { + using type = Eigen::MatrixXf; +}; + +template <> +struct eigen_matrix { + using type = Eigen::MatrixXd; +}; + +template +class eigen_impl +{ + using m_ref_t = ttl::matrix_ref; + using m_view_t = ttl::matrix_view; + using v_ref_t = ttl::vector_ref; + using v_view_t = ttl::vector_view; + + using eigen_mat = typename eigen_matrix::type; + + static eigen_mat to(const m_view_t &a) + { + eigen_mat m(std::get<0>(a.dims()), std::get<1>(a.dims())); + return m; + } + + public: + static void mm(const m_view_t &a, const m_view_t &b, const m_ref_t &c) + { + // FIXME: copy elison + const eigen_mat A = to(a); + const eigen_mat B = to(b); + eigen_mat C = A * B; + std::copy(C.data(), C.data() + c.size(), c.data()); + } + + static void mmt(const m_view_t &a, const m_view_t &b, const m_ref_t &c) + { + throw std::runtime_error("TODO"); + } + + static void mtm(const m_view_t &a, const m_view_t &b, const m_ref_t &c) + { + throw std::runtime_error("TODO"); + } + + static void mv(const m_view_t &a, const v_view_t &b, const v_ref_t &c) + { + throw std::runtime_error("TODO"); + } + + static void vm(const v_view_t &a, const m_view_t &b, const v_ref_t &c) + { + throw std::runtime_error("TODO"); + } +}; + +struct eigen; + +template <> +struct backend { + template + using type = eigen_impl; +}; +} // namespace ttl::nn::engines