5 #ifndef GKO_PUBLIC_CORE_BASE_RANGE_HPP_
6 #define GKO_PUBLIC_CORE_BASE_RANGE_HPP_
11 #include <ginkgo/core/base/math.hpp>
12 #include <ginkgo/core/base/types.hpp>
13 #include <ginkgo/core/base/utils.hpp>
55 :
span{point, point + 1}
94 GKO_ATTRIBUTES GKO_INLINE constexpr
bool operator<(
const span& first,
101 GKO_ATTRIBUTES GKO_INLINE constexpr
bool operator<=(
const span& first,
104 return first.end <= second.begin;
108 GKO_ATTRIBUTES GKO_INLINE constexpr
bool operator>(
const span& first,
111 return second < first;
115 GKO_ATTRIBUTES GKO_INLINE constexpr
bool operator>=(
const span& first,
118 return second <= first;
122 GKO_ATTRIBUTES GKO_INLINE constexpr
bool operator==(
const span& first,
125 return first.begin == second.begin && first.end == second.end;
129 GKO_ATTRIBUTES GKO_INLINE constexpr
bool operator!=(
const span& first,
132 return !(first == second);
139 template <
size_type CurrentDimension = 0,
typename FirstRange,
140 typename SecondRange>
141 GKO_ATTRIBUTES constexpr GKO_INLINE
142 std::enable_if_t<(CurrentDimension >=
max(FirstRange::dimensionality,
143 SecondRange::dimensionality)),
145 equal_dimensions(
const FirstRange&,
const SecondRange&)
150 template <
size_type CurrentDimension = 0,
typename FirstRange,
151 typename SecondRange>
152 GKO_ATTRIBUTES constexpr GKO_INLINE
153 std::enable_if_t<(CurrentDimension <
max(FirstRange::dimensionality,
154 SecondRange::dimensionality)),
156 equal_dimensions(
const FirstRange& first,
const SecondRange& second)
158 return first.length(CurrentDimension) == second.length(CurrentDimension) &&
159 equal_dimensions<CurrentDimension + 1>(first, second);
172 template <
class First,
class... Rest>
173 struct head<First, Rest...> {
180 template <
class... T>
181 using head_t =
typename head<T...>::type;
296 template <
typename Accessor>
323 typename... AccessorParams,
324 typename = std::enable_if_t<
325 sizeof...(AccessorParams) != 1 ||
327 range, std::decay<detail::head_t<AccessorParams...>>>::value>>
328 GKO_ATTRIBUTES constexpr
explicit range(AccessorParams&&... params)
329 : accessor_{std::forward<AccessorParams>(params)...}
344 template <
typename... DimensionTypes>
345 GKO_ATTRIBUTES constexpr
auto operator()(DimensionTypes&&... dimensions)
346 const -> decltype(std::declval<accessor>()(
347 std::forward<DimensionTypes>(dimensions)...))
350 "Too many dimensions in range call");
351 return accessor_(std::forward<DimensionTypes>(dimensions)...);
362 template <
typename OtherAccessor>
366 GKO_ASSERT(detail::equal_dimensions(*
this, other));
367 accessor_.copy_from(other);
386 GKO_ASSERT(detail::equal_dimensions(*
this, other));
402 return accessor_.length(dimension);
440 enum class operation_kind { range_by_range, scalar_by_range, range_by_scalar };
443 template <
typename Accessor,
typename Operation>
444 struct implement_unary_operation {
445 using accessor = Accessor;
446 static constexpr
size_type dimensionality = accessor::dimensionality;
448 GKO_ATTRIBUTES constexpr
explicit implement_unary_operation(
449 const Accessor& operand)
453 template <
typename... DimensionTypes>
454 GKO_ATTRIBUTES constexpr
auto operator()(
455 const DimensionTypes&... dimensions)
const
456 -> decltype(Operation::evaluate(std::declval<accessor>(),
459 return Operation::evaluate(operand, dimensions...);
464 return operand.length(dimension);
467 template <
typename OtherAccessor>
468 GKO_ATTRIBUTES
void copy_from(
const OtherAccessor& other)
const =
delete;
470 const accessor operand;
474 template <operation_kind Kind,
typename FirstOperand,
typename SecondOperand,
476 struct implement_binary_operation {};
478 template <
typename FirstAccessor,
typename SecondAccessor,
typename Operation>
479 struct implement_binary_operation<operation_kind::range_by_range, FirstAccessor,
480 SecondAccessor, Operation> {
481 using first_accessor = FirstAccessor;
482 using second_accessor = SecondAccessor;
483 static_assert(first_accessor::dimensionality ==
484 second_accessor::dimensionality,
485 "Both ranges need to have the same number of dimensions");
486 static constexpr
size_type dimensionality = first_accessor::dimensionality;
488 GKO_ATTRIBUTES
explicit implement_binary_operation(
489 const FirstAccessor& first,
const SecondAccessor& second)
490 : first{first}, second{second}
492 GKO_ASSERT(gko::detail::equal_dimensions(first, second));
495 template <
typename... DimensionTypes>
496 GKO_ATTRIBUTES constexpr
auto operator()(
497 const DimensionTypes&... dimensions)
const
498 -> decltype(Operation::evaluate_range_by_range(
499 std::declval<first_accessor>(), std::declval<second_accessor>(),
502 return Operation::evaluate_range_by_range(first, second, dimensions...);
507 return first.length(dimension);
510 template <
typename OtherAccessor>
511 GKO_ATTRIBUTES
void copy_from(
const OtherAccessor& other)
const =
delete;
513 const first_accessor first;
514 const second_accessor second;
517 template <
typename FirstOperand,
typename SecondAccessor,
typename Operation>
518 struct implement_binary_operation<operation_kind::scalar_by_range, FirstOperand,
519 SecondAccessor, Operation> {
520 using second_accessor = SecondAccessor;
521 static constexpr
size_type dimensionality = second_accessor::dimensionality;
523 GKO_ATTRIBUTES constexpr
explicit implement_binary_operation(
524 const FirstOperand& first,
const SecondAccessor& second)
525 : first{first}, second{second}
528 template <
typename... DimensionTypes>
529 GKO_ATTRIBUTES constexpr
auto operator()(
530 const DimensionTypes&... dimensions)
const
531 -> decltype(Operation::evaluate_scalar_by_range(
532 std::declval<FirstOperand>(), std::declval<second_accessor>(),
535 return Operation::evaluate_scalar_by_range(first, second,
541 return second.length(dimension);
544 template <
typename OtherAccessor>
545 GKO_ATTRIBUTES
void copy_from(
const OtherAccessor& other)
const =
delete;
547 const FirstOperand first;
548 const second_accessor second;
551 template <
typename FirstAccessor,
typename SecondOperand,
typename Operation>
552 struct implement_binary_operation<operation_kind::range_by_scalar,
553 FirstAccessor, SecondOperand, Operation> {
554 using first_accessor = FirstAccessor;
555 static constexpr
size_type dimensionality = first_accessor::dimensionality;
557 GKO_ATTRIBUTES constexpr
explicit implement_binary_operation(
558 const FirstAccessor& first,
const SecondOperand& second)
559 : first{first}, second{second}
562 template <
typename... DimensionTypes>
563 GKO_ATTRIBUTES constexpr
auto operator()(
564 const DimensionTypes&... dimensions)
const
565 -> decltype(Operation::evaluate_range_by_scalar(
566 std::declval<first_accessor>(), std::declval<SecondOperand>(),
569 return Operation::evaluate_range_by_scalar(first, second,
575 return first.length(dimension);
578 template <
typename OtherAccessor>
579 GKO_ATTRIBUTES
void copy_from(
const OtherAccessor& other)
const =
delete;
581 const first_accessor first;
582 const SecondOperand second;
588 #define GKO_DEPRECATED_UNARY_RANGE_OPERATION(_operation_deprecated_name, \
590 namespace accessor { \
591 template <typename Operand> \
592 struct GKO_DEPRECATED("Please use " #_operation_name) \
593 _operation_deprecated_name : _operation_name<Operand> {}; \
595 static_assert(true, \
596 "This assert is used to counter the false positive extra " \
597 "semi-colon warnings")
600 #define GKO_ENABLE_UNARY_RANGE_OPERATION(_operation_name, _operator_name, \
602 namespace accessor { \
603 template <typename Operand> \
604 struct _operation_name \
605 : ::gko::detail::implement_unary_operation<Operand, \
606 ::gko::_operator> { \
607 using ::gko::detail::implement_unary_operation< \
608 Operand, ::gko::_operator>::implement_unary_operation; \
611 GKO_BIND_UNARY_RANGE_OPERATION_TO_OPERATOR(_operation_name, _operator_name)
614 #define GKO_BIND_UNARY_RANGE_OPERATION_TO_OPERATOR(_operation_name, \
616 template <typename Accessor> \
617 GKO_ATTRIBUTES constexpr GKO_INLINE \
618 range<accessor::_operation_name<Accessor>> \
619 _operator_name(const range<Accessor>& operand) \
621 return range<accessor::_operation_name<Accessor>>( \
622 operand.get_accessor()); \
624 static_assert(true, \
625 "This assert is used to counter the false positive extra " \
626 "semi-colon warnings")
629 #define GKO_DEFINE_SIMPLE_UNARY_OPERATION(_name, ...) \
632 template <typename Operand> \
633 GKO_ATTRIBUTES static constexpr auto simple_evaluate_impl( \
634 const Operand& operand) -> decltype(__VA_ARGS__) \
636 return __VA_ARGS__; \
640 template <typename AccessorType, typename... DimensionTypes> \
641 GKO_ATTRIBUTES static constexpr auto evaluate( \
642 const AccessorType& accessor, const DimensionTypes&... dimensions) \
643 -> decltype(simple_evaluate_impl(accessor(dimensions...))) \
645 return simple_evaluate_impl(accessor(dimensions...)); \
655 GKO_DEFINE_SIMPLE_UNARY_OPERATION(
unary_plus, +operand);
656 GKO_DEFINE_SIMPLE_UNARY_OPERATION(
unary_minus, -operand);
659 GKO_DEFINE_SIMPLE_UNARY_OPERATION(
logical_not, !operand);
662 GKO_DEFINE_SIMPLE_UNARY_OPERATION(
bitwise_not, ~(operand));
679 GKO_ENABLE_UNARY_RANGE_OPERATION(unary_plus,
operator+,
680 accessor::detail::unary_plus);
681 GKO_ENABLE_UNARY_RANGE_OPERATION(
unary_minus,
operator-,
682 accessor::detail::unary_minus);
685 GKO_ENABLE_UNARY_RANGE_OPERATION(
logical_not,
operator!,
686 accessor::detail::logical_not);
689 GKO_ENABLE_UNARY_RANGE_OPERATION(
bitwise_not,
operator~,
690 accessor::detail::bitwise_not);
695 accessor::detail::zero_operation);
697 accessor::detail::one_operation);
699 accessor::detail::abs_operation);
701 accessor::detail::real_operation);
703 accessor::detail::imag_operation);
705 accessor::detail::conj_operation);
707 accessor::detail::squared_norm_operation);
720 template <
typename Accessor>
722 using accessor = Accessor;
723 static constexpr
size_type dimensionality = accessor::dimensionality;
726 const Accessor& operand)
730 template <
typename FirstDimensionType,
typename SecondDimensionType,
731 typename... DimensionTypes>
732 GKO_ATTRIBUTES constexpr
auto operator()(
733 const FirstDimensionType& first_dim,
734 const SecondDimensionType& second_dim,
735 const DimensionTypes&... dims)
const
736 -> decltype(std::declval<accessor>()(second_dim, first_dim, dims...))
738 return operand(second_dim, first_dim, dims...);
743 return dimension < 2 ? operand.length(dimension ^ 1)
744 : operand.length(dimension);
747 template <
typename OtherAccessor>
748 GKO_ATTRIBUTES
void copy_from(
const OtherAccessor& other)
const =
delete;
750 const accessor operand;
757 GKO_BIND_UNARY_RANGE_OPERATION_TO_OPERATOR(transpose_operation,
transpose);
760 #undef GKO_DEPRECATED_UNARY_RANGE_OPERATION
761 #undef GKO_DEFINE_SIMPLE_UNARY_OPERATION
762 #undef GKO_ENABLE_UNARY_RANGE_OPERATION
765 #define GKO_ENABLE_BINARY_RANGE_OPERATION(_operation_name, _operator_name, \
767 namespace accessor { \
768 template <::gko::detail::operation_kind Kind, typename FirstOperand, \
769 typename SecondOperand> \
770 struct _operation_name \
771 : ::gko::detail::implement_binary_operation< \
772 Kind, FirstOperand, SecondOperand, ::gko::_operator> { \
773 using ::gko::detail::implement_binary_operation< \
774 Kind, FirstOperand, SecondOperand, \
775 ::gko::_operator>::implement_binary_operation; \
778 GKO_BIND_RANGE_OPERATION_TO_OPERATOR(_operation_name, _operator_name); \
779 static_assert(true, \
780 "This assert is used to counter the false positive extra " \
781 "semi-colon warnings")
784 #define GKO_BIND_RANGE_OPERATION_TO_OPERATOR(_operation_name, _operator_name) \
785 template <typename Accessor> \
786 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
787 ::gko::detail::operation_kind::range_by_range, Accessor, Accessor>> \
788 _operator_name(const range<Accessor>& first, \
789 const range<Accessor>& second) \
791 return range<accessor::_operation_name< \
792 ::gko::detail::operation_kind::range_by_range, Accessor, \
793 Accessor>>(first.get_accessor(), second.get_accessor()); \
796 template <typename FirstAccessor, typename SecondAccessor> \
797 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
798 ::gko::detail::operation_kind::range_by_range, FirstAccessor, \
800 _operator_name(const range<FirstAccessor>& first, \
801 const range<SecondAccessor>& second) \
803 return range<accessor::_operation_name< \
804 ::gko::detail::operation_kind::range_by_range, FirstAccessor, \
805 SecondAccessor>>(first.get_accessor(), second.get_accessor()); \
808 template <typename FirstAccessor, typename SecondOperand> \
809 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
810 ::gko::detail::operation_kind::range_by_scalar, FirstAccessor, \
812 _operator_name(const range<FirstAccessor>& first, \
813 const SecondOperand& second) \
815 return range<accessor::_operation_name< \
816 ::gko::detail::operation_kind::range_by_scalar, FirstAccessor, \
817 SecondOperand>>(first.get_accessor(), second); \
820 template <typename FirstOperand, typename SecondAccessor> \
821 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
822 ::gko::detail::operation_kind::scalar_by_range, FirstOperand, \
824 _operator_name(const FirstOperand& first, \
825 const range<SecondAccessor>& second) \
827 return range<accessor::_operation_name< \
828 ::gko::detail::operation_kind::scalar_by_range, FirstOperand, \
829 SecondAccessor>>(first, second.get_accessor()); \
831 static_assert(true, \
832 "This assert is used to counter the false positive extra " \
833 "semi-colon warnings")
836 #define GKO_DEPRECATED_SIMPLE_BINARY_OPERATION(_deprecated_name, _name) \
837 struct GKO_DEPRECATED("Please use " #_name) _deprecated_name : _name {}
839 #define GKO_DEFINE_SIMPLE_BINARY_OPERATION(_name, ...) \
842 template <typename FirstOperand, typename SecondOperand> \
843 GKO_ATTRIBUTES constexpr static auto simple_evaluate_impl( \
844 const FirstOperand& first, const SecondOperand& second) \
845 -> decltype(__VA_ARGS__) \
847 return __VA_ARGS__; \
851 template <typename FirstAccessor, typename SecondAccessor, \
852 typename... DimensionTypes> \
853 GKO_ATTRIBUTES static constexpr auto evaluate_range_by_range( \
854 const FirstAccessor& first, const SecondAccessor& second, \
855 const DimensionTypes&... dims) \
856 -> decltype(simple_evaluate_impl(first(dims...), second(dims...))) \
858 return simple_evaluate_impl(first(dims...), second(dims...)); \
861 template <typename FirstOperand, typename SecondAccessor, \
862 typename... DimensionTypes> \
863 GKO_ATTRIBUTES static constexpr auto evaluate_scalar_by_range( \
864 const FirstOperand& first, const SecondAccessor& second, \
865 const DimensionTypes&... dims) \
866 -> decltype(simple_evaluate_impl(first, second(dims...))) \
868 return simple_evaluate_impl(first, second(dims...)); \
871 template <typename FirstAccessor, typename SecondOperand, \
872 typename... DimensionTypes> \
873 GKO_ATTRIBUTES static constexpr auto evaluate_range_by_scalar( \
874 const FirstAccessor& first, const SecondOperand& second, \
875 const DimensionTypes&... dims) \
876 -> decltype(simple_evaluate_impl(first(dims...), second)) \
878 return simple_evaluate_impl(first(dims...), second); \
888 GKO_DEFINE_SIMPLE_BINARY_OPERATION(add, first + second);
889 GKO_DEFINE_SIMPLE_BINARY_OPERATION(sub, first - second);
890 GKO_DEFINE_SIMPLE_BINARY_OPERATION(mul, first* second);
891 GKO_DEFINE_SIMPLE_BINARY_OPERATION(div, first / second);
892 GKO_DEFINE_SIMPLE_BINARY_OPERATION(mod, first % second);
895 GKO_DEFINE_SIMPLE_BINARY_OPERATION(less, first < second);
896 GKO_DEFINE_SIMPLE_BINARY_OPERATION(greater, first > second);
897 GKO_DEFINE_SIMPLE_BINARY_OPERATION(less_or_equal, first <= second);
898 GKO_DEFINE_SIMPLE_BINARY_OPERATION(greater_or_equal, first >= second);
899 GKO_DEFINE_SIMPLE_BINARY_OPERATION(equal, first == second);
900 GKO_DEFINE_SIMPLE_BINARY_OPERATION(not_equal, first != second);
903 GKO_DEFINE_SIMPLE_BINARY_OPERATION(logical_or, first || second);
904 GKO_DEFINE_SIMPLE_BINARY_OPERATION(logical_and, first&& second);
907 GKO_DEFINE_SIMPLE_BINARY_OPERATION(bitwise_or, first | second);
908 GKO_DEFINE_SIMPLE_BINARY_OPERATION(bitwise_and, first& second);
909 GKO_DEFINE_SIMPLE_BINARY_OPERATION(bitwise_xor, first ^ second);
910 GKO_DEFINE_SIMPLE_BINARY_OPERATION(left_shift, first << second);
911 GKO_DEFINE_SIMPLE_BINARY_OPERATION(right_shift, first >> second);
914 GKO_DEFINE_SIMPLE_BINARY_OPERATION(max_operation,
max(first, second));
915 GKO_DEFINE_SIMPLE_BINARY_OPERATION(min_operation,
min(first, second));
917 GKO_DEPRECATED_SIMPLE_BINARY_OPERATION(max_operaton, max_operation);
918 GKO_DEPRECATED_SIMPLE_BINARY_OPERATION(min_operaton, min_operation);
924 GKO_ENABLE_BINARY_RANGE_OPERATION(
add,
operator+, accessor::detail::add);
925 GKO_ENABLE_BINARY_RANGE_OPERATION(
sub,
operator-, accessor::detail::sub);
926 GKO_ENABLE_BINARY_RANGE_OPERATION(
mul,
operator*, accessor::detail::mul);
927 GKO_ENABLE_BINARY_RANGE_OPERATION(
div,
operator/, accessor::detail::div);
928 GKO_ENABLE_BINARY_RANGE_OPERATION(
mod,
operator%, accessor::detail::mod);
931 GKO_ENABLE_BINARY_RANGE_OPERATION(
less,
operator<, accessor::detail::less);
932 GKO_ENABLE_BINARY_RANGE_OPERATION(
greater,
operator>,
933 accessor::detail::greater);
935 accessor::detail::less_or_equal);
937 accessor::detail::greater_or_equal);
938 GKO_ENABLE_BINARY_RANGE_OPERATION(
equal,
operator==, accessor::detail::equal);
939 GKO_ENABLE_BINARY_RANGE_OPERATION(
not_equal,
operator!=,
940 accessor::detail::not_equal);
943 GKO_ENABLE_BINARY_RANGE_OPERATION(
logical_or,
operator||,
944 accessor::detail::logical_or);
945 GKO_ENABLE_BINARY_RANGE_OPERATION(
logical_and,
operator&&,
946 accessor::detail::logical_and);
949 GKO_ENABLE_BINARY_RANGE_OPERATION(
bitwise_or,
operator|,
950 accessor::detail::bitwise_or);
951 GKO_ENABLE_BINARY_RANGE_OPERATION(
bitwise_and,
operator&,
952 accessor::detail::bitwise_and);
953 GKO_ENABLE_BINARY_RANGE_OPERATION(
bitwise_xor,
operator^,
954 accessor::detail::bitwise_xor);
955 GKO_ENABLE_BINARY_RANGE_OPERATION(
left_shift,
operator<<,
956 accessor::detail::left_shift);
957 GKO_ENABLE_BINARY_RANGE_OPERATION(
right_shift,
operator>>,
958 accessor::detail::right_shift);
962 accessor::detail::max_operation);
964 accessor::detail::min_operation);
971 template <gko::detail::operation_kind Kind,
typename FirstAccessor,
972 typename SecondAccessor>
974 static_assert(Kind == gko::detail::operation_kind::range_by_range,
975 "Matrix multiplication expects both operands to be ranges");
976 using first_accessor = FirstAccessor;
977 using second_accessor = SecondAccessor;
978 static_assert(first_accessor::dimensionality ==
979 second_accessor::dimensionality,
980 "Both ranges need to have the same number of dimensions");
981 static constexpr
size_type dimensionality = first_accessor::dimensionality;
983 GKO_ATTRIBUTES
explicit mmul_operation(
const FirstAccessor& first,
984 const SecondAccessor& second)
985 : first{first}, second{second}
987 GKO_ASSERT(first.length(1) == second.length(0));
988 GKO_ASSERT(gko::detail::equal_dimensions<2>(first, second));
991 template <
typename FirstDimension,
typename SecondDimension,
992 typename... DimensionTypes>
993 GKO_ATTRIBUTES
auto operator()(
const FirstDimension& row,
994 const SecondDimension& col,
995 const DimensionTypes&... rest)
const
996 -> decltype(std::declval<FirstAccessor>()(row, 0, rest...) *
997 std::declval<SecondAccessor>()(0, col, rest...) +
998 std::declval<FirstAccessor>()(row, 1, rest...) *
999 std::declval<SecondAccessor>()(1, col, rest...))
1002 decltype(first(row, 0, rest...) * second(0, col, rest...) +
1003 first(row, 1, rest...) * second(1, col, rest...));
1004 GKO_ASSERT(first.length(1) == second.length(0));
1005 auto result = zero<result_type>();
1006 const auto size = first.length(1);
1007 for (
auto i =
zero(size); i < size; ++i) {
1008 result += first(row, i, rest...) * second(i, col, rest...);
1015 return dimension == 1 ? second.length(1) : first.length(dimension);
1018 template <
typename OtherAccessor>
1019 GKO_ATTRIBUTES
void copy_from(
const OtherAccessor& other)
const =
delete;
1021 const first_accessor first;
1022 const second_accessor second;
1029 GKO_BIND_RANGE_OPERATION_TO_OPERATOR(mmul_operation, mmul);
1032 #undef GKO_DEFINE_SIMPLE_BINARY_OPERATION
1033 #undef GKO_ENABLE_BINARY_RANGE_OPERATION
1039 #endif // GKO_PUBLIC_CORE_BASE_RANGE_HPP_