5 #ifndef GKO_PUBLIC_CORE_BASE_RANGE_HPP_
6 #define GKO_PUBLIC_CORE_BASE_RANGE_HPP_
11 #include <ginkgo/core/base/math.hpp>
12 #include <ginkgo/core/base/types.hpp>
13 #include <ginkgo/core/base/utils.hpp>
55 :
span{point, point + 1}
94 GKO_ATTRIBUTES GKO_INLINE constexpr
bool operator<(
const span& first,
101 GKO_ATTRIBUTES GKO_INLINE constexpr
bool operator<=(
const span& first,
104 return first.end <= second.begin;
108 GKO_ATTRIBUTES GKO_INLINE constexpr
bool operator>(
const span& first,
111 return second < first;
115 GKO_ATTRIBUTES GKO_INLINE constexpr
bool operator>=(
const span& first,
118 return second <= first;
122 GKO_ATTRIBUTES GKO_INLINE constexpr
bool operator==(
const span& first,
125 return first.begin == second.begin && first.end == second.end;
129 GKO_ATTRIBUTES GKO_INLINE constexpr
bool operator!=(
const span& first,
132 return !(first == second);
146 template <
size_type CurrentDimension = 0,
typename FirstRange,
147 typename SecondRange>
148 GKO_ATTRIBUTES constexpr GKO_INLINE
149 std::enable_if_t<(CurrentDimension >=
max(FirstRange::dimensionality,
150 SecondRange::dimensionality)),
152 equal_dimensions(
const FirstRange&,
const SecondRange&)
157 template <
size_type CurrentDimension = 0,
typename FirstRange,
158 typename SecondRange>
159 GKO_ATTRIBUTES constexpr GKO_INLINE
160 std::enable_if_t<(CurrentDimension <
max(FirstRange::dimensionality,
161 SecondRange::dimensionality)),
163 equal_dimensions(
const FirstRange& first,
const SecondRange& second)
165 return first.length(CurrentDimension) == second.length(CurrentDimension) &&
166 equal_dimensions<CurrentDimension + 1>(first, second);
179 template <
class First,
class... Rest>
180 struct head<First, Rest...> {
187 template <
class... T>
188 using head_t =
typename head<T...>::type;
303 template <
typename Accessor>
330 typename... AccessorParams,
331 typename = std::enable_if_t<
332 sizeof...(AccessorParams) != 1 ||
334 range, std::decay<detail::head_t<AccessorParams...>>>::value>>
335 GKO_ATTRIBUTES constexpr
explicit range(AccessorParams&&... params)
336 : accessor_{std::forward<AccessorParams>(params)...}
351 template <
typename... DimensionTypes>
352 GKO_ATTRIBUTES constexpr
auto operator()(DimensionTypes&&... dimensions)
353 const -> decltype(std::declval<accessor>()(
354 std::forward<DimensionTypes>(dimensions)...))
357 "Too many dimensions in range call");
358 return accessor_(std::forward<DimensionTypes>(dimensions)...);
369 template <
typename OtherAccessor>
373 GKO_ASSERT(detail::equal_dimensions(*
this, other));
374 accessor_.copy_from(other);
393 GKO_ASSERT(detail::equal_dimensions(*
this, other));
409 return accessor_.length(dimension);
447 enum class operation_kind { range_by_range, scalar_by_range, range_by_scalar };
450 template <
typename Accessor,
typename Operation>
451 struct implement_unary_operation {
452 using accessor = Accessor;
453 static constexpr
size_type dimensionality = accessor::dimensionality;
455 GKO_ATTRIBUTES constexpr
explicit implement_unary_operation(
456 const Accessor& operand)
460 template <
typename... DimensionTypes>
461 GKO_ATTRIBUTES constexpr
auto operator()(
462 const DimensionTypes&... dimensions)
const
463 -> decltype(Operation::evaluate(std::declval<accessor>(),
466 return Operation::evaluate(operand, dimensions...);
471 return operand.length(dimension);
474 template <
typename OtherAccessor>
475 GKO_ATTRIBUTES
void copy_from(
const OtherAccessor& other)
const =
delete;
477 const accessor operand;
481 template <operation_kind Kind,
typename FirstOperand,
typename SecondOperand,
483 struct implement_binary_operation {};
485 template <
typename FirstAccessor,
typename SecondAccessor,
typename Operation>
486 struct implement_binary_operation<operation_kind::range_by_range, FirstAccessor,
487 SecondAccessor, Operation> {
488 using first_accessor = FirstAccessor;
489 using second_accessor = SecondAccessor;
490 static_assert(first_accessor::dimensionality ==
491 second_accessor::dimensionality,
492 "Both ranges need to have the same number of dimensions");
493 static constexpr
size_type dimensionality = first_accessor::dimensionality;
495 GKO_ATTRIBUTES
explicit implement_binary_operation(
496 const FirstAccessor& first,
const SecondAccessor& second)
497 : first{first}, second{second}
499 GKO_ASSERT(gko::detail::equal_dimensions(first, second));
502 template <
typename... DimensionTypes>
503 GKO_ATTRIBUTES constexpr
auto operator()(
504 const DimensionTypes&... dimensions)
const
505 -> decltype(Operation::evaluate_range_by_range(
506 std::declval<first_accessor>(), std::declval<second_accessor>(),
509 return Operation::evaluate_range_by_range(first, second, dimensions...);
514 return first.length(dimension);
517 template <
typename OtherAccessor>
518 GKO_ATTRIBUTES
void copy_from(
const OtherAccessor& other)
const =
delete;
520 const first_accessor first;
521 const second_accessor second;
524 template <
typename FirstOperand,
typename SecondAccessor,
typename Operation>
525 struct implement_binary_operation<operation_kind::scalar_by_range, FirstOperand,
526 SecondAccessor, Operation> {
527 using second_accessor = SecondAccessor;
528 static constexpr
size_type dimensionality = second_accessor::dimensionality;
530 GKO_ATTRIBUTES constexpr
explicit implement_binary_operation(
531 const FirstOperand& first,
const SecondAccessor& second)
532 : first{first}, second{second}
535 template <
typename... DimensionTypes>
536 GKO_ATTRIBUTES constexpr
auto operator()(
537 const DimensionTypes&... dimensions)
const
538 -> decltype(Operation::evaluate_scalar_by_range(
539 std::declval<FirstOperand>(), std::declval<second_accessor>(),
542 return Operation::evaluate_scalar_by_range(first, second,
548 return second.length(dimension);
551 template <
typename OtherAccessor>
552 GKO_ATTRIBUTES
void copy_from(
const OtherAccessor& other)
const =
delete;
554 const FirstOperand first;
555 const second_accessor second;
558 template <
typename FirstAccessor,
typename SecondOperand,
typename Operation>
559 struct implement_binary_operation<operation_kind::range_by_scalar,
560 FirstAccessor, SecondOperand, Operation> {
561 using first_accessor = FirstAccessor;
562 static constexpr
size_type dimensionality = first_accessor::dimensionality;
564 GKO_ATTRIBUTES constexpr
explicit implement_binary_operation(
565 const FirstAccessor& first,
const SecondOperand& second)
566 : first{first}, second{second}
569 template <
typename... DimensionTypes>
570 GKO_ATTRIBUTES constexpr
auto operator()(
571 const DimensionTypes&... dimensions)
const
572 -> decltype(Operation::evaluate_range_by_scalar(
573 std::declval<first_accessor>(), std::declval<SecondOperand>(),
576 return Operation::evaluate_range_by_scalar(first, second,
582 return first.length(dimension);
585 template <
typename OtherAccessor>
586 GKO_ATTRIBUTES
void copy_from(
const OtherAccessor& other)
const =
delete;
588 const first_accessor first;
589 const SecondOperand second;
595 #define GKO_DEPRECATED_UNARY_RANGE_OPERATION(_operation_deprecated_name, \
597 namespace accessor { \
598 template <typename Operand> \
599 struct GKO_DEPRECATED("Please use " #_operation_name) \
600 _operation_deprecated_name : _operation_name<Operand> {}; \
602 static_assert(true, \
603 "This assert is used to counter the false positive extra " \
604 "semi-colon warnings")
607 #define GKO_ENABLE_UNARY_RANGE_OPERATION(_operation_name, _operator_name, \
609 namespace accessor { \
610 template <typename Operand> \
611 struct _operation_name \
612 : ::gko::detail::implement_unary_operation<Operand, \
613 ::gko::_operator> { \
614 using ::gko::detail::implement_unary_operation< \
615 Operand, ::gko::_operator>::implement_unary_operation; \
618 GKO_BIND_UNARY_RANGE_OPERATION_TO_OPERATOR(_operation_name, _operator_name)
621 #define GKO_BIND_UNARY_RANGE_OPERATION_TO_OPERATOR(_operation_name, \
623 template <typename Accessor> \
624 GKO_ATTRIBUTES constexpr GKO_INLINE \
625 range<accessor::_operation_name<Accessor>> \
626 _operator_name(const range<Accessor>& operand) \
628 return range<accessor::_operation_name<Accessor>>( \
629 operand.get_accessor()); \
631 static_assert(true, \
632 "This assert is used to counter the false positive extra " \
633 "semi-colon warnings")
636 #define GKO_DEFINE_SIMPLE_UNARY_OPERATION(_name, ...) \
639 template <typename Operand> \
640 GKO_ATTRIBUTES static constexpr auto simple_evaluate_impl( \
641 const Operand& operand) -> decltype(__VA_ARGS__) \
643 return __VA_ARGS__; \
647 template <typename AccessorType, typename... DimensionTypes> \
648 GKO_ATTRIBUTES static constexpr auto evaluate( \
649 const AccessorType& accessor, const DimensionTypes&... dimensions) \
650 -> decltype(simple_evaluate_impl(accessor(dimensions...))) \
652 return simple_evaluate_impl(accessor(dimensions...)); \
662 GKO_DEFINE_SIMPLE_UNARY_OPERATION(
unary_plus, +operand);
663 GKO_DEFINE_SIMPLE_UNARY_OPERATION(
unary_minus, -operand);
666 GKO_DEFINE_SIMPLE_UNARY_OPERATION(
logical_not, !operand);
669 GKO_DEFINE_SIMPLE_UNARY_OPERATION(
bitwise_not, ~(operand));
686 GKO_ENABLE_UNARY_RANGE_OPERATION(unary_plus,
operator+,
687 accessor::detail::unary_plus);
688 GKO_ENABLE_UNARY_RANGE_OPERATION(
unary_minus,
operator-,
689 accessor::detail::unary_minus);
692 GKO_ENABLE_UNARY_RANGE_OPERATION(
logical_not,
operator!,
693 accessor::detail::logical_not);
696 GKO_ENABLE_UNARY_RANGE_OPERATION(
bitwise_not,
operator~,
697 accessor::detail::bitwise_not);
702 accessor::detail::zero_operation);
704 accessor::detail::one_operation);
706 accessor::detail::abs_operation);
708 accessor::detail::real_operation);
710 accessor::detail::imag_operation);
712 accessor::detail::conj_operation);
714 accessor::detail::squared_norm_operation);
727 template <
typename Accessor>
729 using accessor = Accessor;
730 static constexpr
size_type dimensionality = accessor::dimensionality;
733 const Accessor& operand)
737 template <
typename FirstDimensionType,
typename SecondDimensionType,
738 typename... DimensionTypes>
739 GKO_ATTRIBUTES constexpr
auto operator()(
740 const FirstDimensionType& first_dim,
741 const SecondDimensionType& second_dim,
742 const DimensionTypes&... dims)
const
743 -> decltype(std::declval<accessor>()(second_dim, first_dim, dims...))
745 return operand(second_dim, first_dim, dims...);
750 return dimension < 2 ? operand.length(dimension ^ 1)
751 : operand.length(dimension);
754 template <
typename OtherAccessor>
755 GKO_ATTRIBUTES
void copy_from(
const OtherAccessor& other)
const =
delete;
757 const accessor operand;
764 GKO_BIND_UNARY_RANGE_OPERATION_TO_OPERATOR(transpose_operation,
transpose);
767 #undef GKO_DEPRECATED_UNARY_RANGE_OPERATION
768 #undef GKO_DEFINE_SIMPLE_UNARY_OPERATION
769 #undef GKO_ENABLE_UNARY_RANGE_OPERATION
772 #define GKO_ENABLE_BINARY_RANGE_OPERATION(_operation_name, _operator_name, \
774 namespace accessor { \
775 template <::gko::detail::operation_kind Kind, typename FirstOperand, \
776 typename SecondOperand> \
777 struct _operation_name \
778 : ::gko::detail::implement_binary_operation< \
779 Kind, FirstOperand, SecondOperand, ::gko::_operator> { \
780 using ::gko::detail::implement_binary_operation< \
781 Kind, FirstOperand, SecondOperand, \
782 ::gko::_operator>::implement_binary_operation; \
785 GKO_BIND_RANGE_OPERATION_TO_OPERATOR(_operation_name, _operator_name); \
786 static_assert(true, \
787 "This assert is used to counter the false positive extra " \
788 "semi-colon warnings")
791 #define GKO_BIND_RANGE_OPERATION_TO_OPERATOR(_operation_name, _operator_name) \
792 template <typename Accessor> \
793 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
794 ::gko::detail::operation_kind::range_by_range, Accessor, Accessor>> \
795 _operator_name(const range<Accessor>& first, \
796 const range<Accessor>& second) \
798 return range<accessor::_operation_name< \
799 ::gko::detail::operation_kind::range_by_range, Accessor, \
800 Accessor>>(first.get_accessor(), second.get_accessor()); \
803 template <typename FirstAccessor, typename SecondAccessor> \
804 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
805 ::gko::detail::operation_kind::range_by_range, FirstAccessor, \
807 _operator_name(const range<FirstAccessor>& first, \
808 const range<SecondAccessor>& second) \
810 return range<accessor::_operation_name< \
811 ::gko::detail::operation_kind::range_by_range, FirstAccessor, \
812 SecondAccessor>>(first.get_accessor(), second.get_accessor()); \
815 template <typename FirstAccessor, typename SecondOperand> \
816 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
817 ::gko::detail::operation_kind::range_by_scalar, FirstAccessor, \
819 _operator_name(const range<FirstAccessor>& first, \
820 const SecondOperand& second) \
822 return range<accessor::_operation_name< \
823 ::gko::detail::operation_kind::range_by_scalar, FirstAccessor, \
824 SecondOperand>>(first.get_accessor(), second); \
827 template <typename FirstOperand, typename SecondAccessor> \
828 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
829 ::gko::detail::operation_kind::scalar_by_range, FirstOperand, \
831 _operator_name(const FirstOperand& first, \
832 const range<SecondAccessor>& second) \
834 return range<accessor::_operation_name< \
835 ::gko::detail::operation_kind::scalar_by_range, FirstOperand, \
836 SecondAccessor>>(first, second.get_accessor()); \
838 static_assert(true, \
839 "This assert is used to counter the false positive extra " \
840 "semi-colon warnings")
843 #define GKO_DEPRECATED_SIMPLE_BINARY_OPERATION(_deprecated_name, _name) \
844 struct GKO_DEPRECATED("Please use " #_name) _deprecated_name : _name {}
846 #define GKO_DEFINE_SIMPLE_BINARY_OPERATION(_name, ...) \
849 template <typename FirstOperand, typename SecondOperand> \
850 GKO_ATTRIBUTES constexpr static auto simple_evaluate_impl( \
851 const FirstOperand& first, const SecondOperand& second) \
852 -> decltype(__VA_ARGS__) \
854 return __VA_ARGS__; \
858 template <typename FirstAccessor, typename SecondAccessor, \
859 typename... DimensionTypes> \
860 GKO_ATTRIBUTES static constexpr auto evaluate_range_by_range( \
861 const FirstAccessor& first, const SecondAccessor& second, \
862 const DimensionTypes&... dims) \
863 -> decltype(simple_evaluate_impl(first(dims...), second(dims...))) \
865 return simple_evaluate_impl(first(dims...), second(dims...)); \
868 template <typename FirstOperand, typename SecondAccessor, \
869 typename... DimensionTypes> \
870 GKO_ATTRIBUTES static constexpr auto evaluate_scalar_by_range( \
871 const FirstOperand& first, const SecondAccessor& second, \
872 const DimensionTypes&... dims) \
873 -> decltype(simple_evaluate_impl(first, second(dims...))) \
875 return simple_evaluate_impl(first, second(dims...)); \
878 template <typename FirstAccessor, typename SecondOperand, \
879 typename... DimensionTypes> \
880 GKO_ATTRIBUTES static constexpr auto evaluate_range_by_scalar( \
881 const FirstAccessor& first, const SecondOperand& second, \
882 const DimensionTypes&... dims) \
883 -> decltype(simple_evaluate_impl(first(dims...), second)) \
885 return simple_evaluate_impl(first(dims...), second); \
895 GKO_DEFINE_SIMPLE_BINARY_OPERATION(add, first + second);
896 GKO_DEFINE_SIMPLE_BINARY_OPERATION(sub, first - second);
897 GKO_DEFINE_SIMPLE_BINARY_OPERATION(mul, first* second);
898 GKO_DEFINE_SIMPLE_BINARY_OPERATION(div, first / second);
899 GKO_DEFINE_SIMPLE_BINARY_OPERATION(mod, first % second);
902 GKO_DEFINE_SIMPLE_BINARY_OPERATION(less, first < second);
903 GKO_DEFINE_SIMPLE_BINARY_OPERATION(greater, first > second);
904 GKO_DEFINE_SIMPLE_BINARY_OPERATION(less_or_equal, first <= second);
905 GKO_DEFINE_SIMPLE_BINARY_OPERATION(greater_or_equal, first >= second);
906 GKO_DEFINE_SIMPLE_BINARY_OPERATION(equal, first == second);
907 GKO_DEFINE_SIMPLE_BINARY_OPERATION(not_equal, first != second);
910 GKO_DEFINE_SIMPLE_BINARY_OPERATION(logical_or, first || second);
911 GKO_DEFINE_SIMPLE_BINARY_OPERATION(logical_and, first&& second);
914 GKO_DEFINE_SIMPLE_BINARY_OPERATION(bitwise_or, first | second);
915 GKO_DEFINE_SIMPLE_BINARY_OPERATION(bitwise_and, first& second);
916 GKO_DEFINE_SIMPLE_BINARY_OPERATION(bitwise_xor, first ^ second);
917 GKO_DEFINE_SIMPLE_BINARY_OPERATION(left_shift, first << second);
918 GKO_DEFINE_SIMPLE_BINARY_OPERATION(right_shift, first >> second);
921 GKO_DEFINE_SIMPLE_BINARY_OPERATION(max_operation,
max(first, second));
922 GKO_DEFINE_SIMPLE_BINARY_OPERATION(min_operation,
min(first, second));
924 GKO_DEPRECATED_SIMPLE_BINARY_OPERATION(max_operaton, max_operation);
925 GKO_DEPRECATED_SIMPLE_BINARY_OPERATION(min_operaton, min_operation);
931 GKO_ENABLE_BINARY_RANGE_OPERATION(
add,
operator+, accessor::detail::add);
932 GKO_ENABLE_BINARY_RANGE_OPERATION(
sub,
operator-, accessor::detail::sub);
933 GKO_ENABLE_BINARY_RANGE_OPERATION(
mul,
operator*, accessor::detail::mul);
934 GKO_ENABLE_BINARY_RANGE_OPERATION(
div,
operator/, accessor::detail::div);
935 GKO_ENABLE_BINARY_RANGE_OPERATION(
mod,
operator%, accessor::detail::mod);
938 GKO_ENABLE_BINARY_RANGE_OPERATION(
less,
operator<, accessor::detail::less);
939 GKO_ENABLE_BINARY_RANGE_OPERATION(
greater,
operator>,
940 accessor::detail::greater);
942 accessor::detail::less_or_equal);
944 accessor::detail::greater_or_equal);
945 GKO_ENABLE_BINARY_RANGE_OPERATION(
equal,
operator==, accessor::detail::equal);
946 GKO_ENABLE_BINARY_RANGE_OPERATION(
not_equal,
operator!=,
947 accessor::detail::not_equal);
950 GKO_ENABLE_BINARY_RANGE_OPERATION(
logical_or,
operator||,
951 accessor::detail::logical_or);
952 GKO_ENABLE_BINARY_RANGE_OPERATION(
logical_and,
operator&&,
953 accessor::detail::logical_and);
956 GKO_ENABLE_BINARY_RANGE_OPERATION(
bitwise_or,
operator|,
957 accessor::detail::bitwise_or);
958 GKO_ENABLE_BINARY_RANGE_OPERATION(
bitwise_and,
operator&,
959 accessor::detail::bitwise_and);
960 GKO_ENABLE_BINARY_RANGE_OPERATION(
bitwise_xor,
operator^,
961 accessor::detail::bitwise_xor);
962 GKO_ENABLE_BINARY_RANGE_OPERATION(
left_shift,
operator<<,
963 accessor::detail::left_shift);
964 GKO_ENABLE_BINARY_RANGE_OPERATION(
right_shift,
operator>>,
965 accessor::detail::right_shift);
969 accessor::detail::max_operation);
971 accessor::detail::min_operation);
978 template <gko::detail::operation_kind Kind,
typename FirstAccessor,
979 typename SecondAccessor>
981 static_assert(Kind == gko::detail::operation_kind::range_by_range,
982 "Matrix multiplication expects both operands to be ranges");
983 using first_accessor = FirstAccessor;
984 using second_accessor = SecondAccessor;
985 static_assert(first_accessor::dimensionality ==
986 second_accessor::dimensionality,
987 "Both ranges need to have the same number of dimensions");
988 static constexpr
size_type dimensionality = first_accessor::dimensionality;
990 GKO_ATTRIBUTES
explicit mmul_operation(
const FirstAccessor& first,
991 const SecondAccessor& second)
992 : first{first}, second{second}
994 GKO_ASSERT(first.length(1) == second.length(0));
995 GKO_ASSERT(gko::detail::equal_dimensions<2>(first, second));
998 template <
typename FirstDimension,
typename SecondDimension,
999 typename... DimensionTypes>
1000 GKO_ATTRIBUTES
auto operator()(
const FirstDimension& row,
1001 const SecondDimension& col,
1002 const DimensionTypes&... rest)
const
1003 -> decltype(std::declval<FirstAccessor>()(row, 0, rest...) *
1004 std::declval<SecondAccessor>()(0, col, rest...) +
1005 std::declval<FirstAccessor>()(row, 1, rest...) *
1006 std::declval<SecondAccessor>()(1, col, rest...))
1009 decltype(first(row, 0, rest...) * second(0, col, rest...) +
1010 first(row, 1, rest...) * second(1, col, rest...));
1011 GKO_ASSERT(first.length(1) == second.length(0));
1012 auto result = zero<result_type>();
1013 const auto size = first.length(1);
1014 for (
auto i =
zero(size); i < size; ++i) {
1015 result += first(row, i, rest...) * second(i, col, rest...);
1022 return dimension == 1 ? second.length(1) : first.length(dimension);
1025 template <
typename OtherAccessor>
1026 GKO_ATTRIBUTES
void copy_from(
const OtherAccessor& other)
const =
delete;
1028 const first_accessor first;
1029 const second_accessor second;
1036 GKO_BIND_RANGE_OPERATION_TO_OPERATOR(mmul_operation, mmul);
1039 #undef GKO_DEFINE_SIMPLE_BINARY_OPERATION
1040 #undef GKO_ENABLE_BINARY_RANGE_OPERATION
1046 #endif // GKO_PUBLIC_CORE_BASE_RANGE_HPP_