整理
This commit is contained in:
288
include/boost/numeric/ublas/tensor/expression_evaluation.hpp
Normal file
288
include/boost/numeric/ublas/tensor/expression_evaluation.hpp
Normal file
@@ -0,0 +1,288 @@
|
||||
//
|
||||
// Copyright (c) 2018-2019, Cem Bassoy, cem.bassoy@gmail.com
|
||||
//
|
||||
// Distributed under the Boost Software License, Version 1.0. (See
|
||||
// accompanying file LICENSE_1_0.txt or copy at
|
||||
// http://www.boost.org/LICENSE_1_0.txt)
|
||||
//
|
||||
// The authors gratefully acknowledge the support of
|
||||
// Fraunhofer IOSB, Ettlingen, Germany
|
||||
//
|
||||
|
||||
#ifndef _BOOST_UBLAS_TENSOR_EXPRESSIONS_EVALUATION_HPP_
|
||||
#define _BOOST_UBLAS_TENSOR_EXPRESSIONS_EVALUATION_HPP_
|
||||
|
||||
#include <type_traits>
|
||||
#include <stdexcept>
|
||||
|
||||
|
||||
namespace boost::numeric::ublas {
|
||||
|
||||
template<class element_type, class storage_format, class storage_type>
|
||||
class tensor;
|
||||
|
||||
template<class size_type>
|
||||
class basic_extents;
|
||||
|
||||
}
|
||||
|
||||
namespace boost::numeric::ublas::detail {
|
||||
|
||||
template<class T, class D>
|
||||
struct tensor_expression;
|
||||
|
||||
template<class T, class EL, class ER, class OP>
|
||||
struct binary_tensor_expression;
|
||||
|
||||
template<class T, class E, class OP>
|
||||
struct unary_tensor_expression;
|
||||
|
||||
}
|
||||
|
||||
namespace boost::numeric::ublas::detail {
|
||||
|
||||
template<class T, class E>
|
||||
struct has_tensor_types
|
||||
{ static constexpr bool value = false; };
|
||||
|
||||
template<class T>
|
||||
struct has_tensor_types<T,T>
|
||||
{ static constexpr bool value = true; };
|
||||
|
||||
template<class T, class D>
|
||||
struct has_tensor_types<T, tensor_expression<T,D>>
|
||||
{ static constexpr bool value = std::is_same<T,D>::value || has_tensor_types<T,D>::value; };
|
||||
|
||||
|
||||
template<class T, class EL, class ER, class OP>
|
||||
struct has_tensor_types<T, binary_tensor_expression<T,EL,ER,OP>>
|
||||
{ static constexpr bool value = std::is_same<T,EL>::value || std::is_same<T,ER>::value || has_tensor_types<T,EL>::value || has_tensor_types<T,ER>::value; };
|
||||
|
||||
template<class T, class E, class OP>
|
||||
struct has_tensor_types<T, unary_tensor_expression<T,E,OP>>
|
||||
{ static constexpr bool value = std::is_same<T,E>::value || has_tensor_types<T,E>::value; };
|
||||
|
||||
} // namespace boost::numeric::ublas::detail
|
||||
|
||||
|
||||
namespace boost::numeric::ublas::detail {
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
/** @brief Retrieves extents of the tensor
|
||||
*
|
||||
*/
|
||||
template<class T, class F, class A>
|
||||
auto retrieve_extents(tensor<T,F,A> const& t)
|
||||
{
|
||||
return t.extents();
|
||||
}
|
||||
|
||||
/** @brief Retrieves extents of the tensor expression
|
||||
*
|
||||
* @note tensor expression must be a binary tree with at least one tensor type
|
||||
*
|
||||
* @returns extents of the child expression if it is a tensor or extents of one child of its child.
|
||||
*/
|
||||
template<class T, class D>
|
||||
auto retrieve_extents(tensor_expression<T,D> const& expr)
|
||||
{
|
||||
static_assert(detail::has_tensor_types<T,tensor_expression<T,D>>::value,
|
||||
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
|
||||
|
||||
auto const& cast_expr = static_cast<D const&>(expr);
|
||||
|
||||
if constexpr ( std::is_same<T,D>::value )
|
||||
return cast_expr.extents();
|
||||
else
|
||||
return retrieve_extents(cast_expr);
|
||||
}
|
||||
|
||||
/** @brief Retrieves extents of the binary tensor expression
|
||||
*
|
||||
* @note tensor expression must be a binary tree with at least one tensor type
|
||||
*
|
||||
* @returns extents of the (left and if necessary then right) child expression if it is a tensor or extents of a child of its (left and if necessary then right) child.
|
||||
*/
|
||||
template<class T, class EL, class ER, class OP>
|
||||
auto retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& expr)
|
||||
{
|
||||
static_assert(detail::has_tensor_types<T,binary_tensor_expression<T,EL,ER,OP>>::value,
|
||||
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
|
||||
|
||||
if constexpr ( std::is_same<T,EL>::value )
|
||||
return expr.el.extents();
|
||||
|
||||
if constexpr ( std::is_same<T,ER>::value )
|
||||
return expr.er.extents();
|
||||
|
||||
else if constexpr ( detail::has_tensor_types<T,EL>::value )
|
||||
return retrieve_extents(expr.el);
|
||||
|
||||
else if constexpr ( detail::has_tensor_types<T,ER>::value )
|
||||
return retrieve_extents(expr.er);
|
||||
}
|
||||
|
||||
/** @brief Retrieves extents of the binary tensor expression
|
||||
*
|
||||
* @note tensor expression must be a binary tree with at least one tensor type
|
||||
*
|
||||
* @returns extents of the child expression if it is a tensor or extents of a child of its child.
|
||||
*/
|
||||
template<class T, class E, class OP>
|
||||
auto retrieve_extents(unary_tensor_expression<T,E,OP> const& expr)
|
||||
{
|
||||
|
||||
static_assert(detail::has_tensor_types<T,unary_tensor_expression<T,E,OP>>::value,
|
||||
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
|
||||
|
||||
if constexpr ( std::is_same<T,E>::value )
|
||||
return expr.e.extents();
|
||||
|
||||
else if constexpr ( detail::has_tensor_types<T,E>::value )
|
||||
return retrieve_extents(expr.e);
|
||||
}
|
||||
|
||||
} // namespace boost::numeric::ublas::detail
|
||||
|
||||
|
||||
///////////////
|
||||
|
||||
namespace boost::numeric::ublas::detail {
|
||||
|
||||
template<class T, class F, class A, class S>
|
||||
auto all_extents_equal(tensor<T,F,A> const& t, basic_extents<S> const& extents)
|
||||
{
|
||||
return extents == t.extents();
|
||||
}
|
||||
|
||||
template<class T, class D, class S>
|
||||
auto all_extents_equal(tensor_expression<T,D> const& expr, basic_extents<S> const& extents)
|
||||
{
|
||||
static_assert(detail::has_tensor_types<T,tensor_expression<T,D>>::value,
|
||||
"Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
|
||||
auto const& cast_expr = static_cast<D const&>(expr);
|
||||
|
||||
|
||||
if constexpr ( std::is_same<T,D>::value )
|
||||
if( extents != cast_expr.extents() )
|
||||
return false;
|
||||
|
||||
if constexpr ( detail::has_tensor_types<T,D>::value )
|
||||
if ( !all_extents_equal(cast_expr, extents))
|
||||
return false;
|
||||
|
||||
return true;
|
||||
|
||||
}
|
||||
|
||||
template<class T, class EL, class ER, class OP, class S>
|
||||
auto all_extents_equal(binary_tensor_expression<T,EL,ER,OP> const& expr, basic_extents<S> const& extents)
|
||||
{
|
||||
static_assert(detail::has_tensor_types<T,binary_tensor_expression<T,EL,ER,OP>>::value,
|
||||
"Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
|
||||
|
||||
if constexpr ( std::is_same<T,EL>::value )
|
||||
if(extents != expr.el.extents())
|
||||
return false;
|
||||
|
||||
if constexpr ( std::is_same<T,ER>::value )
|
||||
if(extents != expr.er.extents())
|
||||
return false;
|
||||
|
||||
if constexpr ( detail::has_tensor_types<T,EL>::value )
|
||||
if(!all_extents_equal(expr.el, extents))
|
||||
return false;
|
||||
|
||||
if constexpr ( detail::has_tensor_types<T,ER>::value )
|
||||
if(!all_extents_equal(expr.er, extents))
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
template<class T, class E, class OP, class S>
|
||||
auto all_extents_equal(unary_tensor_expression<T,E,OP> const& expr, basic_extents<S> const& extents)
|
||||
{
|
||||
|
||||
static_assert(detail::has_tensor_types<T,unary_tensor_expression<T,E,OP>>::value,
|
||||
"Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
|
||||
|
||||
if constexpr ( std::is_same<T,E>::value )
|
||||
if(extents != expr.e.extents())
|
||||
return false;
|
||||
|
||||
if constexpr ( detail::has_tensor_types<T,E>::value )
|
||||
if(!all_extents_equal(expr.e, extents))
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace boost::numeric::ublas::detail
|
||||
|
||||
|
||||
namespace boost::numeric::ublas::detail {
|
||||
|
||||
|
||||
/** @brief Evaluates expression for a tensor
|
||||
*
|
||||
* Assigns the results of the expression to the tensor.
|
||||
*
|
||||
* \note Checks if shape of the tensor matches those of all tensors within the expression.
|
||||
*/
|
||||
template<class tensor_type, class derived_type>
|
||||
void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type> const& expr)
|
||||
{
|
||||
if constexpr (detail::has_tensor_types<tensor_type, tensor_expression<tensor_type,derived_type> >::value )
|
||||
if(!detail::all_extents_equal(expr, lhs.extents() ))
|
||||
throw std::runtime_error("Error in boost::numeric::ublas::tensor: expression contains tensors with different shapes.");
|
||||
|
||||
#pragma omp parallel for
|
||||
for(auto i = 0u; i < lhs.size(); ++i)
|
||||
lhs(i) = expr()(i);
|
||||
}
|
||||
|
||||
/** @brief Evaluates expression for a tensor
|
||||
*
|
||||
* Applies a unary function to the results of the expressions before the assignment.
|
||||
* Usually applied needed for unary operators such as A += C;
|
||||
*
|
||||
* \note Checks if shape of the tensor matches those of all tensors within the expression.
|
||||
*/
|
||||
template<class tensor_type, class derived_type, class unary_fn>
|
||||
void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type> const& expr, unary_fn const fn)
|
||||
{
|
||||
|
||||
if constexpr (detail::has_tensor_types< tensor_type, tensor_expression<tensor_type,derived_type> >::value )
|
||||
if(!detail::all_extents_equal( expr, lhs.extents() ))
|
||||
throw std::runtime_error("Error in boost::numeric::ublas::tensor: expression contains tensors with different shapes.");
|
||||
|
||||
#pragma omp parallel for
|
||||
for(auto i = 0u; i < lhs.size(); ++i)
|
||||
fn(lhs(i), expr()(i));
|
||||
}
|
||||
|
||||
|
||||
|
||||
/** @brief Evaluates expression for a tensor
|
||||
*
|
||||
* Applies a unary function to the results of the expressions before the assignment.
|
||||
* Usually applied needed for unary operators such as A += C;
|
||||
*
|
||||
* \note Checks if shape of the tensor matches those of all tensors within the expression.
|
||||
*/
|
||||
template<class tensor_type, class unary_fn>
|
||||
void eval(tensor_type& lhs, unary_fn const fn)
|
||||
{
|
||||
#pragma omp parallel for
|
||||
for(auto i = 0u; i < lhs.size(); ++i)
|
||||
fn(lhs(i));
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
#endif
|
||||
Reference in New Issue
Block a user