5 #ifndef GKO_PUBLIC_CORE_BASE_RANGE_HPP_
6 #define GKO_PUBLIC_CORE_BASE_RANGE_HPP_
12 #include <ginkgo/core/base/math.hpp>
13 #include <ginkgo/core/base/types.hpp>
14 #include <ginkgo/core/base/utils.hpp>
56 :
span{point, point + 1}
95 GKO_ATTRIBUTES GKO_INLINE constexpr
bool operator<(
const span& first,
102 GKO_ATTRIBUTES GKO_INLINE constexpr
bool operator<=(
const span& first,
105 return first.end <= second.begin;
109 GKO_ATTRIBUTES GKO_INLINE constexpr
bool operator>(
const span& first,
112 return second < first;
116 GKO_ATTRIBUTES GKO_INLINE constexpr
bool operator>=(
const span& first,
119 return second <= first;
123 GKO_ATTRIBUTES GKO_INLINE constexpr
bool operator==(
const span& first,
126 return first.begin == second.begin && first.end == second.end;
130 GKO_ATTRIBUTES GKO_INLINE constexpr
bool operator!=(
const span& first,
133 return !(first == second);
140 template <
size_type CurrentDimension = 0,
typename FirstRange,
141 typename SecondRange>
142 GKO_ATTRIBUTES constexpr GKO_INLINE
143 std::enable_if_t<(CurrentDimension >=
max(FirstRange::dimensionality,
144 SecondRange::dimensionality)),
146 equal_dimensions(
const FirstRange&,
const SecondRange&)
151 template <
size_type CurrentDimension = 0,
typename FirstRange,
152 typename SecondRange>
153 GKO_ATTRIBUTES constexpr GKO_INLINE
154 std::enable_if_t<(CurrentDimension <
max(FirstRange::dimensionality,
155 SecondRange::dimensionality)),
157 equal_dimensions(
const FirstRange& first,
const SecondRange& second)
159 return first.length(CurrentDimension) == second.length(CurrentDimension) &&
160 equal_dimensions<CurrentDimension + 1>(first, second);
173 template <
class First,
class... Rest>
174 struct head<First, Rest...> {
181 template <
class... T>
182 using head_t =
typename head<T...>::type;
297 template <
typename Accessor>
324 typename... AccessorParams,
325 typename = std::enable_if_t<
326 sizeof...(AccessorParams) != 1 ||
328 range, std::decay<detail::head_t<AccessorParams...>>>::value>>
329 GKO_ATTRIBUTES constexpr
explicit range(AccessorParams&&... params)
330 : accessor_{std::forward<AccessorParams>(params)...}
345 template <
typename... DimensionTypes>
346 GKO_ATTRIBUTES constexpr
auto operator()(DimensionTypes&&... dimensions)
347 const -> decltype(std::declval<accessor>()(
348 std::forward<DimensionTypes>(dimensions)...))
351 "Too many dimensions in range call");
352 return accessor_(std::forward<DimensionTypes>(dimensions)...);
363 template <
typename OtherAccessor>
367 GKO_ASSERT(detail::equal_dimensions(*
this, other));
368 accessor_.copy_from(other);
387 GKO_ASSERT(detail::equal_dimensions(*
this, other));
403 return accessor_.length(dimension);
441 enum class operation_kind { range_by_range, scalar_by_range, range_by_scalar };
444 template <
typename Accessor,
typename Operation>
445 struct implement_unary_operation {
446 using accessor = Accessor;
447 static constexpr
size_type dimensionality = accessor::dimensionality;
449 GKO_ATTRIBUTES constexpr
explicit implement_unary_operation(
450 const Accessor& operand)
454 template <
typename... DimensionTypes>
455 GKO_ATTRIBUTES constexpr
auto operator()(
456 const DimensionTypes&... dimensions)
const
457 -> decltype(Operation::evaluate(std::declval<accessor>(),
460 return Operation::evaluate(operand, dimensions...);
465 return operand.length(dimension);
468 template <
typename OtherAccessor>
469 GKO_ATTRIBUTES
void copy_from(
const OtherAccessor& other)
const =
delete;
471 const accessor operand;
475 template <operation_kind Kind,
typename FirstOperand,
typename SecondOperand,
477 struct implement_binary_operation {};
479 template <
typename FirstAccessor,
typename SecondAccessor,
typename Operation>
480 struct implement_binary_operation<operation_kind::range_by_range, FirstAccessor,
481 SecondAccessor, Operation> {
482 using first_accessor = FirstAccessor;
483 using second_accessor = SecondAccessor;
484 static_assert(first_accessor::dimensionality ==
485 second_accessor::dimensionality,
486 "Both ranges need to have the same number of dimensions");
487 static constexpr
size_type dimensionality = first_accessor::dimensionality;
489 GKO_ATTRIBUTES
explicit implement_binary_operation(
490 const FirstAccessor& first,
const SecondAccessor& second)
491 : first{first}, second{second}
493 GKO_ASSERT(gko::detail::equal_dimensions(first, second));
496 template <
typename... DimensionTypes>
497 GKO_ATTRIBUTES constexpr
auto operator()(
498 const DimensionTypes&... dimensions)
const
499 -> decltype(Operation::evaluate_range_by_range(
500 std::declval<first_accessor>(), std::declval<second_accessor>(),
503 return Operation::evaluate_range_by_range(first, second, dimensions...);
508 return first.length(dimension);
511 template <
typename OtherAccessor>
512 GKO_ATTRIBUTES
void copy_from(
const OtherAccessor& other)
const =
delete;
514 const first_accessor first;
515 const second_accessor second;
518 template <
typename FirstOperand,
typename SecondAccessor,
typename Operation>
519 struct implement_binary_operation<operation_kind::scalar_by_range, FirstOperand,
520 SecondAccessor, Operation> {
521 using second_accessor = SecondAccessor;
522 static constexpr
size_type dimensionality = second_accessor::dimensionality;
524 GKO_ATTRIBUTES constexpr
explicit implement_binary_operation(
525 const FirstOperand& first,
const SecondAccessor& second)
526 : first{first}, second{second}
529 template <
typename... DimensionTypes>
530 GKO_ATTRIBUTES constexpr
auto operator()(
531 const DimensionTypes&... dimensions)
const
532 -> decltype(Operation::evaluate_scalar_by_range(
533 std::declval<FirstOperand>(), std::declval<second_accessor>(),
536 return Operation::evaluate_scalar_by_range(first, second,
542 return second.length(dimension);
545 template <
typename OtherAccessor>
546 GKO_ATTRIBUTES
void copy_from(
const OtherAccessor& other)
const =
delete;
548 const FirstOperand first;
549 const second_accessor second;
552 template <
typename FirstAccessor,
typename SecondOperand,
typename Operation>
553 struct implement_binary_operation<operation_kind::range_by_scalar,
554 FirstAccessor, SecondOperand, Operation> {
555 using first_accessor = FirstAccessor;
556 static constexpr
size_type dimensionality = first_accessor::dimensionality;
558 GKO_ATTRIBUTES constexpr
explicit implement_binary_operation(
559 const FirstAccessor& first,
const SecondOperand& second)
560 : first{first}, second{second}
563 template <
typename... DimensionTypes>
564 GKO_ATTRIBUTES constexpr
auto operator()(
565 const DimensionTypes&... dimensions)
const
566 -> decltype(Operation::evaluate_range_by_scalar(
567 std::declval<first_accessor>(), std::declval<SecondOperand>(),
570 return Operation::evaluate_range_by_scalar(first, second,
576 return first.length(dimension);
579 template <
typename OtherAccessor>
580 GKO_ATTRIBUTES
void copy_from(
const OtherAccessor& other)
const =
delete;
582 const first_accessor first;
583 const SecondOperand second;
589 #define GKO_DEPRECATED_UNARY_RANGE_OPERATION(_operation_deprecated_name, \
591 namespace accessor { \
592 template <typename Operand> \
593 struct GKO_DEPRECATED("Please use " #_operation_name) \
594 _operation_deprecated_name : _operation_name<Operand> {}; \
596 static_assert(true, \
597 "This assert is used to counter the false positive extra " \
598 "semi-colon warnings")
601 #define GKO_ENABLE_UNARY_RANGE_OPERATION(_operation_name, _operator_name, \
603 namespace accessor { \
604 template <typename Operand> \
605 struct _operation_name \
606 : ::gko::detail::implement_unary_operation<Operand, \
607 ::gko::_operator> { \
608 using ::gko::detail::implement_unary_operation< \
609 Operand, ::gko::_operator>::implement_unary_operation; \
612 GKO_BIND_UNARY_RANGE_OPERATION_TO_OPERATOR(_operation_name, _operator_name)
615 #define GKO_BIND_UNARY_RANGE_OPERATION_TO_OPERATOR(_operation_name, \
617 template <typename Accessor> \
618 GKO_ATTRIBUTES constexpr GKO_INLINE \
619 range<accessor::_operation_name<Accessor>> \
620 _operator_name(const range<Accessor>& operand) \
622 return range<accessor::_operation_name<Accessor>>( \
623 operand.get_accessor()); \
625 static_assert(true, \
626 "This assert is used to counter the false positive extra " \
627 "semi-colon warnings")
630 #define GKO_DEFINE_SIMPLE_UNARY_OPERATION(_name, ...) \
633 template <typename Operand> \
634 GKO_ATTRIBUTES static constexpr auto simple_evaluate_impl( \
635 const Operand& operand) -> decltype(__VA_ARGS__) \
637 return __VA_ARGS__; \
641 template <typename AccessorType, typename... DimensionTypes> \
642 GKO_ATTRIBUTES static constexpr auto evaluate( \
643 const AccessorType& accessor, const DimensionTypes&... dimensions) \
644 -> decltype(simple_evaluate_impl(accessor(dimensions...))) \
646 return simple_evaluate_impl(accessor(dimensions...)); \
656 GKO_DEFINE_SIMPLE_UNARY_OPERATION(
unary_plus, +operand);
657 GKO_DEFINE_SIMPLE_UNARY_OPERATION(
unary_minus, -operand);
660 GKO_DEFINE_SIMPLE_UNARY_OPERATION(
logical_not, !operand);
663 GKO_DEFINE_SIMPLE_UNARY_OPERATION(
bitwise_not, ~(operand));
680 GKO_ENABLE_UNARY_RANGE_OPERATION(unary_plus,
operator+,
681 accessor::detail::unary_plus);
682 GKO_ENABLE_UNARY_RANGE_OPERATION(
unary_minus,
operator-,
683 accessor::detail::unary_minus);
686 GKO_ENABLE_UNARY_RANGE_OPERATION(
logical_not,
operator!,
687 accessor::detail::logical_not);
690 GKO_ENABLE_UNARY_RANGE_OPERATION(
bitwise_not,
operator~,
691 accessor::detail::bitwise_not);
696 accessor::detail::zero_operation);
698 accessor::detail::one_operation);
700 accessor::detail::abs_operation);
702 accessor::detail::real_operation);
704 accessor::detail::imag_operation);
706 accessor::detail::conj_operation);
708 accessor::detail::squared_norm_operation);
721 template <
typename Accessor>
723 using accessor = Accessor;
724 static constexpr
size_type dimensionality = accessor::dimensionality;
727 const Accessor& operand)
731 template <
typename FirstDimensionType,
typename SecondDimensionType,
732 typename... DimensionTypes>
733 GKO_ATTRIBUTES constexpr
auto operator()(
734 const FirstDimensionType& first_dim,
735 const SecondDimensionType& second_dim,
736 const DimensionTypes&... dims)
const
737 -> decltype(std::declval<accessor>()(second_dim, first_dim, dims...))
739 return operand(second_dim, first_dim, dims...);
744 return dimension < 2 ? operand.length(dimension ^ 1)
745 : operand.length(dimension);
748 template <
typename OtherAccessor>
749 GKO_ATTRIBUTES
void copy_from(
const OtherAccessor& other)
const =
delete;
751 const accessor operand;
758 GKO_BIND_UNARY_RANGE_OPERATION_TO_OPERATOR(transpose_operation,
transpose);
761 #undef GKO_DEPRECATED_UNARY_RANGE_OPERATION
762 #undef GKO_DEFINE_SIMPLE_UNARY_OPERATION
763 #undef GKO_ENABLE_UNARY_RANGE_OPERATION
766 #define GKO_ENABLE_BINARY_RANGE_OPERATION(_operation_name, _operator_name, \
768 namespace accessor { \
769 template <::gko::detail::operation_kind Kind, typename FirstOperand, \
770 typename SecondOperand> \
771 struct _operation_name \
772 : ::gko::detail::implement_binary_operation< \
773 Kind, FirstOperand, SecondOperand, ::gko::_operator> { \
774 using ::gko::detail::implement_binary_operation< \
775 Kind, FirstOperand, SecondOperand, \
776 ::gko::_operator>::implement_binary_operation; \
779 GKO_BIND_RANGE_OPERATION_TO_OPERATOR(_operation_name, _operator_name); \
780 static_assert(true, \
781 "This assert is used to counter the false positive extra " \
782 "semi-colon warnings")
785 #define GKO_BIND_RANGE_OPERATION_TO_OPERATOR(_operation_name, _operator_name) \
786 template <typename Accessor> \
787 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
788 ::gko::detail::operation_kind::range_by_range, Accessor, Accessor>> \
789 _operator_name(const range<Accessor>& first, \
790 const range<Accessor>& second) \
792 return range<accessor::_operation_name< \
793 ::gko::detail::operation_kind::range_by_range, Accessor, \
794 Accessor>>(first.get_accessor(), second.get_accessor()); \
797 template <typename FirstAccessor, typename SecondAccessor> \
798 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
799 ::gko::detail::operation_kind::range_by_range, FirstAccessor, \
801 _operator_name(const range<FirstAccessor>& first, \
802 const range<SecondAccessor>& second) \
804 return range<accessor::_operation_name< \
805 ::gko::detail::operation_kind::range_by_range, FirstAccessor, \
806 SecondAccessor>>(first.get_accessor(), second.get_accessor()); \
809 template <typename FirstAccessor, typename SecondOperand> \
810 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
811 ::gko::detail::operation_kind::range_by_scalar, FirstAccessor, \
813 _operator_name(const range<FirstAccessor>& first, \
814 const SecondOperand& second) \
816 return range<accessor::_operation_name< \
817 ::gko::detail::operation_kind::range_by_scalar, FirstAccessor, \
818 SecondOperand>>(first.get_accessor(), second); \
821 template <typename FirstOperand, typename SecondAccessor> \
822 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
823 ::gko::detail::operation_kind::scalar_by_range, FirstOperand, \
825 _operator_name(const FirstOperand& first, \
826 const range<SecondAccessor>& second) \
828 return range<accessor::_operation_name< \
829 ::gko::detail::operation_kind::scalar_by_range, FirstOperand, \
830 SecondAccessor>>(first, second.get_accessor()); \
832 static_assert(true, \
833 "This assert is used to counter the false positive extra " \
834 "semi-colon warnings")
837 #define GKO_DEPRECATED_SIMPLE_BINARY_OPERATION(_deprecated_name, _name) \
838 struct GKO_DEPRECATED("Please use " #_name) _deprecated_name : _name {}
840 #define GKO_DEFINE_SIMPLE_BINARY_OPERATION(_name, ...) \
843 template <typename FirstOperand, typename SecondOperand> \
844 GKO_ATTRIBUTES constexpr static auto simple_evaluate_impl( \
845 const FirstOperand& first, const SecondOperand& second) \
846 -> decltype(__VA_ARGS__) \
848 return __VA_ARGS__; \
852 template <typename FirstAccessor, typename SecondAccessor, \
853 typename... DimensionTypes> \
854 GKO_ATTRIBUTES static constexpr auto evaluate_range_by_range( \
855 const FirstAccessor& first, const SecondAccessor& second, \
856 const DimensionTypes&... dims) \
857 -> decltype(simple_evaluate_impl(first(dims...), second(dims...))) \
859 return simple_evaluate_impl(first(dims...), second(dims...)); \
862 template <typename FirstOperand, typename SecondAccessor, \
863 typename... DimensionTypes> \
864 GKO_ATTRIBUTES static constexpr auto evaluate_scalar_by_range( \
865 const FirstOperand& first, const SecondAccessor& second, \
866 const DimensionTypes&... dims) \
867 -> decltype(simple_evaluate_impl(first, second(dims...))) \
869 return simple_evaluate_impl(first, second(dims...)); \
872 template <typename FirstAccessor, typename SecondOperand, \
873 typename... DimensionTypes> \
874 GKO_ATTRIBUTES static constexpr auto evaluate_range_by_scalar( \
875 const FirstAccessor& first, const SecondOperand& second, \
876 const DimensionTypes&... dims) \
877 -> decltype(simple_evaluate_impl(first(dims...), second)) \
879 return simple_evaluate_impl(first(dims...), second); \
889 GKO_DEFINE_SIMPLE_BINARY_OPERATION(add, first + second);
890 GKO_DEFINE_SIMPLE_BINARY_OPERATION(sub, first - second);
891 GKO_DEFINE_SIMPLE_BINARY_OPERATION(mul, first* second);
892 GKO_DEFINE_SIMPLE_BINARY_OPERATION(div, first / second);
893 GKO_DEFINE_SIMPLE_BINARY_OPERATION(mod, first % second);
896 GKO_DEFINE_SIMPLE_BINARY_OPERATION(less, first < second);
897 GKO_DEFINE_SIMPLE_BINARY_OPERATION(greater, first > second);
898 GKO_DEFINE_SIMPLE_BINARY_OPERATION(less_or_equal, first <= second);
899 GKO_DEFINE_SIMPLE_BINARY_OPERATION(greater_or_equal, first >= second);
900 GKO_DEFINE_SIMPLE_BINARY_OPERATION(equal, first == second);
901 GKO_DEFINE_SIMPLE_BINARY_OPERATION(not_equal, first != second);
904 GKO_DEFINE_SIMPLE_BINARY_OPERATION(logical_or, first || second);
905 GKO_DEFINE_SIMPLE_BINARY_OPERATION(logical_and, first&& second);
908 GKO_DEFINE_SIMPLE_BINARY_OPERATION(bitwise_or, first | second);
909 GKO_DEFINE_SIMPLE_BINARY_OPERATION(bitwise_and, first& second);
910 GKO_DEFINE_SIMPLE_BINARY_OPERATION(bitwise_xor, first ^ second);
911 GKO_DEFINE_SIMPLE_BINARY_OPERATION(left_shift, first << second);
912 GKO_DEFINE_SIMPLE_BINARY_OPERATION(right_shift, first >> second);
915 GKO_DEFINE_SIMPLE_BINARY_OPERATION(max_operation,
max(first, second));
916 GKO_DEFINE_SIMPLE_BINARY_OPERATION(min_operation,
min(first, second));
918 GKO_DEPRECATED_SIMPLE_BINARY_OPERATION(max_operaton, max_operation);
919 GKO_DEPRECATED_SIMPLE_BINARY_OPERATION(min_operaton, min_operation);
925 GKO_ENABLE_BINARY_RANGE_OPERATION(
add,
operator+, accessor::detail::add);
926 GKO_ENABLE_BINARY_RANGE_OPERATION(
sub,
operator-, accessor::detail::sub);
927 GKO_ENABLE_BINARY_RANGE_OPERATION(
mul,
operator*, accessor::detail::mul);
928 GKO_ENABLE_BINARY_RANGE_OPERATION(
div,
operator/, accessor::detail::div);
929 GKO_ENABLE_BINARY_RANGE_OPERATION(
mod,
operator%, accessor::detail::mod);
932 GKO_ENABLE_BINARY_RANGE_OPERATION(
less,
operator<, accessor::detail::less);
933 GKO_ENABLE_BINARY_RANGE_OPERATION(
greater,
operator>,
934 accessor::detail::greater);
936 accessor::detail::less_or_equal);
938 accessor::detail::greater_or_equal);
939 GKO_ENABLE_BINARY_RANGE_OPERATION(
equal,
operator==, accessor::detail::equal);
940 GKO_ENABLE_BINARY_RANGE_OPERATION(
not_equal,
operator!=,
941 accessor::detail::not_equal);
944 GKO_ENABLE_BINARY_RANGE_OPERATION(
logical_or,
operator||,
945 accessor::detail::logical_or);
946 GKO_ENABLE_BINARY_RANGE_OPERATION(
logical_and,
operator&&,
947 accessor::detail::logical_and);
950 GKO_ENABLE_BINARY_RANGE_OPERATION(
bitwise_or,
operator|,
951 accessor::detail::bitwise_or);
952 GKO_ENABLE_BINARY_RANGE_OPERATION(
bitwise_and,
operator&,
953 accessor::detail::bitwise_and);
954 GKO_ENABLE_BINARY_RANGE_OPERATION(
bitwise_xor,
operator^,
955 accessor::detail::bitwise_xor);
956 GKO_ENABLE_BINARY_RANGE_OPERATION(
left_shift,
operator<<,
957 accessor::detail::left_shift);
958 GKO_ENABLE_BINARY_RANGE_OPERATION(
right_shift,
operator>>,
959 accessor::detail::right_shift);
963 accessor::detail::max_operation);
965 accessor::detail::min_operation);
972 template <gko::detail::operation_kind Kind,
typename FirstAccessor,
973 typename SecondAccessor>
975 static_assert(Kind == gko::detail::operation_kind::range_by_range,
976 "Matrix multiplication expects both operands to be ranges");
977 using first_accessor = FirstAccessor;
978 using second_accessor = SecondAccessor;
979 static_assert(first_accessor::dimensionality ==
980 second_accessor::dimensionality,
981 "Both ranges need to have the same number of dimensions");
982 static constexpr
size_type dimensionality = first_accessor::dimensionality;
984 GKO_ATTRIBUTES
explicit mmul_operation(
const FirstAccessor& first,
985 const SecondAccessor& second)
986 : first{first}, second{second}
988 GKO_ASSERT(first.length(1) == second.length(0));
989 GKO_ASSERT(gko::detail::equal_dimensions<2>(first, second));
992 template <
typename FirstDimension,
typename SecondDimension,
993 typename... DimensionTypes>
994 GKO_ATTRIBUTES
auto operator()(
const FirstDimension& row,
995 const SecondDimension& col,
996 const DimensionTypes&... rest)
const
997 -> decltype(std::declval<FirstAccessor>()(row, 0, rest...) *
998 std::declval<SecondAccessor>()(0, col, rest...) +
999 std::declval<FirstAccessor>()(row, 1, rest...) *
1000 std::declval<SecondAccessor>()(1, col, rest...))
1003 decltype(first(row, 0, rest...) * second(0, col, rest...) +
1004 first(row, 1, rest...) * second(1, col, rest...));
1005 GKO_ASSERT(first.length(1) == second.length(0));
1006 auto result = zero<result_type>();
1007 const auto size = first.length(1);
1008 for (
auto i =
zero(size); i < size; ++i) {
1009 result += first(row, i, rest...) * second(i, col, rest...);
1016 return dimension == 1 ? second.length(1) : first.length(dimension);
1019 template <
typename OtherAccessor>
1020 GKO_ATTRIBUTES
void copy_from(
const OtherAccessor& other)
const =
delete;
1022 const first_accessor first;
1023 const second_accessor second;
1030 GKO_BIND_RANGE_OPERATION_TO_OPERATOR(mmul_operation, mmul);
1033 #undef GKO_DEFINE_SIMPLE_BINARY_OPERATION
1034 #undef GKO_ENABLE_BINARY_RANGE_OPERATION
1040 #endif // GKO_PUBLIC_CORE_BASE_RANGE_HPP_