Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion external/cuda.cmake
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# cmake 3.17 decouples C++ and CUDA standards, see https://gitlab.kitware.com/cmake/cmake/issues/19123
# cmake 3.18 knows that CUDA 11 provides cuda_std_17
cmake_minimum_required(VERSION 3.18.0)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD 20)
set(CMAKE_CUDA_EXTENSIONS OFF)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
Expand Down
87 changes: 54 additions & 33 deletions src/TiledArray/tensor/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1745,9 +1745,11 @@ class Tensor {
/// \param right The tensor that will be added to this tensor
/// \return A new tensor where the elements are the sum of the elements of
/// \c this and \c other
template <typename Right,
typename std::enable_if<is_tensor<Right>::value>::type* = nullptr>
Tensor add(const Right& right) const& {
template <typename Right>
requires(is_tensor<Right>::value &&
detail::sum_convertible_to<value_type, const value_type&,
const value_t<Right>&>)
Tensor add(const Right& right) const {
// early exit for empty right
if (right.empty()) return this->clone();

Expand Down Expand Up @@ -1780,13 +1782,31 @@ class Tensor {
/// \param right The tensor that will be added to this tensor
/// \return A new tensor where the elements are the sum of the elements of
/// \c this and \c other
template <typename Right,
typename std::enable_if<is_tensor<Right>::value>::type* = nullptr>
template <typename Right>
requires(is_tensor<Right>::value &&
detail::addable_to<value_type&, const value_t<Right>&>)
Tensor add(const Right& right) && {
add_to(right);
return std::move(*this);
}

/// Add this and \c other to construct a new tensor of type that differs from
/// this

/// \tparam Right The right-hand tensor type
/// \param right The tensor that will be added to this tensor
/// \return A new tensor where the elements are the sum of the elements of
/// \c this and \c other
template <typename Right>
requires(detail::is_tensor_v<Right> &&
!detail::sum_convertible_to<value_type, const value_type&,
const value_t<Right>&>)
auto add(const Right& right) const {
return binary(right, [](const value_type& l, const value_t<Right>& r) {
return l + r;
});
}

/// Add this and \c other to construct a new, permuted tensor

/// \tparam Right The right-hand tensor type
Expand All @@ -1795,13 +1815,13 @@ class Tensor {
/// \param perm The permutation to be applied to this tensor
/// \return A new tensor where the elements are the sum of the elements of
/// \c this and \c other
template <
typename Right, typename Perm,
typename std::enable_if<is_tensor<Right>::value &&
detail::is_permutation_v<Perm>>::type* = nullptr>
Tensor add(const Right& right, const Perm& perm) const {
template <typename Right, typename Perm>
requires(is_tensor<Right>::value && detail::is_permutation_v<Perm> &&
detail::addable<const value_type&, const value_t<Right>&>)
auto add(const Right& right, const Perm& perm) const {
return binary(
right, [](const value_type& l, const value_type& r) { return l + r; },
right,
[](const value_type& l, const value_t<Right>& r) { return l + r; },
perm);
}

Expand All @@ -1813,14 +1833,14 @@ class Tensor {
/// \param factor The scaling factor
/// \return A new tensor where the elements are the sum of the elements of
/// \c this and \c other, scaled by \c factor
template <
typename Right, typename Scalar,
typename std::enable_if<is_tensor<Right>::value &&
detail::is_numeric_v<Scalar>>::type* = nullptr>
Tensor add(const Right& right, const Scalar factor) const {
return binary(right, [factor](const value_type& l, const value_type& r) {
return (l + r) * factor;
});
template <typename Right, typename Scalar>
requires(is_tensor<Right>::value && detail::is_numeric_v<Scalar> &&
detail::addable<const value_type&, const value_t<Right>&>)
auto add(const Right& right, const Scalar factor) const {
return binary(right,
[factor](const value_type& l, const value_t<Right>& r) {
return (l + r) * factor;
});
}

/// Scale and add this and \c other to construct a new, permuted tensor
Expand All @@ -1833,14 +1853,14 @@ class Tensor {
/// \param perm The permutation to be applied to this tensor
/// \return A new tensor where the elements are the sum of the elements of
/// \c this and \c other, scaled by \c factor
template <typename Right, typename Scalar, typename Perm,
typename std::enable_if<
is_tensor<Right>::value && detail::is_numeric_v<Scalar> &&
detail::is_permutation_v<Perm>>::type* = nullptr>
Tensor add(const Right& right, const Scalar factor, const Perm& perm) const {
template <typename Right, typename Scalar, typename Perm>
requires(is_tensor<Right>::value && detail::is_numeric_v<Scalar> &&
detail::is_permutation_v<Perm> &&
detail::addable<const value_type&, const value_t<Right>&>)
auto add(const Right& right, const Scalar factor, const Perm& perm) const {
return binary(
right,
[factor](const value_type& l, const value_type& r) {
[factor](const value_type& l, const value_t<Right>& r) {
return (l + r) * factor;
},
perm);
Expand Down Expand Up @@ -1879,8 +1899,9 @@ class Tensor {
/// \tparam Right The right-hand tensor type
/// \param right The tensor that will be added to this tensor
/// \return A reference to this tensor
template <typename Right,
typename std::enable_if<is_tensor<Right>::value>::type* = nullptr>
template <typename Right>
requires(is_tensor<Right>::value &&
detail::addable_to<value_type&, const value_t<Right>&>)
Tensor& add_to(const Right& right) {
// early exit for empty right
if (right.empty()) return *this;
Expand All @@ -1902,10 +1923,9 @@ class Tensor {
/// \param right The tensor that will be added to this tensor
/// \param factor The scaling factor
/// \return A reference to this tensor
template <
typename Right, typename Scalar,
typename std::enable_if<is_tensor<Right>::value &&
detail::is_numeric_v<Scalar>>::type* = nullptr>
template <typename Right, typename Scalar>
requires(is_tensor<Right>::value && detail::is_numeric_v<Scalar> &&
detail::addable_to<value_type&, const value_t<Right>&>)
Tensor& add_to(const Right& right, const Scalar factor) {
return inplace_binary(
right, [factor](value_type& MADNESS_RESTRICT l,
Expand All @@ -1916,8 +1936,9 @@ class Tensor {

/// \param value The constant to be added
/// \return A reference to this tensor
template <typename Scalar,
typename = std::enable_if_t<detail::is_numeric_v<Scalar>>>
template <typename Scalar>
requires(detail::is_numeric_v<Scalar> &&
detail::addable_to<value_type&, const Scalar>)
Tensor& add_to(const Scalar value) {
return inplace_unary(
[value](value_type& MADNESS_RESTRICT res) { res += value; });
Expand Down
30 changes: 8 additions & 22 deletions src/TiledArray/tile_interface/add.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,8 @@ template <typename Left, typename Right,
inline decltype(auto) add(Left&& left, Right&& right) {
constexpr auto left_right =
(detail::has_member_function_add_anyreturn_v<Left&&, Right&&> &&
detail::has_member_function_add_anyreturn_v<Right&&, Left&&> &&
!std::is_reference_v<Right> && std::is_reference_v<Left>) ||
(detail::has_member_function_add_anyreturn_v<Left&&, Right&&> &&
!detail::has_member_function_add_anyreturn_v<Right&&, Left&&>);
!std::is_reference_v<Left>) ||
(!detail::has_member_function_add_anyreturn_v<Right&&, Left&&>);
if constexpr (left_right)
return std::forward<Left>(left).add(std::forward<Right>(right));
else
Expand All @@ -76,12 +74,8 @@ inline decltype(auto) add(Left&& left, Right&& right, const Scalar factor) {
constexpr auto left_right =
(detail::has_member_function_add_anyreturn_v<Left&&, Right&&,
const Scalar> &&
detail::has_member_function_add_anyreturn_v<Right&&, Left&&,
const Scalar> &&
!std::is_reference_v<Right> && std::is_reference_v<Left>) ||
(detail::has_member_function_add_anyreturn_v<Left&&, Right&&,
const Scalar> &&
!detail::has_member_function_add_anyreturn_v<Right&&, Left&&,
!std::is_reference_v<Left>) ||
(!detail::has_member_function_add_anyreturn_v<Right&&, Left&&,
const Scalar>);
if constexpr (left_right)
return std::forward<Left>(left).add(std::forward<Right>(right), factor);
Expand All @@ -108,12 +102,8 @@ inline decltype(auto) add(Left&& left, Right&& right, const Perm& perm) {
constexpr auto left_right =
(detail::has_member_function_add_anyreturn_v<Left&&, Right&&,
const Perm&> &&
detail::has_member_function_add_anyreturn_v<Right&&, Left&&,
const Perm&> &&
!std::is_reference_v<Right> && std::is_reference_v<Left>) ||
(detail::has_member_function_add_anyreturn_v<Left&&, Right&&,
const Perm&> &&
!detail::has_member_function_add_anyreturn_v<Right&&, Left&&,
!std::is_reference_v<Left>) ||
(!detail::has_member_function_add_anyreturn_v<Right&&, Left&&,
const Perm&>);
if constexpr (left_right)
return std::forward<Left>(left).add(std::forward<Right>(right), perm);
Expand Down Expand Up @@ -143,12 +133,8 @@ inline decltype(auto) add(Left&& left, Right&& right, const Scalar factor,
constexpr auto left_right =
(detail::has_member_function_add_anyreturn_v<Left&&, Right&&,
const Scalar, const Perm&> &&
detail::has_member_function_add_anyreturn_v<Right&&, Left&&,
const Scalar, const Perm&> &&
!std::is_reference_v<Right> && std::is_reference_v<Left>) ||
(detail::has_member_function_add_anyreturn_v<Left&&, Right&&,
const Scalar, const Perm&> &&
!detail::has_member_function_add_anyreturn_v<Right&&, Left&&,
!std::is_reference_v<Left>) ||
(!detail::has_member_function_add_anyreturn_v<Right&&, Left&&,
const Scalar, const Perm&>);
if constexpr (left_right)
return std::forward<Left>(left).add(std::forward<Right>(right), factor,
Expand Down
18 changes: 18 additions & 0 deletions src/TiledArray/type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -1585,6 +1585,24 @@ struct is_invocable_void : is_invocable_void_helper<void, F, Args...> {};

template <typename T>
struct type_printer;

// concepts

template <typename Summand1, typename Summand2>
concept addable = requires(Summand1 t, Summand2 u) {
{ t + u };
};

template <typename Sum, typename Summand>
concept addable_to = requires(Sum t, Summand u) {
{ t += u };
};

template <typename Result, typename Summand1, typename Summand2>
concept sum_convertible_to = requires(Summand1 t, Summand2 u) {
{ t + u } -> std::convertible_to<Result>;
};

} // namespace detail

} // namespace TiledArray
Expand Down
Loading