Ginkgo Generated from branch based on master. Ginkgo version 1.8.0
A numerical linear algebra library targeting many-core architectures
Loading...
Searching...
No Matches
solver_base.hpp
1// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_SOLVER_SOLVER_BASE_HPP_
6#define GKO_PUBLIC_CORE_SOLVER_SOLVER_BASE_HPP_
7
8
9#include <memory>
10#include <type_traits>
11#include <utility>
12
13
14#include <ginkgo/core/base/lin_op.hpp>
15#include <ginkgo/core/base/math.hpp>
16#include <ginkgo/core/log/logger.hpp>
17#include <ginkgo/core/matrix/dense.hpp>
18#include <ginkgo/core/matrix/identity.hpp>
19#include <ginkgo/core/solver/workspace.hpp>
20#include <ginkgo/core/stop/combined.hpp>
21#include <ginkgo/core/stop/criterion.hpp>
22
23
24GKO_BEGIN_DISABLE_DEPRECATION_WARNINGS
25
26
27namespace gko {
28namespace solver {
29
30
38 zero,
42 rhs,
47};
48
49
50namespace multigrid {
51namespace detail {
52
53
54class MultigridState;
55
56
57} // namespace detail
58} // namespace multigrid
59
60
67protected:
68 friend class multigrid::detail::MultigridState;
69
83 virtual void apply_with_initial_guess(const LinOp* b, LinOp* x,
84 initial_guess_mode guess) const = 0;
85
86 void apply_with_initial_guess(ptr_param<const LinOp> b, ptr_param<LinOp> x,
87 initial_guess_mode guess) const
88 {
89 apply_with_initial_guess(b.get(), x.get(), guess);
90 }
91
104 virtual void apply_with_initial_guess(const LinOp* alpha, const LinOp* b,
105 const LinOp* beta, LinOp* x,
106 initial_guess_mode guess) const = 0;
107
108
109 void apply_with_initial_guess(ptr_param<const LinOp> alpha,
113 initial_guess_mode guess) const
114 {
115 apply_with_initial_guess(alpha.get(), b.get(), beta.get(), x.get(),
116 guess);
117 }
118
124 initial_guess_mode get_default_initial_guess() const { return guess_; }
125
132 explicit ApplyWithInitialGuess(
134 : guess_(guess)
135 {}
136
142 void set_default_initial_guess(initial_guess_mode guess) { guess_ = guess; }
143
144private:
145 initial_guess_mode guess_;
146};
147
148
161template <typename DerivedType>
163protected:
164 friend class multigrid::detail::MultigridState;
165
168 : ApplyWithInitialGuess(guess)
169 {}
170
175 void apply_with_initial_guess(const LinOp* b, LinOp* x,
176 initial_guess_mode guess) const override
177 {
178 self()->template log<log::Logger::linop_apply_started>(self(), b, x);
179 auto exec = self()->get_executor();
180 GKO_ASSERT_CONFORMANT(self(), b);
181 GKO_ASSERT_EQUAL_ROWS(self(), x);
182 GKO_ASSERT_EQUAL_COLS(b, x);
183 this->apply_with_initial_guess_impl(make_temporary_clone(exec, b).get(),
184 make_temporary_clone(exec, x).get(),
185 guess);
186 self()->template log<log::Logger::linop_apply_completed>(self(), b, x);
187 }
188
193 void apply_with_initial_guess(const LinOp* alpha, const LinOp* b,
194 const LinOp* beta, LinOp* x,
195 initial_guess_mode guess) const override
196 {
197 self()->template log<log::Logger::linop_advanced_apply_started>(
198 self(), alpha, b, beta, x);
199 auto exec = self()->get_executor();
200 GKO_ASSERT_CONFORMANT(self(), b);
201 GKO_ASSERT_EQUAL_ROWS(self(), x);
202 GKO_ASSERT_EQUAL_COLS(b, x);
203 GKO_ASSERT_EQUAL_DIMENSIONS(alpha, dim<2>(1, 1));
204 GKO_ASSERT_EQUAL_DIMENSIONS(beta, dim<2>(1, 1));
205 this->apply_with_initial_guess_impl(
206 make_temporary_clone(exec, alpha).get(),
207 make_temporary_clone(exec, b).get(),
208 make_temporary_clone(exec, beta).get(),
209 make_temporary_clone(exec, x).get(), guess);
210 self()->template log<log::Logger::linop_advanced_apply_completed>(
211 self(), alpha, b, beta, x);
212 }
213
214 // TODO: should we provide the default implementation?
219 virtual void apply_with_initial_guess_impl(
220 const LinOp* b, LinOp* x, initial_guess_mode guess) const = 0;
221
226 virtual void apply_with_initial_guess_impl(
227 const LinOp* alpha, const LinOp* b, const LinOp* beta, LinOp* x,
228 initial_guess_mode guess) const = 0;
229
230 GKO_ENABLE_SELF(DerivedType);
231};
232
233
238template <typename Solver>
240 // number of vectors used by this workspace
241 static int num_vectors(const Solver&) { return 0; }
242 // number of arrays used by this workspace
243 static int num_arrays(const Solver&) { return 0; }
244 // array containing the num_vectors names for the workspace vectors
245 static std::vector<std::string> op_names(const Solver&) { return {}; }
246 // array containing the num_arrays names for the workspace vectors
247 static std::vector<std::string> array_names(const Solver&) { return {}; }
248 // array containing all scalar vectors (independent of problem size)
249 static std::vector<int> scalars(const Solver&) { return {}; }
250 // array containing all vectors (dependent on problem size)
251 static std::vector<int> vectors(const Solver&) { return {}; }
252};
253
254
270template <typename DerivedType>
272public:
279 void set_preconditioner(std::shared_ptr<const LinOp> new_precond) override
280 {
281 auto exec = self()->get_executor();
282 if (new_precond) {
283 GKO_ASSERT_EQUAL_DIMENSIONS(self(), new_precond);
284 GKO_ASSERT_IS_SQUARE_MATRIX(new_precond);
285 if (new_precond->get_executor() != exec) {
286 new_precond = gko::clone(exec, new_precond);
287 }
288 }
290 }
291
297 {
298 if (&other != this) {
300 }
301 return *this;
302 }
303
310 {
311 if (&other != this) {
312 set_preconditioner(other.get_preconditioner());
313 other.set_preconditioner(nullptr);
314 }
315 return *this;
316 }
317
318 EnablePreconditionable() = default;
319
320 EnablePreconditionable(std::shared_ptr<const LinOp> preconditioner)
321 {
322 set_preconditioner(std::move(preconditioner));
323 }
324
329 {
330 *this = other;
331 }
332
338 {
339 *this = std::move(other);
340 }
341
342private:
343 DerivedType* self() { return static_cast<DerivedType*>(this); }
344
345 const DerivedType* self() const
346 {
347 return static_cast<const DerivedType*>(this);
348 }
349};
350
351
352namespace detail {
353
354
363class SolverBaseLinOp {
364public:
365 SolverBaseLinOp(std::shared_ptr<const Executor> exec)
366 : workspace_{std::move(exec)}
367 {}
368
369 virtual ~SolverBaseLinOp() = default;
370
376 std::shared_ptr<const LinOp> get_system_matrix() const
377 {
378 return system_matrix_;
379 }
380
381 const LinOp* get_workspace_op(int vector_id) const
382 {
383 return workspace_.get_op(vector_id);
384 }
385
386 virtual int get_num_workspace_ops() const { return 0; }
387
388 virtual std::vector<std::string> get_workspace_op_names() const
389 {
390 return {};
391 }
392
397 virtual std::vector<int> get_workspace_scalars() const { return {}; }
398
403 virtual std::vector<int> get_workspace_vectors() const { return {}; }
404
405protected:
406 void set_system_matrix_base(std::shared_ptr<const LinOp> system_matrix)
407 {
408 system_matrix_ = std::move(system_matrix);
409 }
410
411 void set_workspace_size(int num_operators, int num_arrays) const
412 {
413 workspace_.set_size(num_operators, num_arrays);
414 }
415
416 template <typename LinOpType>
417 LinOpType* create_workspace_op(int vector_id, gko::dim<2> size) const
418 {
419 return workspace_.template create_or_get_op<LinOpType>(
420 vector_id,
421 [&] {
422 return LinOpType::create(this->workspace_.get_executor(), size);
423 },
424 typeid(LinOpType), size, size[1]);
425 }
426
427 template <typename LinOpType>
428 LinOpType* create_workspace_op_with_config_of(int vector_id,
429 const LinOpType* vec) const
430 {
431 return workspace_.template create_or_get_op<LinOpType>(
432 vector_id, [&] { return LinOpType::create_with_config_of(vec); },
433 typeid(*vec), vec->get_size(), vec->get_stride());
434 }
435
436 template <typename LinOpType>
437 LinOpType* create_workspace_op_with_type_of(int vector_id,
438 const LinOpType* vec,
439 dim<2> size) const
440 {
441 return workspace_.template create_or_get_op<LinOpType>(
442 vector_id,
443 [&] {
444 return LinOpType::create_with_type_of(
445 vec, workspace_.get_executor(), size, size[1]);
446 },
447 typeid(*vec), size, size[1]);
448 }
449
450 template <typename LinOpType>
451 LinOpType* create_workspace_op_with_type_of(int vector_id,
452 const LinOpType* vec,
453 dim<2> global_size,
454 dim<2> local_size) const
455 {
456 return workspace_.template create_or_get_op<LinOpType>(
457 vector_id,
458 [&] {
459 return LinOpType::create_with_type_of(
460 vec, workspace_.get_executor(), global_size, local_size,
461 local_size[1]);
462 },
463 typeid(*vec), global_size, local_size[1]);
464 }
465
466 template <typename ValueType>
467 matrix::Dense<ValueType>* create_workspace_scalar(int vector_id,
468 size_type size) const
469 {
470 return workspace_.template create_or_get_op<matrix::Dense<ValueType>>(
471 vector_id,
472 [&] {
474 workspace_.get_executor(), dim<2>{1, size});
475 },
476 typeid(matrix::Dense<ValueType>), gko::dim<2>{1, size}, size);
477 }
478
479 template <typename ValueType>
480 array<ValueType>& create_workspace_array(int array_id, size_type size) const
481 {
482 return workspace_.template create_or_get_array<ValueType>(array_id,
483 size);
484 }
485
486 template <typename ValueType>
487 array<ValueType>& create_workspace_array(int array_id) const
488 {
489 return workspace_.template init_or_get_array<ValueType>(array_id);
490 }
491
492private:
493 mutable detail::workspace workspace_;
494
495 std::shared_ptr<const LinOp> system_matrix_;
496};
497
498
499} // namespace detail
500
501
502template <typename MatrixType>
503class
504 // clang-format off
505 GKO_DEPRECATED("This class will be replaced by the template-less detail::SolverBaseLinOp in a future release") SolverBase
506 // clang-format on
507 : public detail::SolverBaseLinOp {
508public:
509 using detail::SolverBaseLinOp::SolverBaseLinOp;
510
518 std::shared_ptr<const MatrixType> get_system_matrix() const
519 {
520 return std::dynamic_pointer_cast<const MatrixType>(
521 SolverBaseLinOp::get_system_matrix());
522 }
523
524protected:
525 void set_system_matrix_base(std::shared_ptr<const MatrixType> system_matrix)
526 {
527 SolverBaseLinOp::set_system_matrix_base(std::move(system_matrix));
528 }
529};
530
531
541template <typename DerivedType, typename MatrixType = LinOp>
542class EnableSolverBase : public SolverBase<MatrixType> {
543public:
549 {
550 if (&other != this) {
551 set_system_matrix(other.get_system_matrix());
552 }
553 return *this;
554 }
555
561 {
562 if (&other != this) {
563 set_system_matrix(other.get_system_matrix());
564 other.set_system_matrix(nullptr);
565 }
566 return *this;
567 }
568
569 EnableSolverBase() : SolverBase<MatrixType>{self()->get_executor()} {}
570
571 EnableSolverBase(std::shared_ptr<const MatrixType> system_matrix)
572 : SolverBase<MatrixType>{self()->get_executor()}
573 {
574 set_system_matrix(std::move(system_matrix));
575 }
576
581 : SolverBase<MatrixType>{other.self()->get_executor()}
582 {
583 *this = other;
584 }
585
591 : SolverBase<MatrixType>{other.self()->get_executor()}
592 {
593 *this = std::move(other);
594 }
595
596 int get_num_workspace_ops() const override
597 {
598 using traits = workspace_traits<DerivedType>;
599 return traits::num_vectors(*self());
600 }
601
602 std::vector<std::string> get_workspace_op_names() const override
603 {
604 using traits = workspace_traits<DerivedType>;
605 return traits::op_names(*self());
606 }
607
612 std::vector<int> get_workspace_scalars() const override
613 {
614 using traits = workspace_traits<DerivedType>;
615 return traits::scalars(*self());
616 }
617
622 std::vector<int> get_workspace_vectors() const override
623 {
624 using traits = workspace_traits<DerivedType>;
625 return traits::vectors(*self());
626 }
627
628protected:
629 void set_system_matrix(std::shared_ptr<const MatrixType> new_system_matrix)
630 {
631 auto exec = self()->get_executor();
632 if (new_system_matrix) {
633 GKO_ASSERT_EQUAL_DIMENSIONS(self(), new_system_matrix);
634 GKO_ASSERT_IS_SQUARE_MATRIX(new_system_matrix);
635 if (new_system_matrix->get_executor() != exec) {
636 new_system_matrix = gko::clone(exec, new_system_matrix);
637 }
638 }
639 this->set_system_matrix_base(new_system_matrix);
640 }
641
642 void setup_workspace() const
643 {
644 using traits = workspace_traits<DerivedType>;
645 this->set_workspace_size(traits::num_vectors(*self()),
646 traits::num_arrays(*self()));
647 }
648
649private:
650 DerivedType* self() { return static_cast<DerivedType*>(this); }
651
652 const DerivedType* self() const
653 {
654 return static_cast<const DerivedType*>(this);
655 }
656};
657
658
666public:
672 std::shared_ptr<const stop::CriterionFactory> get_stop_criterion_factory()
673 const
674 {
675 return stop_factory_;
676 }
677
684 std::shared_ptr<const stop::CriterionFactory> new_stop_factory)
685 {
686 stop_factory_ = new_stop_factory;
687 }
688
689private:
690 std::shared_ptr<const stop::CriterionFactory> stop_factory_;
691};
692
693
703template <typename DerivedType>
705public:
711 {
712 if (&other != this) {
714 }
715 return *this;
716 }
717
724 {
725 if (&other != this) {
726 set_stop_criterion_factory(other.get_stop_criterion_factory());
727 other.set_stop_criterion_factory(nullptr);
728 }
729 return *this;
730 }
731
732 EnableIterativeBase() = default;
733
735 std::shared_ptr<const stop::CriterionFactory> stop_factory)
736 {
737 set_stop_criterion_factory(std::move(stop_factory));
738 }
739
743 EnableIterativeBase(const EnableIterativeBase& other) { *this = other; }
744
750 {
751 *this = std::move(other);
752 }
753
755 std::shared_ptr<const stop::CriterionFactory> new_stop_factory) override
756 {
757 auto exec = self()->get_executor();
758 if (new_stop_factory && new_stop_factory->get_executor() != exec) {
759 new_stop_factory = gko::clone(exec, new_stop_factory);
760 }
762 }
763
764private:
765 DerivedType* self() { return static_cast<DerivedType*>(this); }
766
767 const DerivedType* self() const
768 {
769 return static_cast<const DerivedType*>(this);
770 }
771};
772
773
784template <typename ValueType, typename DerivedType>
786 : public EnableSolverBase<DerivedType>,
787 public EnableIterativeBase<DerivedType>,
788 public EnablePreconditionable<DerivedType> {
789public:
791
793 std::shared_ptr<const LinOp> system_matrix,
794 std::shared_ptr<const stop::CriterionFactory> stop_factory,
795 std::shared_ptr<const LinOp> preconditioner)
796 : EnableSolverBase<DerivedType>(std::move(system_matrix)),
797 EnableIterativeBase<DerivedType>{std::move(stop_factory)},
798 EnablePreconditionable<DerivedType>{std::move(preconditioner)}
799 {}
800
801 template <typename FactoryParameters>
803 std::shared_ptr<const LinOp> system_matrix,
804 const FactoryParameters& params)
806 system_matrix, stop::combine(params.criteria),
807 generate_preconditioner(system_matrix, params)}
808 {}
809
810private:
811 template <typename FactoryParameters>
812 static std::shared_ptr<const LinOp> generate_preconditioner(
813 std::shared_ptr<const LinOp> system_matrix,
814 const FactoryParameters& params)
815 {
816 if (params.generated_preconditioner) {
817 return params.generated_preconditioner;
818 } else if (params.preconditioner) {
819 return params.preconditioner->generate(system_matrix);
820 } else {
822 system_matrix->get_executor(), system_matrix->get_size());
823 }
824 }
825};
826
827
828template <typename Parameters, typename Factory>
830 : enable_parameters_type<Parameters, Factory> {
834 std::vector<std::shared_ptr<const stop::CriterionFactory>>
835 GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(criteria);
836};
837
838
839template <typename Parameters, typename Factory>
841 : enable_iterative_solver_factory_parameters<Parameters, Factory> {
846 std::shared_ptr<const LinOpFactory> GKO_DEFERRED_FACTORY_PARAMETER(
848
853 std::shared_ptr<const LinOp> GKO_FACTORY_PARAMETER_SCALAR(
855};
856
857
858} // namespace solver
859} // namespace gko
860
861
862GKO_END_DISABLE_DEPRECATION_WARNINGS
863
864
865#endif // GKO_PUBLIC_CORE_SOLVER_SOLVER_BASE_HPP_
Definition lin_op.hpp:118
A LinOp implementing this interface can be preconditioned.
Definition lin_op.hpp:683
virtual void set_preconditioner(std::shared_ptr< const LinOp > new_precond)
Sets the preconditioner operator used by the Preconditionable.
Definition lin_op.hpp:703
virtual std::shared_ptr< const LinOp > get_preconditioner() const
Returns the preconditioner operator used by the Preconditionable.
Definition lin_op.hpp:692
The enable_parameters_type mixin is used to create a base implementation of the factory parameters st...
Definition abstract_factory.hpp:211
static std::unique_ptr< Dense > create(std::shared_ptr< const Executor > exec, const dim< 2 > &size={}, size_type stride=0)
Creates an uninitialized Dense matrix of the specified size.
static std::unique_ptr< Identity > create(std::shared_ptr< const Executor > exec, dim< 2 > size)
Creates an Identity matrix of the specified size.
This class is used for function parameters in the place of raw pointers.
Definition utils_helper.hpp:43
ApplyWithInitialGuess provides a way to give the input guess for apply function.
Definition solver_base.hpp:66
EnableApplyWithInitialGuess providing default operation for ApplyWithInitialGuess with correct valida...
Definition solver_base.hpp:162
A LinOp deriving from this CRTP class stores a stopping criterion factory and allows applying with a ...
Definition solver_base.hpp:704
EnableIterativeBase & operator=(EnableIterativeBase &&other)
Moves the provided stopping criterion, clones it onto this executor if executors don't match.
Definition solver_base.hpp:723
void set_stop_criterion_factory(std::shared_ptr< const stop::CriterionFactory > new_stop_factory) override
Sets the stopping criterion of the solver.
Definition solver_base.hpp:754
EnableIterativeBase(EnableIterativeBase &&other)
Moves the provided stopping criterion.
Definition solver_base.hpp:749
EnableIterativeBase(const EnableIterativeBase &other)
Creates a shallow copy of the provided stopping criterion.
Definition solver_base.hpp:743
EnableIterativeBase & operator=(const EnableIterativeBase &other)
Creates a shallow copy of the provided stopping criterion, clones it onto this executor if executors ...
Definition solver_base.hpp:710
Mixin providing default operation for Preconditionable with correct value semantics.
Definition solver_base.hpp:271
EnablePreconditionable(const EnablePreconditionable &other)
Creates a shallow copy of the provided preconditioner.
Definition solver_base.hpp:328
EnablePreconditionable & operator=(EnablePreconditionable &&other)
Moves the provided preconditioner, clones it onto this executor if executors don't match.
Definition solver_base.hpp:309
EnablePreconditionable(EnablePreconditionable &&other)
Moves the provided preconditioner.
Definition solver_base.hpp:337
EnablePreconditionable & operator=(const EnablePreconditionable &other)
Creates a shallow copy of the provided preconditioner, clones it onto this executor if executors don'...
Definition solver_base.hpp:296
void set_preconditioner(std::shared_ptr< const LinOp > new_precond) override
Sets the preconditioner operator used by the Preconditionable.
Definition solver_base.hpp:279
A LinOp implementing this interface stores a system matrix and stopping criterion factory.
Definition solver_base.hpp:788
A LinOp deriving from this CRTP class stores a system matrix.
Definition solver_base.hpp:542
EnableSolverBase(EnableSolverBase &&other)
Moves the provided system matrix.
Definition solver_base.hpp:590
std::vector< int > get_workspace_vectors() const override
Returns the IDs of all vectors (workspace vectors with system dimension-dependent size,...
Definition solver_base.hpp:622
std::vector< int > get_workspace_scalars() const override
Returns the IDs of all scalars (workspace vectors with system dimension-independent size,...
Definition solver_base.hpp:612
EnableSolverBase(const EnableSolverBase &other)
Creates a shallow copy of the provided system matrix.
Definition solver_base.hpp:580
EnableSolverBase & operator=(EnableSolverBase &&other)
Moves the provided system matrix, clones it onto this executor if executors don't match.
Definition solver_base.hpp:560
EnableSolverBase & operator=(const EnableSolverBase &other)
Creates a shallow copy of the provided system matrix, clones it onto this executor if executors don't...
Definition solver_base.hpp:548
A LinOp implementing this interface stores a stopping criterion factory.
Definition solver_base.hpp:665
std::shared_ptr< const stop::CriterionFactory > get_stop_criterion_factory() const
Gets the stopping criterion factory of the solver.
Definition solver_base.hpp:672
virtual void set_stop_criterion_factory(std::shared_ptr< const stop::CriterionFactory > new_stop_factory)
Sets the stopping criterion of the solver.
Definition solver_base.hpp:683
Definition solver_base.hpp:507
std::shared_ptr< const MatrixType > get_system_matrix() const
Returns the system matrix, with its concrete type, used by the solver.
Definition solver_base.hpp:518
#define GKO_FACTORY_PARAMETER_SCALAR(_name, _default)
Creates a scalar factory parameter in the factory parameters structure.
Definition abstract_factory.hpp:445
std::shared_ptr< const CriterionFactory > combine(FactoryContainer &&factories)
Combines multiple criterion factories into a single combined criterion factory.
Definition combined.hpp:110
initial_guess_mode
Give a initial guess mode about the input of the apply method.
Definition solver_base.hpp:34
@ provided
the input is provided
@ rhs
the input is right hand side
The Ginkgo namespace.
Definition abstract_factory.hpp:20
std::size_t size_type
Integral type used for allocation quantities.
Definition types.hpp:86
detail::cloned_type< Pointer > clone(const Pointer &p)
Creates a unique clone of the object pointed to by p.
Definition utils_helper.hpp:175
detail::temporary_clone< detail::pointee< Ptr > > make_temporary_clone(std::shared_ptr< const Executor > exec, Ptr &&ptr)
Creates a temporary_clone.
Definition temporary_clone.hpp:209
@ array
The matrix should be written as dense matrix in column-major order.
A type representing the dimensions of a multidimensional object.
Definition dim.hpp:27
std::vector< std::shared_ptr< const stop::CriterionFactory > > criteria
Stopping criteria to be used by the solver.
Definition solver_base.hpp:835
std::shared_ptr< const LinOp > generated_preconditioner
Already generated preconditioner.
Definition solver_base.hpp:854
std::shared_ptr< const LinOpFactory > preconditioner
The preconditioner to be used by the iterative solver.
Definition solver_base.hpp:847
Traits class providing information on the type and location of workspace vectors inside a solver.
Definition solver_base.hpp:239