Ginkgo Generated from branch based on master. Ginkgo version 1.8.0
A numerical linear algebra library targeting many-core architectures
Loading...
Searching...
No Matches
range.hpp
1// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_BASE_RANGE_HPP_
6#define GKO_PUBLIC_CORE_BASE_RANGE_HPP_
7
8
9#include <type_traits>
10
11
12#include <ginkgo/core/base/math.hpp>
13#include <ginkgo/core/base/types.hpp>
14#include <ginkgo/core/base/utils.hpp>
15
16
17namespace gko {
18
19
47struct span {
55 GKO_ATTRIBUTES constexpr span(size_type point) noexcept
56 : span{point, point + 1}
57 {}
58
65 GKO_ATTRIBUTES constexpr span(size_type begin, size_type end) noexcept
66 : begin{begin}, end{end}
67 {}
68
74 GKO_ATTRIBUTES constexpr bool is_valid() const { return begin <= end; }
75
81 GKO_ATTRIBUTES constexpr size_type length() const { return end - begin; }
82
87
92};
93
94
95GKO_ATTRIBUTES GKO_INLINE constexpr bool operator<(const span& first,
96 const span& second)
97{
98 return first.end < second.begin;
99}
100
101
102GKO_ATTRIBUTES GKO_INLINE constexpr bool operator<=(const span& first,
103 const span& second)
104{
105 return first.end <= second.begin;
106}
107
108
109GKO_ATTRIBUTES GKO_INLINE constexpr bool operator>(const span& first,
110 const span& second)
111{
112 return second < first;
113}
114
115
116GKO_ATTRIBUTES GKO_INLINE constexpr bool operator>=(const span& first,
117 const span& second)
118{
119 return second <= first;
120}
121
122
123GKO_ATTRIBUTES GKO_INLINE constexpr bool operator==(const span& first,
124 const span& second)
125{
126 return first.begin == second.begin && first.end == second.end;
127}
128
129
130GKO_ATTRIBUTES GKO_INLINE constexpr bool operator!=(const span& first,
131 const span& second)
132{
133 return !(first == second);
134}
135
136
137namespace detail {
138
139
140template <size_type CurrentDimension = 0, typename FirstRange,
141 typename SecondRange>
142GKO_ATTRIBUTES constexpr GKO_INLINE
143 std::enable_if_t<(CurrentDimension >= max(FirstRange::dimensionality,
144 SecondRange::dimensionality)),
145 bool>
146 equal_dimensions(const FirstRange&, const SecondRange&)
147{
148 return true;
149}
150
151template <size_type CurrentDimension = 0, typename FirstRange,
152 typename SecondRange>
153GKO_ATTRIBUTES constexpr GKO_INLINE
154 std::enable_if_t<(CurrentDimension < max(FirstRange::dimensionality,
155 SecondRange::dimensionality)),
156 bool>
157 equal_dimensions(const FirstRange& first, const SecondRange& second)
158{
159 return first.length(CurrentDimension) == second.length(CurrentDimension) &&
160 equal_dimensions<CurrentDimension + 1>(first, second);
161}
162
167template <class...>
168struct head;
169
173template <class First, class... Rest>
174struct head<First, Rest...> {
175 using type = First;
176};
177
181template <class... T>
182using head_t = typename head<T...>::type;
183
184
185} // namespace detail
186
187
297template <typename Accessor>
298class range {
299public:
303 using accessor = Accessor;
304
308 static constexpr size_type dimensionality = accessor::dimensionality;
309
313 ~range() = default;
314
323 template <
324 typename... AccessorParams,
325 typename = std::enable_if_t<
326 sizeof...(AccessorParams) != 1 ||
327 !std::is_same<
328 range, std::decay<detail::head_t<AccessorParams...>>>::value>>
329 GKO_ATTRIBUTES constexpr explicit range(AccessorParams&&... params)
330 : accessor_{std::forward<AccessorParams>(params)...}
331 {}
332
345 template <typename... DimensionTypes>
346 GKO_ATTRIBUTES constexpr auto operator()(DimensionTypes&&... dimensions)
347 const -> decltype(std::declval<accessor>()(
348 std::forward<DimensionTypes>(dimensions)...))
349 {
350 static_assert(sizeof...(DimensionTypes) <= dimensionality,
351 "Too many dimensions in range call");
352 return accessor_(std::forward<DimensionTypes>(dimensions)...);
353 }
354
363 template <typename OtherAccessor>
364 GKO_ATTRIBUTES const range& operator=(
365 const range<OtherAccessor>& other) const
366 {
367 GKO_ASSERT(detail::equal_dimensions(*this, other));
368 accessor_.copy_from(other);
369 return *this;
370 }
371
385 GKO_ATTRIBUTES const range& operator=(const range& other) const
386 {
387 GKO_ASSERT(detail::equal_dimensions(*this, other));
388 accessor_.copy_from(other.get_accessor());
389 return *this;
390 }
391
392 range(const range& other) = default;
393
401 GKO_ATTRIBUTES constexpr size_type length(size_type dimension) const
402 {
403 return accessor_.length(dimension);
404 }
405
413 GKO_ATTRIBUTES constexpr const accessor* operator->() const noexcept
414 {
415 return &accessor_;
416 }
417
423 GKO_ATTRIBUTES constexpr const accessor& get_accessor() const noexcept
424 {
425 return accessor_;
426 }
427
428private:
429 accessor accessor_;
430};
431
432
433// implementation of range operations follows
434// (you probably should not have to look at this unless you're interested in the
435// gory details)
436
437
438namespace detail {
439
440
441enum class operation_kind { range_by_range, scalar_by_range, range_by_scalar };
442
443
444template <typename Accessor, typename Operation>
445struct implement_unary_operation {
446 using accessor = Accessor;
447 static constexpr size_type dimensionality = accessor::dimensionality;
448
449 GKO_ATTRIBUTES constexpr explicit implement_unary_operation(
450 const Accessor& operand)
451 : operand{operand}
452 {}
453
454 template <typename... DimensionTypes>
455 GKO_ATTRIBUTES constexpr auto operator()(
456 const DimensionTypes&... dimensions) const
457 -> decltype(Operation::evaluate(std::declval<accessor>(),
458 dimensions...))
459 {
460 return Operation::evaluate(operand, dimensions...);
461 }
462
463 GKO_ATTRIBUTES constexpr size_type length(size_type dimension) const
464 {
465 return operand.length(dimension);
466 }
467
468 template <typename OtherAccessor>
469 GKO_ATTRIBUTES void copy_from(const OtherAccessor& other) const = delete;
470
471 const accessor operand;
472};
473
474
475template <operation_kind Kind, typename FirstOperand, typename SecondOperand,
476 typename Operation>
477struct implement_binary_operation {};
478
479template <typename FirstAccessor, typename SecondAccessor, typename Operation>
480struct implement_binary_operation<operation_kind::range_by_range, FirstAccessor,
481 SecondAccessor, Operation> {
482 using first_accessor = FirstAccessor;
483 using second_accessor = SecondAccessor;
484 static_assert(first_accessor::dimensionality ==
485 second_accessor::dimensionality,
486 "Both ranges need to have the same number of dimensions");
487 static constexpr size_type dimensionality = first_accessor::dimensionality;
488
489 GKO_ATTRIBUTES explicit implement_binary_operation(
490 const FirstAccessor& first, const SecondAccessor& second)
491 : first{first}, second{second}
492 {
493 GKO_ASSERT(gko::detail::equal_dimensions(first, second));
494 }
495
496 template <typename... DimensionTypes>
497 GKO_ATTRIBUTES constexpr auto operator()(
498 const DimensionTypes&... dimensions) const
499 -> decltype(Operation::evaluate_range_by_range(
500 std::declval<first_accessor>(), std::declval<second_accessor>(),
501 dimensions...))
502 {
503 return Operation::evaluate_range_by_range(first, second, dimensions...);
504 }
505
506 GKO_ATTRIBUTES constexpr size_type length(size_type dimension) const
507 {
508 return first.length(dimension);
509 }
510
511 template <typename OtherAccessor>
512 GKO_ATTRIBUTES void copy_from(const OtherAccessor& other) const = delete;
513
514 const first_accessor first;
515 const second_accessor second;
516};
517
518template <typename FirstOperand, typename SecondAccessor, typename Operation>
519struct implement_binary_operation<operation_kind::scalar_by_range, FirstOperand,
520 SecondAccessor, Operation> {
521 using second_accessor = SecondAccessor;
522 static constexpr size_type dimensionality = second_accessor::dimensionality;
523
524 GKO_ATTRIBUTES constexpr explicit implement_binary_operation(
525 const FirstOperand& first, const SecondAccessor& second)
526 : first{first}, second{second}
527 {}
528
529 template <typename... DimensionTypes>
530 GKO_ATTRIBUTES constexpr auto operator()(
531 const DimensionTypes&... dimensions) const
532 -> decltype(Operation::evaluate_scalar_by_range(
533 std::declval<FirstOperand>(), std::declval<second_accessor>(),
534 dimensions...))
535 {
536 return Operation::evaluate_scalar_by_range(first, second,
537 dimensions...);
538 }
539
540 GKO_ATTRIBUTES constexpr size_type length(size_type dimension) const
541 {
542 return second.length(dimension);
543 }
544
545 template <typename OtherAccessor>
546 GKO_ATTRIBUTES void copy_from(const OtherAccessor& other) const = delete;
547
548 const FirstOperand first;
549 const second_accessor second;
550};
551
552template <typename FirstAccessor, typename SecondOperand, typename Operation>
553struct implement_binary_operation<operation_kind::range_by_scalar,
554 FirstAccessor, SecondOperand, Operation> {
555 using first_accessor = FirstAccessor;
556 static constexpr size_type dimensionality = first_accessor::dimensionality;
557
558 GKO_ATTRIBUTES constexpr explicit implement_binary_operation(
559 const FirstAccessor& first, const SecondOperand& second)
560 : first{first}, second{second}
561 {}
562
563 template <typename... DimensionTypes>
564 GKO_ATTRIBUTES constexpr auto operator()(
565 const DimensionTypes&... dimensions) const
566 -> decltype(Operation::evaluate_range_by_scalar(
567 std::declval<first_accessor>(), std::declval<SecondOperand>(),
568 dimensions...))
569 {
570 return Operation::evaluate_range_by_scalar(first, second,
571 dimensions...);
572 }
573
574 GKO_ATTRIBUTES constexpr size_type length(size_type dimension) const
575 {
576 return first.length(dimension);
577 }
578
579 template <typename OtherAccessor>
580 GKO_ATTRIBUTES void copy_from(const OtherAccessor& other) const = delete;
581
582 const first_accessor first;
583 const SecondOperand second;
584};
585
586
587} // namespace detail
588
589#define GKO_DEPRECATED_UNARY_RANGE_OPERATION(_operation_deprecated_name, \
590 _operation_name) \
591 namespace accessor { \
592 template <typename Operand> \
593 struct GKO_DEPRECATED("Please use " #_operation_name) \
594 _operation_deprecated_name : _operation_name<Operand> {}; \
595 } \
596 static_assert(true, \
597 "This assert is used to counter the false positive extra " \
598 "semi-colon warnings")
599
600
601#define GKO_ENABLE_UNARY_RANGE_OPERATION(_operation_name, _operator_name, \
602 _operator) \
603 namespace accessor { \
604 template <typename Operand> \
605 struct _operation_name \
606 : ::gko::detail::implement_unary_operation<Operand, \
607 ::gko::_operator> { \
608 using ::gko::detail::implement_unary_operation< \
609 Operand, ::gko::_operator>::implement_unary_operation; \
610 }; \
611 } \
612 GKO_BIND_UNARY_RANGE_OPERATION_TO_OPERATOR(_operation_name, _operator_name)
613
614
615#define GKO_BIND_UNARY_RANGE_OPERATION_TO_OPERATOR(_operation_name, \
616 _operator_name) \
617 template <typename Accessor> \
618 GKO_ATTRIBUTES constexpr GKO_INLINE \
619 range<accessor::_operation_name<Accessor>> \
620 _operator_name(const range<Accessor>& operand) \
621 { \
622 return range<accessor::_operation_name<Accessor>>( \
623 operand.get_accessor()); \
624 } \
625 static_assert(true, \
626 "This assert is used to counter the false positive extra " \
627 "semi-colon warnings")
628
629
630#define GKO_DEFINE_SIMPLE_UNARY_OPERATION(_name, ...) \
631 struct _name { \
632 private: \
633 template <typename Operand> \
634 GKO_ATTRIBUTES static constexpr auto simple_evaluate_impl( \
635 const Operand& operand) -> decltype(__VA_ARGS__) \
636 { \
637 return __VA_ARGS__; \
638 } \
639 \
640 public: \
641 template <typename AccessorType, typename... DimensionTypes> \
642 GKO_ATTRIBUTES static constexpr auto evaluate( \
643 const AccessorType& accessor, const DimensionTypes&... dimensions) \
644 -> decltype(simple_evaluate_impl(accessor(dimensions...))) \
645 { \
646 return simple_evaluate_impl(accessor(dimensions...)); \
647 } \
648 }
649
650
651namespace accessor {
652namespace detail {
653
654
655// unary arithmetic
656GKO_DEFINE_SIMPLE_UNARY_OPERATION(unary_plus, +operand);
657GKO_DEFINE_SIMPLE_UNARY_OPERATION(unary_minus, -operand);
658
659// unary logical
660GKO_DEFINE_SIMPLE_UNARY_OPERATION(logical_not, !operand);
661
662// unary bitwise
663GKO_DEFINE_SIMPLE_UNARY_OPERATION(bitwise_not, ~(operand));
664
665// common functions
666GKO_DEFINE_SIMPLE_UNARY_OPERATION(zero_operation, zero(operand));
667GKO_DEFINE_SIMPLE_UNARY_OPERATION(one_operation, one(operand));
668GKO_DEFINE_SIMPLE_UNARY_OPERATION(abs_operation, abs(operand));
669GKO_DEFINE_SIMPLE_UNARY_OPERATION(real_operation, real(operand));
670GKO_DEFINE_SIMPLE_UNARY_OPERATION(imag_operation, imag(operand));
671GKO_DEFINE_SIMPLE_UNARY_OPERATION(conj_operation, conj(operand));
672GKO_DEFINE_SIMPLE_UNARY_OPERATION(squared_norm_operation,
673 squared_norm(operand));
674
675} // namespace detail
676} // namespace accessor
677
678
679// unary arithmetic
680GKO_ENABLE_UNARY_RANGE_OPERATION(unary_plus, operator+,
681 accessor::detail::unary_plus);
682GKO_ENABLE_UNARY_RANGE_OPERATION(unary_minus, operator-,
683 accessor::detail::unary_minus);
684
685// unary logical
686GKO_ENABLE_UNARY_RANGE_OPERATION(logical_not, operator!,
687 accessor::detail::logical_not);
688
689// unary bitwise
690GKO_ENABLE_UNARY_RANGE_OPERATION(bitwise_not, operator~,
691 accessor::detail::bitwise_not);
692
693// common unary functions
694
695GKO_ENABLE_UNARY_RANGE_OPERATION(zero_operation, zero,
696 accessor::detail::zero_operation);
697GKO_ENABLE_UNARY_RANGE_OPERATION(one_operation, one,
698 accessor::detail::one_operation);
699GKO_ENABLE_UNARY_RANGE_OPERATION(abs_operation, abs,
700 accessor::detail::abs_operation);
701GKO_ENABLE_UNARY_RANGE_OPERATION(real_operation, real,
702 accessor::detail::real_operation);
703GKO_ENABLE_UNARY_RANGE_OPERATION(imag_operation, imag,
704 accessor::detail::imag_operation);
705GKO_ENABLE_UNARY_RANGE_OPERATION(conj_operation, conj,
706 accessor::detail::conj_operation);
707GKO_ENABLE_UNARY_RANGE_OPERATION(squared_norm_operation, squared_norm,
708 accessor::detail::squared_norm_operation);
709
710GKO_DEPRECATED_UNARY_RANGE_OPERATION(one_operaton, one_operation);
711GKO_DEPRECATED_UNARY_RANGE_OPERATION(abs_operaton, abs_operation);
712GKO_DEPRECATED_UNARY_RANGE_OPERATION(real_operaton, real_operation);
713GKO_DEPRECATED_UNARY_RANGE_OPERATION(imag_operaton, imag_operation);
714GKO_DEPRECATED_UNARY_RANGE_OPERATION(conj_operaton, conj_operation);
715GKO_DEPRECATED_UNARY_RANGE_OPERATION(squared_norm_operaton,
717
718namespace accessor {
719
720
721template <typename Accessor>
723 using accessor = Accessor;
724 static constexpr size_type dimensionality = accessor::dimensionality;
725
726 GKO_ATTRIBUTES constexpr explicit transpose_operation(
727 const Accessor& operand)
728 : operand{operand}
729 {}
730
731 template <typename FirstDimensionType, typename SecondDimensionType,
732 typename... DimensionTypes>
733 GKO_ATTRIBUTES constexpr auto operator()(
734 const FirstDimensionType& first_dim,
735 const SecondDimensionType& second_dim,
736 const DimensionTypes&... dims) const
737 -> decltype(std::declval<accessor>()(second_dim, first_dim, dims...))
738 {
739 return operand(second_dim, first_dim, dims...);
740 }
741
742 GKO_ATTRIBUTES constexpr size_type length(size_type dimension) const
743 {
744 return dimension < 2 ? operand.length(dimension ^ 1)
745 : operand.length(dimension);
746 }
747
748 template <typename OtherAccessor>
749 GKO_ATTRIBUTES void copy_from(const OtherAccessor& other) const = delete;
750
751 const accessor operand;
752};
753
754
755} // namespace accessor
756
757
758GKO_BIND_UNARY_RANGE_OPERATION_TO_OPERATOR(transpose_operation, transpose);
759
760
761#undef GKO_DEPRECATED_UNARY_RANGE_OPERATION
762#undef GKO_DEFINE_SIMPLE_UNARY_OPERATION
763#undef GKO_ENABLE_UNARY_RANGE_OPERATION
764
765
766#define GKO_ENABLE_BINARY_RANGE_OPERATION(_operation_name, _operator_name, \
767 _operator) \
768 namespace accessor { \
769 template <::gko::detail::operation_kind Kind, typename FirstOperand, \
770 typename SecondOperand> \
771 struct _operation_name \
772 : ::gko::detail::implement_binary_operation< \
773 Kind, FirstOperand, SecondOperand, ::gko::_operator> { \
774 using ::gko::detail::implement_binary_operation< \
775 Kind, FirstOperand, SecondOperand, \
776 ::gko::_operator>::implement_binary_operation; \
777 }; \
778 } \
779 GKO_BIND_RANGE_OPERATION_TO_OPERATOR(_operation_name, _operator_name); \
780 static_assert(true, \
781 "This assert is used to counter the false positive extra " \
782 "semi-colon warnings")
783
784
785#define GKO_BIND_RANGE_OPERATION_TO_OPERATOR(_operation_name, _operator_name) \
786 template <typename Accessor> \
787 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
788 ::gko::detail::operation_kind::range_by_range, Accessor, Accessor>> \
789 _operator_name(const range<Accessor>& first, \
790 const range<Accessor>& second) \
791 { \
792 return range<accessor::_operation_name< \
793 ::gko::detail::operation_kind::range_by_range, Accessor, \
794 Accessor>>(first.get_accessor(), second.get_accessor()); \
795 } \
796 \
797 template <typename FirstAccessor, typename SecondAccessor> \
798 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
799 ::gko::detail::operation_kind::range_by_range, FirstAccessor, \
800 SecondAccessor>> \
801 _operator_name(const range<FirstAccessor>& first, \
802 const range<SecondAccessor>& second) \
803 { \
804 return range<accessor::_operation_name< \
805 ::gko::detail::operation_kind::range_by_range, FirstAccessor, \
806 SecondAccessor>>(first.get_accessor(), second.get_accessor()); \
807 } \
808 \
809 template <typename FirstAccessor, typename SecondOperand> \
810 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
811 ::gko::detail::operation_kind::range_by_scalar, FirstAccessor, \
812 SecondOperand>> \
813 _operator_name(const range<FirstAccessor>& first, \
814 const SecondOperand& second) \
815 { \
816 return range<accessor::_operation_name< \
817 ::gko::detail::operation_kind::range_by_scalar, FirstAccessor, \
818 SecondOperand>>(first.get_accessor(), second); \
819 } \
820 \
821 template <typename FirstOperand, typename SecondAccessor> \
822 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
823 ::gko::detail::operation_kind::scalar_by_range, FirstOperand, \
824 SecondAccessor>> \
825 _operator_name(const FirstOperand& first, \
826 const range<SecondAccessor>& second) \
827 { \
828 return range<accessor::_operation_name< \
829 ::gko::detail::operation_kind::scalar_by_range, FirstOperand, \
830 SecondAccessor>>(first, second.get_accessor()); \
831 } \
832 static_assert(true, \
833 "This assert is used to counter the false positive extra " \
834 "semi-colon warnings")
835
836
837#define GKO_DEPRECATED_SIMPLE_BINARY_OPERATION(_deprecated_name, _name) \
838 struct GKO_DEPRECATED("Please use " #_name) _deprecated_name : _name {}
839
840#define GKO_DEFINE_SIMPLE_BINARY_OPERATION(_name, ...) \
841 struct _name { \
842 private: \
843 template <typename FirstOperand, typename SecondOperand> \
844 GKO_ATTRIBUTES constexpr static auto simple_evaluate_impl( \
845 const FirstOperand& first, const SecondOperand& second) \
846 -> decltype(__VA_ARGS__) \
847 { \
848 return __VA_ARGS__; \
849 } \
850 \
851 public: \
852 template <typename FirstAccessor, typename SecondAccessor, \
853 typename... DimensionTypes> \
854 GKO_ATTRIBUTES static constexpr auto evaluate_range_by_range( \
855 const FirstAccessor& first, const SecondAccessor& second, \
856 const DimensionTypes&... dims) \
857 -> decltype(simple_evaluate_impl(first(dims...), second(dims...))) \
858 { \
859 return simple_evaluate_impl(first(dims...), second(dims...)); \
860 } \
861 \
862 template <typename FirstOperand, typename SecondAccessor, \
863 typename... DimensionTypes> \
864 GKO_ATTRIBUTES static constexpr auto evaluate_scalar_by_range( \
865 const FirstOperand& first, const SecondAccessor& second, \
866 const DimensionTypes&... dims) \
867 -> decltype(simple_evaluate_impl(first, second(dims...))) \
868 { \
869 return simple_evaluate_impl(first, second(dims...)); \
870 } \
871 \
872 template <typename FirstAccessor, typename SecondOperand, \
873 typename... DimensionTypes> \
874 GKO_ATTRIBUTES static constexpr auto evaluate_range_by_scalar( \
875 const FirstAccessor& first, const SecondOperand& second, \
876 const DimensionTypes&... dims) \
877 -> decltype(simple_evaluate_impl(first(dims...), second)) \
878 { \
879 return simple_evaluate_impl(first(dims...), second); \
880 } \
881 }
882
883
884namespace accessor {
885namespace detail {
886
887
888// binary arithmetic
889GKO_DEFINE_SIMPLE_BINARY_OPERATION(add, first + second);
890GKO_DEFINE_SIMPLE_BINARY_OPERATION(sub, first - second);
891GKO_DEFINE_SIMPLE_BINARY_OPERATION(mul, first* second);
892GKO_DEFINE_SIMPLE_BINARY_OPERATION(div, first / second);
893GKO_DEFINE_SIMPLE_BINARY_OPERATION(mod, first % second);
894
895// relational
896GKO_DEFINE_SIMPLE_BINARY_OPERATION(less, first < second);
897GKO_DEFINE_SIMPLE_BINARY_OPERATION(greater, first > second);
898GKO_DEFINE_SIMPLE_BINARY_OPERATION(less_or_equal, first <= second);
899GKO_DEFINE_SIMPLE_BINARY_OPERATION(greater_or_equal, first >= second);
900GKO_DEFINE_SIMPLE_BINARY_OPERATION(equal, first == second);
901GKO_DEFINE_SIMPLE_BINARY_OPERATION(not_equal, first != second);
902
903// binary logical
904GKO_DEFINE_SIMPLE_BINARY_OPERATION(logical_or, first || second);
905GKO_DEFINE_SIMPLE_BINARY_OPERATION(logical_and, first&& second);
906
907// binary bitwise
908GKO_DEFINE_SIMPLE_BINARY_OPERATION(bitwise_or, first | second);
909GKO_DEFINE_SIMPLE_BINARY_OPERATION(bitwise_and, first& second);
910GKO_DEFINE_SIMPLE_BINARY_OPERATION(bitwise_xor, first ^ second);
911GKO_DEFINE_SIMPLE_BINARY_OPERATION(left_shift, first << second);
912GKO_DEFINE_SIMPLE_BINARY_OPERATION(right_shift, first >> second);
913
914// common binary functions
915GKO_DEFINE_SIMPLE_BINARY_OPERATION(max_operation, max(first, second));
916GKO_DEFINE_SIMPLE_BINARY_OPERATION(min_operation, min(first, second));
917
918GKO_DEPRECATED_SIMPLE_BINARY_OPERATION(max_operaton, max_operation);
919GKO_DEPRECATED_SIMPLE_BINARY_OPERATION(min_operaton, min_operation);
920} // namespace detail
921} // namespace accessor
922
923
924// binary arithmetic
925GKO_ENABLE_BINARY_RANGE_OPERATION(add, operator+, accessor::detail::add);
926GKO_ENABLE_BINARY_RANGE_OPERATION(sub, operator-, accessor::detail::sub);
927GKO_ENABLE_BINARY_RANGE_OPERATION(mul, operator*, accessor::detail::mul);
928GKO_ENABLE_BINARY_RANGE_OPERATION(div, operator/, accessor::detail::div);
929GKO_ENABLE_BINARY_RANGE_OPERATION(mod, operator%, accessor::detail::mod);
930
931// relational
932GKO_ENABLE_BINARY_RANGE_OPERATION(less, operator<, accessor::detail::less);
933GKO_ENABLE_BINARY_RANGE_OPERATION(greater, operator>,
934 accessor::detail::greater);
935GKO_ENABLE_BINARY_RANGE_OPERATION(less_or_equal, operator<=,
936 accessor::detail::less_or_equal);
937GKO_ENABLE_BINARY_RANGE_OPERATION(greater_or_equal, operator>=,
938 accessor::detail::greater_or_equal);
939GKO_ENABLE_BINARY_RANGE_OPERATION(equal, operator==, accessor::detail::equal);
940GKO_ENABLE_BINARY_RANGE_OPERATION(not_equal, operator!=,
941 accessor::detail::not_equal);
942
943// binary logical
944GKO_ENABLE_BINARY_RANGE_OPERATION(logical_or, operator||,
945 accessor::detail::logical_or);
946GKO_ENABLE_BINARY_RANGE_OPERATION(logical_and, operator&&,
947 accessor::detail::logical_and);
948
949// binary bitwise
950GKO_ENABLE_BINARY_RANGE_OPERATION(bitwise_or, operator|,
951 accessor::detail::bitwise_or);
952GKO_ENABLE_BINARY_RANGE_OPERATION(bitwise_and, operator&,
953 accessor::detail::bitwise_and);
954GKO_ENABLE_BINARY_RANGE_OPERATION(bitwise_xor, operator^,
955 accessor::detail::bitwise_xor);
956GKO_ENABLE_BINARY_RANGE_OPERATION(left_shift, operator<<,
957 accessor::detail::left_shift);
958GKO_ENABLE_BINARY_RANGE_OPERATION(right_shift, operator>>,
959 accessor::detail::right_shift);
960
961// common binary functions
962GKO_ENABLE_BINARY_RANGE_OPERATION(max_operation, max,
963 accessor::detail::max_operation);
964GKO_ENABLE_BINARY_RANGE_OPERATION(min_operation, min,
965 accessor::detail::min_operation);
966
967
968// special binary range functions
969namespace accessor {
970
971
972template <gko::detail::operation_kind Kind, typename FirstAccessor,
973 typename SecondAccessor>
975 static_assert(Kind == gko::detail::operation_kind::range_by_range,
976 "Matrix multiplication expects both operands to be ranges");
977 using first_accessor = FirstAccessor;
978 using second_accessor = SecondAccessor;
979 static_assert(first_accessor::dimensionality ==
980 second_accessor::dimensionality,
981 "Both ranges need to have the same number of dimensions");
982 static constexpr size_type dimensionality = first_accessor::dimensionality;
983
984 GKO_ATTRIBUTES explicit mmul_operation(const FirstAccessor& first,
985 const SecondAccessor& second)
986 : first{first}, second{second}
987 {
988 GKO_ASSERT(first.length(1) == second.length(0));
989 GKO_ASSERT(gko::detail::equal_dimensions<2>(first, second));
990 }
991
992 template <typename FirstDimension, typename SecondDimension,
993 typename... DimensionTypes>
994 GKO_ATTRIBUTES auto operator()(const FirstDimension& row,
995 const SecondDimension& col,
996 const DimensionTypes&... rest) const
997 -> decltype(std::declval<FirstAccessor>()(row, 0, rest...) *
998 std::declval<SecondAccessor>()(0, col, rest...) +
999 std::declval<FirstAccessor>()(row, 1, rest...) *
1000 std::declval<SecondAccessor>()(1, col, rest...))
1001 {
1002 using result_type =
1003 decltype(first(row, 0, rest...) * second(0, col, rest...) +
1004 first(row, 1, rest...) * second(1, col, rest...));
1005 GKO_ASSERT(first.length(1) == second.length(0));
1006 auto result = zero<result_type>();
1007 const auto size = first.length(1);
1008 for (auto i = zero(size); i < size; ++i) {
1009 result += first(row, i, rest...) * second(i, col, rest...);
1010 }
1011 return result;
1012 }
1013
1014 GKO_ATTRIBUTES constexpr size_type length(size_type dimension) const
1015 {
1016 return dimension == 1 ? second.length(1) : first.length(dimension);
1017 }
1018
1019 template <typename OtherAccessor>
1020 GKO_ATTRIBUTES void copy_from(const OtherAccessor& other) const = delete;
1021
1022 const first_accessor first;
1023 const second_accessor second;
1024};
1025
1026
1027} // namespace accessor
1028
1029
1030GKO_BIND_RANGE_OPERATION_TO_OPERATOR(mmul_operation, mmul);
1031
1032
1033#undef GKO_DEFINE_SIMPLE_BINARY_OPERATION
1034#undef GKO_ENABLE_BINARY_RANGE_OPERATION
1035
1036
1037} // namespace gko
1038
1039
1040#endif // GKO_PUBLIC_CORE_BASE_RANGE_HPP_
A range is a multidimensional view of the memory.
Definition range.hpp:298
Accessor accessor
The type of the underlying accessor.
Definition range.hpp:303
constexpr auto operator()(DimensionTypes &&... dimensions) const -> decltype(std::declval< accessor >()(std::forward< DimensionTypes >(dimensions)...))
Returns a value (or a sub-range) with the specified indexes.
Definition range.hpp:346
constexpr size_type length(size_type dimension) const
Returns the length of the specified dimension of the range.
Definition range.hpp:401
constexpr const accessor * operator->() const noexcept
Returns a pointer to the accessor.
Definition range.hpp:413
static constexpr size_type dimensionality
The number of dimensions of the range.
Definition range.hpp:308
const range & operator=(const range &other) const
Assigns another range to this range.
Definition range.hpp:385
constexpr const accessor & get_accessor() const noexcept
`Returns a reference to the accessor.
Definition range.hpp:423
~range()=default
Use the default destructor.
const range & operator=(const range< OtherAccessor > &other) const
Definition range.hpp:364
constexpr range(AccessorParams &&... params)
Creates a new range.
Definition range.hpp:329
The Ginkgo namespace.
Definition abstract_factory.hpp:20
constexpr T one()
Returns the multiplicative identity for T.
Definition math.hpp:775
constexpr T zero()
Returns the additive identity for T.
Definition math.hpp:747
constexpr auto imag(const T &x)
Returns the imaginary part of the object.
Definition math.hpp:1029
std::size_t size_type
Integral type used for allocation quantities.
Definition types.hpp:86
constexpr T min(const T &x, const T &y)
Returns the smaller of the arguments.
Definition math.hpp:863
batch_dim< 2, DimensionType > transpose(const batch_dim< 2, DimensionType > &input)
Returns a batch_dim object with its dimensions swapped for batched operators.
Definition batch_dim.hpp:120
constexpr xstd::enable_if_t<!is_complex_s< T >::value, T > abs(const T &x)
Returns the absolute value of the object.
Definition math.hpp:1076
constexpr auto squared_norm(const T &x) -> decltype(real(conj(x) *x))
Returns the squared norm of the object.
Definition math.hpp:1057
constexpr auto conj(const T &x)
Returns the conjugate of an object.
Definition math.hpp:1043
constexpr T max(const T &x, const T &y)
Returns the larger of the arguments.
Definition math.hpp:845
constexpr auto real(const T &x)
Returns the real part of the object.
Definition math.hpp:1013
Definition range.hpp:700
Definition range.hpp:711
Definition range.hpp:925
Definition range.hpp:953
Definition range.hpp:691
Definition range.hpp:951
Definition range.hpp:955
Definition range.hpp:706
Definition range.hpp:714
Definition range.hpp:928
Definition range.hpp:939
Definition range.hpp:938
Definition range.hpp:934
Definition range.hpp:704
Definition range.hpp:713
Definition range.hpp:957
Definition range.hpp:936
Definition range.hpp:932
Definition range.hpp:947
Definition range.hpp:687
Definition range.hpp:945
Definition range.hpp:963
Definition range.hpp:965
Definition range.hpp:974
Definition range.hpp:929
Definition range.hpp:927
Definition range.hpp:941
Definition range.hpp:698
Definition range.hpp:710
Definition range.hpp:702
Definition range.hpp:712
Definition range.hpp:959
Definition range.hpp:716
Definition range.hpp:926
Definition range.hpp:722
Definition range.hpp:683
Definition range.hpp:681
Definition range.hpp:696
A span is a lightweight structure used to create sub-ranges from other ranges.
Definition range.hpp:47
constexpr span(size_type begin, size_type end) noexcept
Creates a span.
Definition range.hpp:65
constexpr span(size_type point) noexcept
Creates a span representing a point point.
Definition range.hpp:55
constexpr bool is_valid() const
Checks if a span is valid.
Definition range.hpp:74
constexpr size_type length() const
Returns the length of a span.
Definition range.hpp:81
const size_type begin
Beginning of the span.
Definition range.hpp:86
const size_type end
End of the span.
Definition range.hpp:91