Ginkgo Generated from branch based on master. Ginkgo version 1.8.0
A numerical linear algebra library targeting many-core architectures
Loading...
Searching...
No Matches
math.hpp
1// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_BASE_MATH_HPP_
6#define GKO_PUBLIC_CORE_BASE_MATH_HPP_
7
8
9#include <cmath>
10#include <complex>
11#include <cstdlib>
12#include <limits>
13#include <type_traits>
14#include <utility>
15
16
17#include <ginkgo/config.hpp>
18#include <ginkgo/core/base/types.hpp>
19#include <ginkgo/core/base/utils.hpp>
20
21
22namespace gko {
23
24
25// HIP should not see std::abs or std::sqrt, we want the custom implementation.
26// Hence, provide the using declaration only for some cases
27namespace kernels {
28namespace reference {
29
30
31using std::abs;
32
33
34using std::sqrt;
35
36
37} // namespace reference
38} // namespace kernels
39
40
41namespace kernels {
42namespace omp {
43
44
45using std::abs;
46
47
48using std::sqrt;
49
50
51} // namespace omp
52} // namespace kernels
53
54
55namespace kernels {
56namespace cuda {
57
58
59using std::abs;
60
61
62using std::sqrt;
63
64
65} // namespace cuda
66} // namespace kernels
67
68
69namespace kernels {
70namespace dpcpp {
71
72
73using std::abs;
74
75
76using std::sqrt;
77
78
79} // namespace dpcpp
80} // namespace kernels
81
82
83namespace test {
84
85
86using std::abs;
87
88
89using std::sqrt;
90
91
92} // namespace test
93
94
95// type manipulations
96
97
103namespace detail {
104
105
109template <typename T>
110struct remove_complex_impl {
111 using type = T;
112};
113
117template <typename T>
118struct remove_complex_impl<std::complex<T>> {
119 using type = T;
120};
121
122
128template <typename T>
129struct to_complex_impl {
130 using type = std::complex<T>;
131};
132
138template <typename T>
139struct to_complex_impl<std::complex<T>> {
140 using type = std::complex<T>;
141};
142
143
144template <typename T>
145struct is_complex_impl : public std::integral_constant<bool, false> {};
146
147template <typename T>
148struct is_complex_impl<std::complex<T>>
149 : public std::integral_constant<bool, true> {};
150
151
152template <typename T>
153struct is_complex_or_scalar_impl : std::is_scalar<T> {};
154
155template <typename T>
156struct is_complex_or_scalar_impl<std::complex<T>> : std::is_scalar<T> {};
157
158
166template <template <typename> class converter, typename T>
167struct template_converter {};
168
178template <template <typename> class converter, template <typename...> class T,
179 typename... Rest>
180struct template_converter<converter, T<Rest...>> {
181 using type = T<typename converter<Rest>::type...>;
182};
183
184
185template <typename T, typename = void>
186struct remove_complex_s {};
187
194template <typename T>
195struct remove_complex_s<T,
196 std::enable_if_t<is_complex_or_scalar_impl<T>::value>> {
197 using type = typename detail::remove_complex_impl<T>::type;
198};
199
206template <typename T>
207struct remove_complex_s<
208 T, std::enable_if_t<!is_complex_or_scalar_impl<T>::value>> {
209 using type =
210 typename detail::template_converter<detail::remove_complex_impl,
211 T>::type;
212};
213
214
215template <typename T, typename = void>
216struct to_complex_s {};
217
224template <typename T>
225struct to_complex_s<T, std::enable_if_t<is_complex_or_scalar_impl<T>::value>> {
226 using type = typename detail::to_complex_impl<T>::type;
227};
228
235template <typename T>
236struct to_complex_s<T, std::enable_if_t<!is_complex_or_scalar_impl<T>::value>> {
237 using type =
238 typename detail::template_converter<detail::to_complex_impl, T>::type;
239};
240
241
242} // namespace detail
243
244
250template <typename T>
253 using type = T;
254};
255
261template <typename T>
262struct cpx_real_type<std::complex<T>> {
264 using type = typename std::complex<T>::value_type;
265};
266
267
276template <typename T>
277using is_complex_s = detail::is_complex_impl<T>;
278
286template <typename T>
287GKO_INLINE GKO_ATTRIBUTES constexpr bool is_complex()
288{
289 return detail::is_complex_impl<T>::value;
290}
291
292
300template <typename T>
301using is_complex_or_scalar_s = detail::is_complex_or_scalar_impl<T>;
302
310template <typename T>
311GKO_INLINE GKO_ATTRIBUTES constexpr bool is_complex_or_scalar()
312{
313 return detail::is_complex_or_scalar_impl<T>::value;
314}
315
316
325template <typename T>
326using remove_complex = typename detail::remove_complex_s<T>::type;
327
328
344template <typename T>
345using to_complex = typename detail::to_complex_s<T>::type;
346
347
353template <typename T>
355
356
357namespace detail {
358
359
360// singly linked list of all our supported precisions
361template <typename T>
362struct next_precision_impl {};
363
364template <>
365struct next_precision_impl<float> {
366 using type = double;
367};
368
369template <>
370struct next_precision_impl<double> {
371 using type = float;
372};
373
374template <typename T>
375struct next_precision_impl<std::complex<T>> {
376 using type = std::complex<typename next_precision_impl<T>::type>;
377};
378
379
380template <typename T>
381struct reduce_precision_impl {
382 using type = T;
383};
384
385template <typename T>
386struct reduce_precision_impl<std::complex<T>> {
387 using type = std::complex<typename reduce_precision_impl<T>::type>;
388};
389
390template <>
391struct reduce_precision_impl<double> {
392 using type = float;
393};
394
395template <>
396struct reduce_precision_impl<float> {
397 using type = half;
398};
399
400
401template <typename T>
402struct increase_precision_impl {
403 using type = T;
404};
405
406template <typename T>
407struct increase_precision_impl<std::complex<T>> {
408 using type = std::complex<typename increase_precision_impl<T>::type>;
409};
410
411template <>
412struct increase_precision_impl<float> {
413 using type = double;
414};
415
416template <>
417struct increase_precision_impl<half> {
418 using type = float;
419};
420
421
422template <typename T>
423struct infinity_impl {
424 // CUDA doesn't allow us to call std::numeric_limits functions
425 // so we need to store the value instead.
426 static constexpr auto value = std::numeric_limits<T>::infinity();
427};
428
429
433template <typename T1, typename T2>
434struct highest_precision_impl {
435 using type = decltype(T1{} + T2{});
436};
437
438template <typename T1, typename T2>
439struct highest_precision_impl<std::complex<T1>, std::complex<T2>> {
440 using type = std::complex<typename highest_precision_impl<T1, T2>::type>;
441};
442
443template <typename Head, typename... Tail>
444struct highest_precision_variadic {
445 using type = typename highest_precision_impl<
446 Head, typename highest_precision_variadic<Tail...>::type>::type;
447};
448
449template <typename Head>
450struct highest_precision_variadic<Head> {
451 using type = Head;
452};
453
454
455} // namespace detail
456
457
461template <typename T>
462using next_precision = typename detail::next_precision_impl<T>::type;
463
464
471template <typename T>
473
474
478template <typename T>
479using reduce_precision = typename detail::reduce_precision_impl<T>::type;
480
481
485template <typename T>
486using increase_precision = typename detail::increase_precision_impl<T>::type;
487
488
500template <typename... Ts>
502 typename detail::highest_precision_variadic<Ts...>::type;
503
504
514template <typename T>
515GKO_INLINE GKO_ATTRIBUTES constexpr reduce_precision<T> round_down(T val)
516{
517 return static_cast<reduce_precision<T>>(val);
518}
519
520
530template <typename T>
531GKO_INLINE GKO_ATTRIBUTES constexpr increase_precision<T> round_up(T val)
532{
533 return static_cast<increase_precision<T>>(val);
534}
535
536
537template <typename FloatType, size_type NumComponents, size_type ComponentId>
539
540
541namespace detail {
542
543
544template <typename T>
545struct truncate_type_impl {
546 using type = truncated<T, 2, 0>;
547};
548
549template <typename T, size_type Components>
550struct truncate_type_impl<truncated<T, Components, 0>> {
552};
553
554template <typename T>
555struct truncate_type_impl<std::complex<T>> {
556 using type = std::complex<typename truncate_type_impl<T>::type>;
557};
558
559
560template <typename T>
561struct type_size_impl {
562 static constexpr auto value = sizeof(T) * byte_size;
563};
564
565template <typename T>
566struct type_size_impl<std::complex<T>> {
567 static constexpr auto value = sizeof(T) * byte_size;
568};
569
570
571} // namespace detail
572
573
578template <typename T, size_type Limit = sizeof(uint16) * byte_size>
580 std::conditional_t<detail::type_size_impl<T>::value >= 2 * Limit,
582
583
590template <typename S, typename R>
598 GKO_ATTRIBUTES R operator()(S val) { return static_cast<R>(val); }
599};
600
601
602// mathematical functions
603
604
613GKO_INLINE GKO_ATTRIBUTES constexpr int64 ceildiv(int64 num, int64 den)
614{
615 return (num + den - 1) / den;
616}
617
618
619#if defined(__HIPCC__) && GINKGO_HIP_PLATFORM_HCC
620
621
627template <typename T>
628GKO_INLINE __host__ constexpr T zero()
629{
630 return T{};
631}
632
633
643template <typename T>
644GKO_INLINE __host__ constexpr T zero(const T&)
645{
646 return zero<T>();
647}
648
649
655template <typename T>
656GKO_INLINE __host__ constexpr T one()
657{
658 return T(1);
659}
660
661
671template <typename T>
672GKO_INLINE __host__ constexpr T one(const T&)
673{
674 return one<T>();
675}
676
677
683template <typename T>
684GKO_INLINE __device__ constexpr std::enable_if_t<
685 !std::is_same<T, std::complex<remove_complex<T>>>::value, T>
686zero()
687{
688 return T{};
689}
690
691
701template <typename T>
702GKO_INLINE __device__ constexpr T zero(const T&)
703{
704 return zero<T>();
705}
706
707
713template <typename T>
714GKO_INLINE __device__ constexpr std::enable_if_t<
715 !std::is_same<T, std::complex<remove_complex<T>>>::value, T>
716one()
717{
718 return T(1);
719}
720
721
731template <typename T>
732GKO_INLINE __device__ constexpr T one(const T&)
733{
734 return one<T>();
735}
736
737
738#else
739
740
746template <typename T>
747GKO_INLINE GKO_ATTRIBUTES constexpr T zero()
748{
749 return T{};
750}
751
752
762template <typename T>
763GKO_INLINE GKO_ATTRIBUTES constexpr T zero(const T&)
764{
765 return zero<T>();
766}
767
768
774template <typename T>
775GKO_INLINE GKO_ATTRIBUTES constexpr T one()
776{
777 return T(1);
778}
779
780
790template <typename T>
791GKO_INLINE GKO_ATTRIBUTES constexpr T one(const T&)
792{
793 return one<T>();
794}
795
796
797#endif // defined(__HIPCC__) && GINKGO_HIP_PLATFORM_HCC
798
799
800#undef GKO_BIND_ZERO_ONE
801
802
811template <typename T>
812GKO_INLINE GKO_ATTRIBUTES constexpr bool is_zero(T value)
813{
814 return value == zero<T>();
815}
816
817
826template <typename T>
827GKO_INLINE GKO_ATTRIBUTES constexpr bool is_nonzero(T value)
828{
829 return value != zero<T>();
830}
831
832
844template <typename T>
845GKO_INLINE GKO_ATTRIBUTES constexpr T max(const T& x, const T& y)
846{
847 return x >= y ? x : y;
848}
849
850
862template <typename T>
863GKO_INLINE GKO_ATTRIBUTES constexpr T min(const T& x, const T& y)
864{
865 return x <= y ? x : y;
866}
867
868
869namespace detail {
870
871
881template <typename Ref, typename Dummy = xstd::void_t<>>
882struct has_to_arithmetic_type : std::false_type {
883 static_assert(std::is_same<Dummy, void>::value,
884 "Do not modify the Dummy value!");
885 using type = Ref;
886};
887
888template <typename Ref>
889struct has_to_arithmetic_type<
890 Ref, xstd::void_t<decltype(std::declval<Ref>().to_arithmetic_type())>>
891 : std::true_type {
892 using type = decltype(std::declval<Ref>().to_arithmetic_type());
893};
894
895
900template <typename Ref, typename Dummy = xstd::void_t<>>
901struct has_arithmetic_type : std::false_type {
902 static_assert(std::is_same<Dummy, void>::value,
903 "Do not modify the Dummy value!");
904};
905
906template <typename Ref>
907struct has_arithmetic_type<Ref, xstd::void_t<typename Ref::arithmetic_type>>
908 : std::true_type {};
909
910
922template <typename Ref>
923constexpr GKO_ATTRIBUTES
924 std::enable_if_t<has_to_arithmetic_type<Ref>::value,
925 typename has_to_arithmetic_type<Ref>::type>
926 to_arithmetic_type(const Ref& ref)
927{
928 return ref.to_arithmetic_type();
929}
930
931template <typename Ref>
932constexpr GKO_ATTRIBUTES std::enable_if_t<!has_to_arithmetic_type<Ref>::value &&
933 has_arithmetic_type<Ref>::value,
934 typename Ref::arithmetic_type>
935to_arithmetic_type(const Ref& ref)
936{
937 return ref;
938}
939
940template <typename Ref>
941constexpr GKO_ATTRIBUTES std::enable_if_t<!has_to_arithmetic_type<Ref>::value &&
942 !has_arithmetic_type<Ref>::value,
943 Ref>
944to_arithmetic_type(const Ref& ref)
945{
946 return ref;
947}
948
949
950// Note: All functions have postfix `impl` so they are not considered for
951// overload resolution (in case a class / function also is in the namespace
952// `detail`)
953template <typename T>
954GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<!is_complex_s<T>::value, T>
955real_impl(const T& x)
956{
957 return x;
958}
959
960template <typename T>
961GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<is_complex_s<T>::value,
963real_impl(const T& x)
964{
965 return x.real();
966}
967
968
969template <typename T>
970GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<!is_complex_s<T>::value, T>
971imag_impl(const T&)
972{
973 return T{};
974}
975
976template <typename T>
977GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<is_complex_s<T>::value,
979imag_impl(const T& x)
980{
981 return x.imag();
982}
983
984
985template <typename T>
986GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<!is_complex_s<T>::value, T>
987conj_impl(const T& x)
988{
989 return x;
990}
991
992template <typename T>
993GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<is_complex_s<T>::value, T>
994conj_impl(const T& x)
995{
996 return T{real_impl(x), -imag_impl(x)};
997}
998
999
1000} // namespace detail
1001
1002
1012template <typename T>
1013GKO_ATTRIBUTES GKO_INLINE constexpr auto real(const T& x)
1014{
1015 return detail::real_impl(detail::to_arithmetic_type(x));
1016}
1017
1018
1028template <typename T>
1029GKO_ATTRIBUTES GKO_INLINE constexpr auto imag(const T& x)
1030{
1031 return detail::imag_impl(detail::to_arithmetic_type(x));
1032}
1033
1034
1042template <typename T>
1043GKO_ATTRIBUTES GKO_INLINE constexpr auto conj(const T& x)
1044{
1045 return detail::conj_impl(detail::to_arithmetic_type(x));
1046}
1047
1048
1056template <typename T>
1057GKO_INLINE GKO_ATTRIBUTES constexpr auto squared_norm(const T& x)
1058 -> decltype(real(conj(x) * x))
1059{
1060 return real(conj(x) * x);
1061}
1062
1063
1073template <typename T>
1074GKO_INLINE
1075 GKO_ATTRIBUTES constexpr xstd::enable_if_t<!is_complex_s<T>::value, T>
1076 abs(const T& x)
1077{
1078 return x >= zero<T>() ? x : -x;
1079}
1080
1081
1082template <typename T>
1083GKO_INLINE GKO_ATTRIBUTES constexpr xstd::enable_if_t<is_complex_s<T>::value,
1085abs(const T& x)
1086{
1087 return sqrt(squared_norm(x));
1088}
1089
1090
1096template <typename T>
1097GKO_INLINE GKO_ATTRIBUTES constexpr T pi()
1098{
1099 return static_cast<T>(3.1415926535897932384626433);
1100}
1101
1102
1111template <typename T>
1112GKO_INLINE GKO_ATTRIBUTES constexpr std::complex<remove_complex<T>> unit_root(
1113 int64 n, int64 k = 1)
1114{
1115 return std::polar(one<remove_complex<T>>(),
1116 remove_complex<T>{2} * pi<remove_complex<T>>() * k / n);
1117}
1118
1119
1132template <typename T>
1133constexpr uint32 get_significant_bit(const T& n, uint32 hint = 0u) noexcept
1134{
1135 return (T{1} << (hint + 1)) > n ? hint : get_significant_bit(n, hint + 1u);
1136}
1137
1138
1150template <typename T>
1151constexpr T get_superior_power(const T& base, const T& limit,
1152 const T& hint = T{1}) noexcept
1153{
1154 return hint >= limit ? hint : get_superior_power(base, limit, hint * base);
1155}
1156
1157
1169template <typename T>
1170GKO_INLINE GKO_ATTRIBUTES std::enable_if_t<!is_complex_s<T>::value, bool>
1171is_finite(const T& value)
1172{
1173 constexpr T infinity{detail::infinity_impl<T>::value};
1174 return abs(value) < infinity;
1175}
1176
1177
1189template <typename T>
1190GKO_INLINE GKO_ATTRIBUTES std::enable_if_t<is_complex_s<T>::value, bool>
1191is_finite(const T& value)
1192{
1193 return is_finite(value.real()) && is_finite(value.imag());
1194}
1195
1196
1208template <typename T>
1209GKO_INLINE GKO_ATTRIBUTES T safe_divide(T a, T b)
1210{
1211 return b == zero<T>() ? zero<T>() : a / b;
1212}
1213
1214
1224template <typename T>
1225GKO_INLINE GKO_ATTRIBUTES std::enable_if_t<!is_complex_s<T>::value, bool>
1226is_nan(const T& value)
1227{
1228 return std::isnan(value);
1229}
1230
1231
1241template <typename T>
1242GKO_INLINE GKO_ATTRIBUTES std::enable_if_t<is_complex_s<T>::value, bool> is_nan(
1243 const T& value)
1244{
1245 return std::isnan(value.real()) || std::isnan(value.imag());
1246}
1247
1248
1256template <typename T>
1257GKO_INLINE GKO_ATTRIBUTES constexpr std::enable_if_t<!is_complex_s<T>::value, T>
1259{
1260 return std::numeric_limits<T>::quiet_NaN();
1261}
1262
1263
1271template <typename T>
1272GKO_INLINE GKO_ATTRIBUTES constexpr std::enable_if_t<is_complex_s<T>::value, T>
1274{
1276}
1277
1278
1279} // namespace gko
1280
1281
1282#endif // GKO_PUBLIC_CORE_BASE_MATH_HPP_
Definition math.hpp:538
The Ginkgo namespace.
Definition abstract_factory.hpp:20
constexpr T one()
Returns the multiplicative identity for T.
Definition math.hpp:775
std::enable_if_t<!is_complex_s< T >::value, bool > is_finite(const T &value)
Checks if a floating point number is finite, meaning it is neither +/- infinity nor NaN.
Definition math.hpp:1171
constexpr T pi()
Returns the value of pi.
Definition math.hpp:1097
typename detail::remove_complex_s< T >::type remove_complex
Obtain the type which removed the complex of complex/scalar type or the template parameter of class b...
Definition math.hpp:326
constexpr increase_precision< T > round_up(T val)
Increases the precision of the input parameter.
Definition math.hpp:531
std::conditional_t< detail::type_size_impl< T >::value >=2 *Limit, typename detail::truncate_type_impl< T >::type, T > truncate_type
Truncates the type by half (by dropping bits), but ensures that it is at least Limit bits wide.
Definition math.hpp:579
typename detail::next_precision_impl< T >::type next_precision
Obtains the next type in the singly-linked precision list.
Definition math.hpp:462
typename detail::highest_precision_variadic< Ts... >::type highest_precision
Obtains the smallest arithmetic type that is able to store elements of all template parameter types e...
Definition math.hpp:501
typename detail::to_complex_s< T >::type to_complex
Obtain the type which adds the complex of complex/scalar type or the template parameter of class by a...
Definition math.hpp:345
constexpr uint32 get_significant_bit(const T &n, uint32 hint=0u) noexcept
Returns the position of the most significant bit of the number.
Definition math.hpp:1133
constexpr bool is_complex_or_scalar()
Checks if T is a complex/scalar type.
Definition math.hpp:311
std::enable_if_t<!is_complex_s< T >::value, bool > is_nan(const T &value)
Checks if a floating point number is NaN.
Definition math.hpp:1226
detail::is_complex_impl< T > is_complex_s
Allows to check if T is a complex value during compile time by accessing the value attribute of this ...
Definition math.hpp:277
constexpr std::enable_if_t<!is_complex_s< T >::value, T > nan()
Returns a quiet NaN of the given type.
Definition math.hpp:1258
constexpr T zero()
Returns the additive identity for T.
Definition math.hpp:747
constexpr bool is_zero(T value)
Returns true if and only if the given value is zero.
Definition math.hpp:812
constexpr auto imag(const T &x)
Returns the imaginary part of the object.
Definition math.hpp:1029
std::uint32_t uint32
32-bit unsigned integral type.
Definition types.hpp:126
constexpr std::complex< remove_complex< T > > unit_root(int64 n, int64 k=1)
Returns the value of exp(2 * pi * i * k / n), i.e.
Definition math.hpp:1112
typename detail::reduce_precision_impl< T >::type reduce_precision
Obtains the next type in the hierarchy with lower precision than T.
Definition math.hpp:479
constexpr reduce_precision< T > round_down(T val)
Reduces the precision of the input parameter.
Definition math.hpp:515
std::int64_t int64
64-bit signed integral type.
Definition types.hpp:109
constexpr int64 ceildiv(int64 num, int64 den)
Performs integer division with rounding up.
Definition math.hpp:613
constexpr bool is_complex()
Checks if T is a complex type.
Definition math.hpp:287
T safe_divide(T a, T b)
Computes the quotient of the given parameters, guarding against division by zero.
Definition math.hpp:1209
constexpr T min(const T &x, const T &y)
Returns the smaller of the arguments.
Definition math.hpp:863
constexpr xstd::enable_if_t<!is_complex_s< T >::value, T > abs(const T &x)
Returns the absolute value of the object.
Definition math.hpp:1076
constexpr auto squared_norm(const T &x) -> decltype(real(conj(x) *x))
Returns the squared norm of the object.
Definition math.hpp:1057
next_precision< T > previous_precision
Obtains the previous type in the singly-linked precision list.
Definition math.hpp:472
detail::is_complex_or_scalar_impl< T > is_complex_or_scalar_s
Allows to check if T is a complex or scalar value during compile time by accessing the value attribut...
Definition math.hpp:301
constexpr size_type byte_size
Number of bits in a byte.
Definition types.hpp:177
constexpr T get_superior_power(const T &base, const T &limit, const T &hint=T{1}) noexcept
Returns the smallest power of base not smaller than limit.
Definition math.hpp:1151
typename detail::increase_precision_impl< T >::type increase_precision
Obtains the next type in the hierarchy with higher precision than T.
Definition math.hpp:486
remove_complex< T > to_real
to_real is alias of remove_complex
Definition math.hpp:354
constexpr auto conj(const T &x)
Returns the conjugate of an object.
Definition math.hpp:1043
constexpr bool is_nonzero(T value)
Returns true if and only if the given value is not zero.
Definition math.hpp:827
constexpr T max(const T &x, const T &y)
Returns the larger of the arguments.
Definition math.hpp:845
constexpr auto real(const T &x)
Returns the real part of the object.
Definition math.hpp:1013
Access the underlying real type of a complex number.
Definition math.hpp:251
T type
The type.
Definition math.hpp:253
Used to convert objects of type S to objects of type R using static_cast.
Definition math.hpp:591
R operator()(S val)
Converts the object to result type.
Definition math.hpp:598