Ginkgo Generated from branch based on master. Ginkgo version 1.8.0
A numerical linear algebra library targeting many-core architectures
Loading...
Searching...
No Matches
ic.hpp
1// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_PRECONDITIONER_IC_HPP_
6#define GKO_PUBLIC_CORE_PRECONDITIONER_IC_HPP_
7
8
9#include <memory>
10#include <type_traits>
11
12
13#include <ginkgo/core/base/abstract_factory.hpp>
14#include <ginkgo/core/base/composition.hpp>
15#include <ginkgo/core/base/exception.hpp>
16#include <ginkgo/core/base/exception_helpers.hpp>
17#include <ginkgo/core/base/lin_op.hpp>
18#include <ginkgo/core/base/precision_dispatch.hpp>
19#include <ginkgo/core/base/std_extensions.hpp>
20#include <ginkgo/core/config/config.hpp>
21#include <ginkgo/core/config/registry.hpp>
22#include <ginkgo/core/factorization/par_ic.hpp>
23#include <ginkgo/core/matrix/dense.hpp>
24#include <ginkgo/core/preconditioner/isai.hpp>
25#include <ginkgo/core/preconditioner/utils.hpp>
26#include <ginkgo/core/solver/gmres.hpp>
27#include <ginkgo/core/solver/ir.hpp>
28#include <ginkgo/core/solver/solver_traits.hpp>
29#include <ginkgo/core/solver/triangular.hpp>
30#include <ginkgo/core/stop/combined.hpp>
31#include <ginkgo/core/stop/iteration.hpp>
32#include <ginkgo/core/stop/residual_norm.hpp>
33
34
35namespace gko {
36namespace preconditioner {
37namespace detail {
38
39
40template <typename Type>
41constexpr bool support_ic_parse =
42 is_instantiation_of<Type, solver::LowerTrs>::value ||
43 is_instantiation_of<Type, solver::Ir>::value ||
44 is_instantiation_of<Type, solver::Gmres>::value ||
45 is_instantiation_of<Type, preconditioner::LowerIsai>::value;
46
47
48template <
49 typename Ic,
50 std::enable_if_t<!support_ic_parse<typename Ic::l_solver_type>>* = nullptr>
51typename Ic::parameters_type ic_parse(
52 const config::pnode& config, const config::registry& context,
53 const config::type_descriptor& td_for_child)
54{
55 GKO_INVALID_STATE(
56 "preconditioner::Ic only supports limited type for parse.");
57}
58
59template <
60 typename Ic,
61 std::enable_if_t<support_ic_parse<typename Ic::l_solver_type>>* = nullptr>
62typename Ic::parameters_type ic_parse(
63 const config::pnode& config, const config::registry& context,
64 const config::type_descriptor& td_for_child);
65
66
67} // namespace detail
68
114template <typename LSolverType = solver::LowerTrs<>, typename IndexType = int32>
115class Ic : public EnableLinOp<Ic<LSolverType, IndexType>>, public Transposable {
116 friend class EnableLinOp<Ic>;
117 friend class EnablePolymorphicObject<Ic, LinOp>;
118
119public:
120 static_assert(
121 std::is_same<typename LSolverType::transposed_type::transposed_type,
122 LSolverType>::value,
123 "LSolverType::transposed_type must be symmetric");
124 using value_type = typename LSolverType::value_type;
125 using l_solver_type = LSolverType;
126 using lh_solver_type = typename LSolverType::transposed_type;
127 using index_type = IndexType;
129
130 class Factory;
131
133 : public enable_parameters_type<parameters_type, Factory> {
137 std::shared_ptr<const typename l_solver_type::Factory>
139
143 std::shared_ptr<const LinOpFactory> factorization_factory{};
144
145 GKO_DEPRECATED("use with_l_solver instead")
146 parameters_type& with_l_solver_factory(
147 deferred_factory_parameter<const typename l_solver_type::Factory>
148 solver)
149 {
150 return with_l_solver(std::move(solver));
151 }
152
153 parameters_type& with_l_solver(
155 solver)
156 {
157 this->l_solver_generator = std::move(solver);
158 this->deferred_factories["l_solver"] = [](const auto& exec,
159 auto& params) {
160 if (!params.l_solver_generator.is_empty()) {
161 params.l_solver_factory =
162 params.l_solver_generator.on(exec);
163 }
164 };
165 return *this;
166 }
167
168 GKO_DEPRECATED("use with_factorization instead")
169 parameters_type& with_factorization_factory(
170 deferred_factory_parameter<const LinOpFactory> factorization)
171 {
172 return with_factorization(std::move(factorization));
173 }
174
175 parameters_type& with_factorization(
176 deferred_factory_parameter<const LinOpFactory> factorization)
177 {
178 this->factorization_generator = std::move(factorization);
179 this->deferred_factories["factorization"] = [](const auto& exec,
180 auto& params) {
181 if (!params.factorization_generator.is_empty()) {
182 params.factorization_factory =
183 params.factorization_generator.on(exec);
184 }
185 };
186 return *this;
187 }
188
189 private:
190 deferred_factory_parameter<const typename l_solver_type::Factory>
191 l_solver_generator;
192
193 deferred_factory_parameter<const LinOpFactory> factorization_generator;
194 };
195
198
216 const config::pnode& config, const config::registry& context,
217 const config::type_descriptor& td_for_child =
218 config::make_type_descriptor<value_type, index_type>())
219 {
220 return detail::ic_parse<Ic>(config, context, td_for_child);
221 }
222
228 std::shared_ptr<const l_solver_type> get_l_solver() const
229 {
230 return l_solver_;
231 }
232
238 std::shared_ptr<const lh_solver_type> get_lh_solver() const
239 {
240 return lh_solver_;
241 }
242
243 std::unique_ptr<LinOp> transpose() const override
244 {
245 std::unique_ptr<transposed_type> transposed{
246 new transposed_type{this->get_executor()}};
247 transposed->set_size(gko::transpose(this->get_size()));
248 transposed->l_solver_ =
250 this->get_lh_solver()->transpose()));
251 transposed->lh_solver_ =
253 this->get_l_solver()->transpose()));
254
255 return std::move(transposed);
256 }
257
258 std::unique_ptr<LinOp> conj_transpose() const override
259 {
260 std::unique_ptr<transposed_type> transposed{
261 new transposed_type{this->get_executor()}};
262 transposed->set_size(gko::transpose(this->get_size()));
263 transposed->l_solver_ =
265 this->get_lh_solver()->conj_transpose()));
266 transposed->lh_solver_ =
268 this->get_l_solver()->conj_transpose()));
269
270 return std::move(transposed);
271 }
272
278 Ic& operator=(const Ic& other)
279 {
280 if (&other != this) {
282 auto exec = this->get_executor();
283 l_solver_ = other.l_solver_;
284 lh_solver_ = other.lh_solver_;
285 parameters_ = other.parameters_;
286 if (other.get_executor() != exec) {
287 l_solver_ = gko::clone(exec, l_solver_);
288 lh_solver_ = gko::clone(exec, lh_solver_);
289 }
290 }
291 return *this;
292 }
293
300 Ic& operator=(Ic&& other)
301 {
302 if (&other != this) {
304 auto exec = this->get_executor();
305 l_solver_ = std::move(other.l_solver_);
306 lh_solver_ = std::move(other.lh_solver_);
307 parameters_ = std::exchange(other.parameters_, parameters_type{});
308 if (other.get_executor() != exec) {
309 l_solver_ = gko::clone(exec, l_solver_);
310 lh_solver_ = gko::clone(exec, lh_solver_);
311 }
312 }
313 return *this;
314 }
315
320 Ic(const Ic& other) : Ic{other.get_executor()} { *this = other; }
321
327 Ic(Ic&& other) : Ic{other.get_executor()} { *this = std::move(other); }
328
329protected:
330 void apply_impl(const LinOp* b, LinOp* x) const override
331 {
332 // take care of real-to-complex apply
334 [&](auto dense_b, auto dense_x) {
335 this->set_cache_to(dense_b);
336 l_solver_->apply(dense_b, cache_.intermediate);
337 if (lh_solver_->apply_uses_initial_guess()) {
338 dense_x->copy_from(cache_.intermediate);
339 }
340 lh_solver_->apply(cache_.intermediate, dense_x);
341 },
342 b, x);
343 }
344
345 void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta,
346 LinOp* x) const override
347 {
349 [&](auto dense_alpha, auto dense_b, auto dense_beta, auto dense_x) {
350 this->set_cache_to(dense_b);
351 l_solver_->apply(dense_b, cache_.intermediate);
352 lh_solver_->apply(dense_alpha, cache_.intermediate, dense_beta,
353 dense_x);
354 },
355 alpha, b, beta, x);
356 }
357
358 explicit Ic(std::shared_ptr<const Executor> exec)
359 : EnableLinOp<Ic>(std::move(exec))
360 {}
361
362 explicit Ic(const Factory* factory, std::shared_ptr<const LinOp> lin_op)
363 : EnableLinOp<Ic>(factory->get_executor(), lin_op->get_size()),
364 parameters_{factory->get_parameters()}
365 {
366 auto comp =
367 std::dynamic_pointer_cast<const Composition<value_type>>(lin_op);
368 std::shared_ptr<const LinOp> l_factor;
369
370 // build factorization if we weren't passed a composition
371 if (!comp) {
372 auto exec = lin_op->get_executor();
373 if (!parameters_.factorization_factory) {
374 parameters_.factorization_factory =
375 factorization::ParIc<value_type, index_type>::build()
376 .with_both_factors(false)
377 .on(exec);
378 }
379 auto fact = std::shared_ptr<const LinOp>(
380 parameters_.factorization_factory->generate(lin_op));
381 // ensure that the result is a composition
382 comp =
383 std::dynamic_pointer_cast<const Composition<value_type>>(fact);
384 if (!comp) {
385 GKO_NOT_SUPPORTED(comp);
386 }
387 }
388 // comp must contain one or two factors
389 if (comp->get_operators().size() > 2 || comp->get_operators().empty()) {
390 GKO_NOT_SUPPORTED(comp);
391 }
392 l_factor = comp->get_operators()[0];
393 GKO_ASSERT_IS_SQUARE_MATRIX(l_factor);
394
395 auto exec = this->get_executor();
396
397 // If no factories are provided, generate default ones
398 if (!parameters_.l_solver_factory) {
399 l_solver_ = generate_default_solver<l_solver_type>(exec, l_factor);
400 // If comp contains both factors: use the transposed factor to avoid
401 // transposing twice
402 if (comp->get_operators().size() == 2) {
403 auto lh_factor = comp->get_operators()[1];
404 GKO_ASSERT_EQUAL_DIMENSIONS(l_factor, lh_factor);
405 lh_solver_ = as<lh_solver_type>(l_solver_->conj_transpose());
406 } else {
407 lh_solver_ = as<lh_solver_type>(l_solver_->conj_transpose());
408 }
409 } else {
410 l_solver_ = parameters_.l_solver_factory->generate(l_factor);
411 lh_solver_ = as<lh_solver_type>(l_solver_->conj_transpose());
412 }
413 }
414
422 void set_cache_to(const LinOp* b) const
423 {
424 if (cache_.intermediate == nullptr) {
425 cache_.intermediate =
427 }
428 // Use b as the initial guess for the first triangular solve
429 cache_.intermediate->copy_from(b);
430 }
431
432
440 template <typename SolverType>
441 static std::enable_if_t<solver::has_with_criteria<SolverType>::value,
442 std::unique_ptr<SolverType>>
443 generate_default_solver(const std::shared_ptr<const Executor>& exec,
444 const std::shared_ptr<const LinOp>& mtx)
445 {
446 constexpr gko::remove_complex<value_type> default_reduce_residual{1e-4};
447 const unsigned int default_max_iters{
448 static_cast<unsigned int>(mtx->get_size()[0])};
449
450 return SolverType::build()
451 .with_criteria(
452 gko::stop::Iteration::build().with_max_iters(default_max_iters),
454 .with_reduction_factor(default_reduce_residual))
455 .on(exec)
456 ->generate(mtx);
457 }
458
462 template <typename SolverType>
463 static std::enable_if_t<!solver::has_with_criteria<SolverType>::value,
464 std::unique_ptr<SolverType>>
465 generate_default_solver(const std::shared_ptr<const Executor>& exec,
466 const std::shared_ptr<const LinOp>& mtx)
467 {
468 return SolverType::build().on(exec)->generate(mtx);
469 }
470
471private:
472 std::shared_ptr<const l_solver_type> l_solver_{};
473 std::shared_ptr<const lh_solver_type> lh_solver_{};
484 mutable struct cache_struct {
485 cache_struct() = default;
486 ~cache_struct() = default;
487 cache_struct(const cache_struct&) {}
488 cache_struct(cache_struct&&) {}
489 cache_struct& operator=(const cache_struct&) { return *this; }
490 cache_struct& operator=(cache_struct&&) { return *this; }
491 std::unique_ptr<LinOp> intermediate{};
492 } cache_;
493};
494
495
496} // namespace preconditioner
497} // namespace gko
498
499
500#endif // GKO_PUBLIC_CORE_PRECONDITIONER_IC_HPP_
The EnableLinOp mixin can be used to provide sensible default implementations of the majority of the ...
Definition lin_op.hpp:880
This mixin inherits from (a subclass of) PolymorphicObject and provides a base implementation of a ne...
Definition polymorphic_object.hpp:663
Definition lin_op.hpp:118
LinOp(const LinOp &)=default
Copy-constructs a LinOp.
const dim< 2 > & get_size() const noexcept
Returns the size of the operator.
Definition lin_op.hpp:211
LinOp & operator=(const LinOp &)=default
Copy-assigns a LinOp.
std::shared_ptr< const Executor > get_executor() const noexcept
Returns the Executor of the object.
Definition polymorphic_object.hpp:235
Linear operators which support transposition should implement the Transposable interface.
Definition lin_op.hpp:434
pnode describes a tree of properties.
Definition property_tree.hpp:28
This class stores additional context for creating Ginkgo objects from configuration files.
Definition registry.hpp:168
This class describes the value and index types to be used when building a Ginkgo type from a configur...
Definition type_descriptor.hpp:37
Represents a factory parameter of factory type that can either initialized by a pre-existing factory ...
Definition abstract_factory.hpp:309
The enable_parameters_type mixin is used to create a base implementation of the factory parameters st...
Definition abstract_factory.hpp:211
static std::unique_ptr< Dense > create(std::shared_ptr< const Executor > exec, const dim< 2 > &size={}, size_type stride=0)
Creates an uninitialized Dense matrix of the specified size.
The Incomplete Cholesky (IC) preconditioner solves the equation for a given lower triangular matrix ...
Definition ic.hpp:115
std::shared_ptr< const lh_solver_type > get_lh_solver() const
Returns the solver which is used for the L^H matrix.
Definition ic.hpp:238
std::unique_ptr< LinOp > transpose() const override
Returns a LinOp representing the transpose of the Transposable object.
Definition ic.hpp:243
Ic(const Ic &other)
Copy-constructs an IC preconditioner.
Definition ic.hpp:320
Ic & operator=(Ic &&other)
Move-assigns an IC preconditioner.
Definition ic.hpp:300
Ic(Ic &&other)
Move-constructs an IC preconditioner.
Definition ic.hpp:327
static parameters_type parse(const config::pnode &config, const config::registry &context, const config::type_descriptor &td_for_child=config::make_type_descriptor< value_type, index_type >())
Create the parameters from the property_tree.
Definition ic.hpp:215
std::shared_ptr< const l_solver_type > get_l_solver() const
Returns the solver which is used for the provided L matrix.
Definition ic.hpp:228
Ic & operator=(const Ic &other)
Copy-assigns an IC preconditioner.
Definition ic.hpp:278
std::unique_ptr< LinOp > conj_transpose() const override
Returns a LinOp representing the conjugate transpose of the Transposable object.
Definition ic.hpp:258
The ResidualNorm class is a stopping criterion which stops the iteration process when the actual resi...
Definition residual_norm.hpp:110
#define GKO_ENABLE_BUILD_METHOD(_factory_name)
Defines a build method for the factory, simplifying its construction by removing the repetitive typin...
Definition abstract_factory.hpp:394
#define GKO_ENABLE_LIN_OP_FACTORY(_lin_op, _parameters_name, _factory_name)
This macro will generate a default implementation of a LinOpFactory for the LinOp subclass it is defi...
Definition lin_op.hpp:1018
@ factory
LinOpFactory events.
The Ginkgo namespace.
Definition abstract_factory.hpp:20
typename detail::remove_complex_s< T >::type remove_complex
Obtain the type which removed the complex of complex/scalar type or the template parameter of class b...
Definition math.hpp:326
void precision_dispatch_real_complex(Function fn, const LinOp *in, LinOp *out)
Calls the given function with the given LinOps temporarily converted to matrix::Dense<ValueType>* as ...
Definition precision_dispatch.hpp:94
detail::cloned_type< Pointer > clone(const Pointer &p)
Creates a unique clone of the object pointed to by p.
Definition utils_helper.hpp:175
batch_dim< 2, DimensionType > transpose(const batch_dim< 2, DimensionType > &input)
Returns a batch_dim object with its dimensions swapped for batched operators.
Definition batch_dim.hpp:120
std::decay_t< T > * as(U *obj)
Performs polymorphic type conversion.
Definition utils_helper.hpp:309
detail::shared_type< OwningPointer > share(OwningPointer &&p)
Marks the object pointed to by p as shared.
Definition utils_helper.hpp:226
std::shared_ptr< const typename l_solver_type::Factory > l_solver_factory
Factory for the L solver.
Definition ic.hpp:138
std::shared_ptr< const LinOpFactory > factorization_factory
Factory for the factorization.
Definition ic.hpp:143