Ginkgo Generated from branch based on master. Ginkgo version 1.8.0
A numerical linear algebra library targeting many-core architectures
Loading...
Searching...
No Matches
batch_lin_op.hpp
1// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_BASE_BATCH_LIN_OP_HPP_
6#define GKO_PUBLIC_CORE_BASE_BATCH_LIN_OP_HPP_
7
8
9#include <memory>
10#include <type_traits>
11#include <utility>
12
13
14#include <ginkgo/core/base/abstract_factory.hpp>
15#include <ginkgo/core/base/batch_multi_vector.hpp>
16#include <ginkgo/core/base/dim.hpp>
17#include <ginkgo/core/base/exception_helpers.hpp>
18#include <ginkgo/core/base/math.hpp>
19#include <ginkgo/core/base/matrix_assembly_data.hpp>
20#include <ginkgo/core/base/matrix_data.hpp>
21#include <ginkgo/core/base/polymorphic_object.hpp>
22#include <ginkgo/core/base/types.hpp>
23#include <ginkgo/core/base/utils.hpp>
24#include <ginkgo/core/log/logger.hpp>
25
26
27namespace gko {
28namespace batch {
29
30
60class BatchLinOp : public EnableAbstractPolymorphicObject<BatchLinOp> {
61public:
68 {
70 }
71
78
84 const batch_dim<2>& get_size() const noexcept { return size_; }
85
91 template <typename ValueType>
94 {
95 GKO_ASSERT_EQ(b->get_num_batch_items(), this->get_num_batch_items());
96 GKO_ASSERT_EQ(this->get_num_batch_items(), x->get_num_batch_items());
97
98 GKO_ASSERT_CONFORMANT(this->get_common_size(), b->get_common_size());
99 GKO_ASSERT_EQUAL_ROWS(this->get_common_size(), x->get_common_size());
100 GKO_ASSERT_EQUAL_COLS(b->get_common_size(), x->get_common_size());
101 }
102
108 template <typename ValueType>
110 const MultiVector<ValueType>* b,
111 const MultiVector<ValueType>* beta,
112 MultiVector<ValueType>* x) const
113 {
114 GKO_ASSERT_EQ(b->get_num_batch_items(), this->get_num_batch_items());
115 GKO_ASSERT_EQ(this->get_num_batch_items(), x->get_num_batch_items());
116
117 GKO_ASSERT_CONFORMANT(this->get_common_size(), b->get_common_size());
118 GKO_ASSERT_EQUAL_ROWS(this->get_common_size(), x->get_common_size());
119 GKO_ASSERT_EQUAL_COLS(b->get_common_size(), x->get_common_size());
120 GKO_ASSERT_EQUAL_DIMENSIONS(alpha->get_common_size(),
121 gko::dim<2>(1, 1));
122 GKO_ASSERT_EQUAL_DIMENSIONS(beta->get_common_size(), gko::dim<2>(1, 1));
123 }
124
125protected:
131 void set_size(const batch_dim<2>& size) { size_ = size; }
132
139 explicit BatchLinOp(std::shared_ptr<const Executor> exec,
140 const batch_dim<2>& batch_size)
141 : EnableAbstractPolymorphicObject<BatchLinOp>(exec), size_{batch_size}
142 {}
143
152 explicit BatchLinOp(std::shared_ptr<const Executor> exec,
153 const size_type num_batch_items = 0,
154 const dim<2>& common_size = dim<2>{})
155 : BatchLinOp{std::move(exec),
156 num_batch_items > 0
157 ? batch_dim<2>(num_batch_items, common_size)
158 : batch_dim<2>{}}
159 {}
160
161private:
162 batch_dim<2> size_{};
163};
164
165
196 : public AbstractFactory<BatchLinOp, std::shared_ptr<const BatchLinOp>> {
197public:
199 std::shared_ptr<const BatchLinOp>>::AbstractFactory;
200
201 std::unique_ptr<BatchLinOp> generate(
202 std::shared_ptr<const BatchLinOp> input) const
203 {
204 this->template log<
205 gko::log::Logger::batch_linop_factory_generate_started>(
206 this, input.get());
207 const auto exec = this->get_executor();
208 std::unique_ptr<BatchLinOp> generated;
209 if (input->get_executor() == exec) {
210 generated = this->AbstractFactory::generate(input);
211 } else {
212 generated =
213 this->AbstractFactory::generate(gko::clone(exec, input));
214 }
215 this->template log<
216 gko::log::Logger::batch_linop_factory_generate_completed>(
217 this, input.get(), generated.get());
218 return generated;
219 }
220};
221
222
250template <typename ConcreteBatchLinOp, typename PolymorphicBase = BatchLinOp>
252 : public EnablePolymorphicObject<ConcreteBatchLinOp, PolymorphicBase>,
253 public EnablePolymorphicAssignment<ConcreteBatchLinOp> {
254public:
255 using EnablePolymorphicObject<ConcreteBatchLinOp,
256 PolymorphicBase>::EnablePolymorphicObject;
257};
258
259
276template <typename ConcreteFactory, typename ConcreteBatchLinOp,
277 typename ParametersType, typename PolymorphicBase = BatchLinOpFactory>
279 EnableDefaultFactory<ConcreteFactory, ConcreteBatchLinOp, ParametersType,
280 PolymorphicBase>;
281
282
359#define GKO_ENABLE_BATCH_LIN_OP_FACTORY(_batch_lin_op, _parameters_name, \
360 _factory_name) \
361public: \
362 const _parameters_name##_type& get_##_parameters_name() const \
363 { \
364 return _parameters_name##_; \
365 } \
366 \
367 class _factory_name \
368 : public ::gko::batch::EnableDefaultBatchLinOpFactory< \
369 _factory_name, _batch_lin_op, _parameters_name##_type> { \
370 friend class ::gko::EnablePolymorphicObject< \
371 _factory_name, ::gko::batch::BatchLinOpFactory>; \
372 friend class ::gko::enable_parameters_type<_parameters_name##_type, \
373 _factory_name>; \
374 explicit _factory_name(std::shared_ptr<const ::gko::Executor> exec) \
375 : ::gko::batch::EnableDefaultBatchLinOpFactory< \
376 _factory_name, _batch_lin_op, _parameters_name##_type>( \
377 std::move(exec)) \
378 {} \
379 explicit _factory_name(std::shared_ptr<const ::gko::Executor> exec, \
380 const _parameters_name##_type& parameters) \
381 : ::gko::batch::EnableDefaultBatchLinOpFactory< \
382 _factory_name, _batch_lin_op, _parameters_name##_type>( \
383 std::move(exec), parameters) \
384 {} \
385 }; \
386 friend ::gko::batch::EnableDefaultBatchLinOpFactory< \
387 _factory_name, _batch_lin_op, _parameters_name##_type>; \
388 \
389 \
390private: \
391 _parameters_name##_type _parameters_name##_; \
392 \
393public: \
394 static_assert(true, \
395 "This assert is used to counter the false positive extra " \
396 "semi-colon warnings")
397
398
399} // namespace batch
400} // namespace gko
401
402
403#endif // GKO_PUBLIC_CORE_BASE_BATCH_LIN_OP_HPP_
The AbstractFactory is a generic interface template that enables easy implementation of the abstract ...
Definition abstract_factory.hpp:47
This mixin inherits from (a subclass of) PolymorphicObject and provides a base implementation of a ne...
Definition polymorphic_object.hpp:346
This mixin provides a default implementation of a concrete factory.
Definition abstract_factory.hpp:126
This mixin is used to enable a default PolymorphicObject::copy_from() implementation for objects that...
Definition polymorphic_object.hpp:724
This mixin inherits from (a subclass of) PolymorphicObject and provides a base implementation of a ne...
Definition polymorphic_object.hpp:663
A BatchLinOpFactory represents a higher order mapping which transforms one batch linear operator into...
Definition batch_lin_op.hpp:196
Definition batch_lin_op.hpp:60
const batch_dim< 2 > & get_size() const noexcept
Returns the size of the batch operator.
Definition batch_lin_op.hpp:84
void validate_application_parameters(const MultiVector< ValueType > *b, MultiVector< ValueType > *x) const
Validates the sizes for the apply(b,x) operation in the concrete BatchLinOp.
Definition batch_lin_op.hpp:92
void validate_application_parameters(const MultiVector< ValueType > *alpha, const MultiVector< ValueType > *b, const MultiVector< ValueType > *beta, MultiVector< ValueType > *x) const
Validates the sizes for the apply(alpha, b , beta, x) operation in the concrete BatchLinOp.
Definition batch_lin_op.hpp:109
dim< 2 > get_common_size() const
Returns the common size of the batch items.
Definition batch_lin_op.hpp:77
size_type get_num_batch_items() const noexcept
Returns the number of items in the batch operator.
Definition batch_lin_op.hpp:67
The EnableBatchLinOp mixin can be used to provide sensible default implementations of the majority of...
Definition batch_lin_op.hpp:253
MultiVector stores multiple vectors in a batched fashion and is useful for batched operations.
Definition logger.hpp:41
dim< 2 > get_common_size() const
Returns the common size of the batch items.
Definition batch_multi_vector.hpp:127
size_type get_num_batch_items() const
Returns the number of batch items.
Definition batch_multi_vector.hpp:117
The Ginkgo namespace.
Definition abstract_factory.hpp:20
std::size_t size_type
Integral type used for allocation quantities.
Definition types.hpp:86
detail::cloned_type< Pointer > clone(const Pointer &p)
Creates a unique clone of the object pointed to by p.
Definition utils_helper.hpp:175
A type representing the dimensions of a multidimensional batch object.
Definition batch_dim.hpp:28
dim< dimensionality, dimension_type > get_common_size() const
Get the common size of the batch items.
Definition batch_dim.hpp:44
size_type get_num_batch_items() const
Get the number of batch items stored.
Definition batch_dim.hpp:37
A type representing the dimensions of a multidimensional object.
Definition dim.hpp:27