5 #ifndef GKO_PUBLIC_CORE_LOG_LOGGER_HPP_
6 #define GKO_PUBLIC_CORE_LOG_LOGGER_HPP_
12 #include <type_traits>
15 #include <ginkgo/core/base/types.hpp>
16 #include <ginkgo/core/base/utils_helper.hpp>
22 template <
typename ValueType>
27 class PolymorphicObject;
29 class stopping_status;
36 class BatchLinOpFactory;
38 template <
typename ValueType>
110 #define GKO_LOGGER_REGISTER_EVENT(_id, _event_name, ...) \
112 virtual void on_##_event_name(__VA_ARGS__) const {} \
115 template <size_type Event, typename... Params> \
116 std::enable_if_t<Event == _id && (_id < event_count_max)> on( \
117 Params&&... params) const \
119 if (enabled_events_ & (mask_type{1} << _id)) { \
120 this->on_##_event_name(std::forward<Params>(params)...); \
123 static constexpr size_type _event_name{_id}; \
124 static constexpr mask_type _event_name##_mask{mask_type{1} << _id};
132 GKO_LOGGER_REGISTER_EVENT(0, allocation_started,
const Executor* exec,
142 GKO_LOGGER_REGISTER_EVENT(1, allocation_completed,
const Executor* exec,
152 GKO_LOGGER_REGISTER_EVENT(2, free_started,
const Executor* exec,
161 GKO_LOGGER_REGISTER_EVENT(3, free_completed,
const Executor* exec,
173 GKO_LOGGER_REGISTER_EVENT(4, copy_started,
const Executor* exec_from,
186 GKO_LOGGER_REGISTER_EVENT(5, copy_completed,
const Executor* exec_from,
196 GKO_LOGGER_REGISTER_EVENT(6, operation_launched,
const Executor* exec,
210 GKO_LOGGER_REGISTER_EVENT(7, operation_completed,
const Executor* exec,
219 GKO_LOGGER_REGISTER_EVENT(8, polymorphic_object_create_started,
229 GKO_LOGGER_REGISTER_EVENT(9, polymorphic_object_create_completed,
241 GKO_LOGGER_REGISTER_EVENT(10, polymorphic_object_copy_started,
253 GKO_LOGGER_REGISTER_EVENT(11, polymorphic_object_copy_completed,
264 GKO_LOGGER_REGISTER_EVENT(12, polymorphic_object_deleted,
274 GKO_LOGGER_REGISTER_EVENT(13, linop_apply_started,
const LinOp* A,
284 GKO_LOGGER_REGISTER_EVENT(14, linop_apply_completed,
const LinOp* A,
296 GKO_LOGGER_REGISTER_EVENT(15, linop_advanced_apply_started,
const LinOp* A,
309 GKO_LOGGER_REGISTER_EVENT(16, linop_advanced_apply_completed,
320 GKO_LOGGER_REGISTER_EVENT(17, linop_factory_generate_started,
331 GKO_LOGGER_REGISTER_EVENT(18, linop_factory_generate_completed,
346 GKO_LOGGER_REGISTER_EVENT(19, criterion_check_started,
350 const uint8& stopping_id,
351 const bool& set_finalized)
373 GKO_LOGGER_REGISTER_EVENT(
376 const uint8& stopping_id,
const bool& set_finalized,
378 const bool& all_converged)
397 virtual void on_criterion_check_completed(
400 const uint8& stopping_id,
const bool& set_finalized,
402 const bool& all_converged)
const
404 this->on_criterion_check_completed(
criterion, it, r, tau, x,
405 stopping_id, set_finalized, status,
406 one_changed, all_converged);
410 static constexpr
size_type iteration_complete{21};
411 static constexpr mask_type iteration_complete_mask{mask_type{1} << 21};
413 template <
size_type Event,
typename... Params>
415 Params&&... params)
const
417 if (enabled_events_ & (mask_type{1} << 21)) {
418 this->on_iteration_complete(std::forward<Params>(params)...);
436 "Please use the version with the additional stopping "
438 virtual
void on_iteration_complete(const LinOp*
solver, const
size_type& it,
439 const LinOp* r, const LinOp* x =
nullptr,
440 const LinOp* tau =
nullptr)
const
457 "Please use the version with the additional stopping "
459 virtual
void on_iteration_complete(const LinOp*
solver, const
size_type& it,
460 const LinOp* r, const LinOp* x,
462 const LinOp* implicit_tau_sq)
const
464 GKO_BEGIN_DISABLE_DEPRECATION_WARNINGS
465 this->on_iteration_complete(
solver, it, r, x, tau);
466 GKO_END_DISABLE_DEPRECATION_WARNINGS
484 virtual void on_iteration_complete(
const LinOp*
solver,
const LinOp* b,
486 const LinOp* r,
const LinOp* tau,
487 const LinOp* implicit_tau_sq,
488 const array<stopping_status>* status,
491 GKO_BEGIN_DISABLE_DEPRECATION_WARNINGS
492 this->on_iteration_complete(
solver, it, r, x, tau, implicit_tau_sq);
493 GKO_END_DISABLE_DEPRECATION_WARNINGS
504 GKO_LOGGER_REGISTER_EVENT(22, polymorphic_object_move_started,
505 const Executor* exec,
506 const PolymorphicObject* input,
507 const PolymorphicObject* output)
516 GKO_LOGGER_REGISTER_EVENT(23, polymorphic_object_move_completed,
517 const Executor* exec,
518 const PolymorphicObject* input,
519 const PolymorphicObject* output)
528 GKO_LOGGER_REGISTER_EVENT(24, batch_linop_factory_generate_started,
529 const batch::BatchLinOpFactory*
factory,
530 const batch::BatchLinOp* input)
540 GKO_LOGGER_REGISTER_EVENT(25, batch_linop_factory_generate_completed,
541 const batch::BatchLinOpFactory*
factory,
542 const batch::BatchLinOp* input,
543 const batch::BatchLinOp* output)
546 static constexpr
size_type batch_solver_completed{26};
547 static constexpr mask_type batch_solver_completed_mask{mask_type{1} << 26};
549 template <
size_type Event,
typename... Params>
551 Params&&... params)
const
553 if (enabled_events_ & batch_solver_completed_mask) {
554 this->on_batch_solver_completed(std::forward<Params>(params)...);
566 virtual void on_batch_solver_completed(
567 const array<int>& iters,
const array<double>& residual_norms)
const
577 virtual void on_batch_solver_completed(
578 const array<int>& iters,
const array<float>& residual_norms)
const
582 #if GINKGO_ENABLE_HALF
592 virtual void on_batch_solver_completed(
593 const array<int>& iters,
594 const array<gko::float16>& residual_norms)
const
601 #if GINKGO_ENABLE_BFLOAT16
611 virtual void on_batch_solver_completed(
612 const array<int>& iters,
613 const array<gko::bfloat16>& residual_norms)
const
621 #undef GKO_LOGGER_REGISTER_EVENT
627 allocation_started_mask | allocation_completed_mask |
628 free_started_mask | free_completed_mask | copy_started_mask |
635 operation_launched_mask | operation_completed_mask;
641 polymorphic_object_create_started_mask |
642 polymorphic_object_create_completed_mask |
643 polymorphic_object_copy_started_mask |
644 polymorphic_object_copy_completed_mask |
645 polymorphic_object_move_started_mask |
646 polymorphic_object_move_completed_mask |
647 polymorphic_object_deleted_mask;
653 linop_apply_started_mask | linop_apply_completed_mask |
654 linop_advanced_apply_started_mask | linop_advanced_apply_completed_mask;
660 linop_factory_generate_started_mask |
661 linop_factory_generate_completed_mask;
667 batch_linop_factory_generate_started_mask |
668 batch_linop_factory_generate_completed_mask;
674 criterion_check_started_mask | criterion_check_completed_mask;
682 virtual ~
Logger() =
default;
699 GKO_DEPRECATED(
"use single-parameter constructor")
720 : enabled_events_{enabled_events}
724 mask_type enabled_events_;
742 virtual void add_logger(std::shared_ptr<const Logger> logger) = 0;
765 virtual const std::vector<std::shared_ptr<const Logger>>&
get_loggers()
785 template <
typename ConcreteLoggable,
typename PolymorphicBase = Loggable>
788 void add_logger(std::shared_ptr<const Logger> logger)
override
790 loggers_.push_back(logger);
793 void remove_logger(
const Logger* logger)
override
796 find_if(begin(loggers_), end(loggers_),
797 [&logger](
const auto& l) {
return l.get() == logger; });
798 if (idx != end(loggers_)) {
808 remove_logger(logger.
get());
811 const std::vector<std::shared_ptr<const Logger>>& get_loggers()
817 void clear_loggers()
override { loggers_.clear(); }
827 template <
size_type Event,
typename ConcreteLoggableT,
typename =
void>
828 struct propagate_log_helper {
829 template <
typename... Args>
830 static void propagate_log(
const ConcreteLoggableT*, Args&&...)
834 template <
size_type Event,
typename ConcreteLoggableT>
835 struct propagate_log_helper<
836 Event, ConcreteLoggableT,
838 decltype(std::declval<ConcreteLoggableT>().get_executor())>> {
839 template <
typename... Args>
840 static void propagate_log(
const ConcreteLoggableT* loggable,
843 const auto exec = loggable->get_executor();
844 if (exec->should_propagate_log()) {
845 for (
auto& logger : exec->get_loggers()) {
846 if (logger->needs_propagation()) {
847 logger->template on<Event>(std::forward<Args>(args)...);
855 template <
size_type Event,
typename... Params>
856 void log(Params&&... params)
const
858 propagate_log_helper<Event, ConcreteLoggable>::propagate_log(
859 static_cast<const ConcreteLoggable*>(
this),
860 std::forward<Params>(params)...);
861 for (
auto& logger : loggers_) {
862 logger->template on<Event>(std::forward<Params>(params)...);
866 std::vector<std::shared_ptr<const Logger>> loggers_;
874 #endif // GKO_PUBLIC_CORE_LOG_LOGGER_HPP_