5 #ifndef GKO_PUBLIC_CORE_BASE_ABSTRACT_FACTORY_HPP_
6 #define GKO_PUBLIC_CORE_BASE_ABSTRACT_FACTORY_HPP_
10 #include <unordered_map>
12 #include <ginkgo/core/base/polymorphic_object.hpp>
44 template <
typename AbstractProductType,
typename ComponentsType>
47 AbstractFactory<AbstractProductType, ComponentsType>> {
49 using abstract_product_type = AbstractProductType;
50 using components_type = ComponentsType;
66 template <
typename... Args>
67 std::unique_ptr<abstract_product_type>
generate(Args&&... args)
const
70 this->generate_impl(components_type{std::forward<Args>(args)...});
71 for (
auto logger : this->loggers_) {
72 product->add_logger(logger);
94 virtual std::unique_ptr<abstract_product_type> generate_impl(
95 ComponentsType args)
const = 0;
122 template <
typename ConcreteFactory,
typename ProductType,
123 typename ParametersType,
typename PolymorphicBase>
130 using product_type = ProductType;
131 using parameters_type = ParametersType;
132 using polymorphic_base = PolymorphicBase;
133 using abstract_product_type =
134 typename PolymorphicBase::abstract_product_type;
135 using components_type =
typename PolymorphicBase::components_type;
137 template <
typename... Args>
138 std::unique_ptr<product_type> generate(Args&&... args)
const
140 auto product = std::unique_ptr<product_type>(static_cast<product_type*>(
141 this->polymorphic_base::generate(std::forward<Args>(args)...)
168 static parameters_type
create() {
return {}; }
178 const parameters_type& parameters = {})
179 : EnablePolymorphicObject<ConcreteFactory, PolymorphicBase>(
181 parameters_{parameters}
184 std::unique_ptr<abstract_product_type> generate_impl(
185 components_type args)
const override
187 return std::unique_ptr<abstract_product_type>(
188 new product_type(
self(), args));
192 GKO_ENABLE_SELF(ConcreteFactory);
194 ParametersType parameters_;
210 template <
typename ConcreteParametersType,
typename Factory>
213 using factory = Factory;
219 template <
typename... Args>
222 this->loggers = {std::forward<Args>(_value)...};
233 std::unique_ptr<Factory>
on(std::shared_ptr<const Executor> exec)
const
235 ConcreteParametersType copy = *
self();
236 for (
const auto& item : deferred_factories) {
237 item.second(exec, copy);
239 auto factory = std::unique_ptr<Factory>(
new Factory(exec, copy));
240 for (
auto& logger : loggers) {
241 factory->add_logger(logger);
247 GKO_ENABLE_SELF(ConcreteParametersType);
252 std::vector<std::shared_ptr<const log::Logger>> loggers{};
260 std::unordered_map<std::string,
261 std::function<void(std::shared_ptr<const Executor> exec,
262 ConcreteParametersType&)>>
280 #define GKO_CREATE_FACTORY_PARAMETERS(_parameters_name, _factory_name) \
282 class _factory_name; \
283 struct _parameters_name##_type \
284 : public ::gko::enable_parameters_type<_parameters_name##_type, \
293 template <
typename From,
typename To>
294 struct is_pointer_convertible : std::is_convertible<From*, To*> {};
308 template <
typename FactoryType>
317 generator_ = [](std::shared_ptr<const Executor>) {
return nullptr; };
324 template <
typename ConcreteFactoryType,
325 std::enable_if_t<detail::is_pointer_convertible<
326 ConcreteFactoryType, FactoryType>::value>* =
nullptr>
329 generator_ = [factory =
330 std::shared_ptr<FactoryType>(std::move(factory))](
331 std::shared_ptr<const Executor>) {
return factory; };
338 template <
typename ConcreteFactoryType,
typename Deleter,
339 std::enable_if_t<detail::is_pointer_convertible<
340 ConcreteFactoryType, FactoryType>::value>* =
nullptr>
342 std::unique_ptr<ConcreteFactoryType, Deleter> factory)
344 generator_ = [factory =
345 std::shared_ptr<FactoryType>(std::move(factory))](
346 std::shared_ptr<const Executor>) {
return factory; };
354 template <
typename ParametersType,
355 typename U = decltype(std::declval<ParametersType>().
on(
356 std::shared_ptr<const Executor>{})),
357 std::enable_if_t<detail::is_pointer_convertible<
358 typename U::element_type, FactoryType>::value>* =
nullptr>
361 generator_ = [parameters](std::shared_ptr<const Executor> exec)
362 -> std::shared_ptr<FactoryType> {
return parameters.on(exec); };
369 std::shared_ptr<FactoryType>
on(std::shared_ptr<const Executor> exec)
const
372 GKO_NOT_SUPPORTED(*
this);
374 return generator_(exec);
378 bool is_empty()
const {
return !bool(generator_); }
381 std::function<std::shared_ptr<FactoryType>(std::shared_ptr<const Executor>)>
394 #define GKO_ENABLE_BUILD_METHOD(_factory_name) \
395 static auto build()->decltype(_factory_name::create()) \
397 return _factory_name::create(); \
399 static_assert(true, \
400 "This assert is used to counter the false positive extra " \
401 "semi-colon warnings")
404 #if !(defined(__CUDACC__) || defined(__HIPCC__))
417 #define GKO_FACTORY_PARAMETER(_name, ...) \
418 _name{__VA_ARGS__}; \
420 template <typename... Args> \
421 auto with_##_name(Args&&... _value) \
422 ->std::decay_t<decltype(*(this->self()))>& \
424 using type = decltype(this->_name); \
425 this->_name = type{std::forward<Args>(_value)...}; \
426 return *(this->self()); \
428 static_assert(true, \
429 "This assert is used to counter the false positive extra " \
430 "semi-colon warnings")
445 #define GKO_FACTORY_PARAMETER_SCALAR(_name, _default) \
446 GKO_FACTORY_PARAMETER(_name, _default)
461 #define GKO_FACTORY_PARAMETER_VECTOR(_name, ...) \
462 GKO_FACTORY_PARAMETER(_name, __VA_ARGS__)
463 #else // defined(__CUDACC__) || defined(__HIPCC__)
468 #define GKO_FACTORY_PARAMETER(_name, ...) \
469 _name{__VA_ARGS__}; \
471 template <typename... Args> \
472 auto with_##_name(Args&&... _value) \
473 ->std::decay_t<decltype(*(this->self()))>& \
475 GKO_NOT_IMPLEMENTED; \
476 return *(this->self()); \
478 static_assert(true, \
479 "This assert is used to counter the false positive extra " \
480 "semi-colon warnings")
482 #define GKO_FACTORY_PARAMETER_SCALAR(_name, _default) \
485 template <typename Arg> \
486 auto with_##_name(Arg&& _value)->std::decay_t<decltype(*(this->self()))>& \
488 using type = decltype(this->_name); \
489 this->_name = type{std::forward<Arg>(_value)}; \
490 return *(this->self()); \
492 static_assert(true, \
493 "This assert is used to counter the false positive extra " \
494 "semi-colon warnings")
496 #define GKO_FACTORY_PARAMETER_VECTOR(_name, ...) \
497 _name{__VA_ARGS__}; \
499 template <typename... Args> \
500 auto with_##_name(Args&&... _value) \
501 ->std::decay_t<decltype(*(this->self()))>& \
503 using type = decltype(this->_name); \
504 this->_name = type{std::forward<Args>(_value)...}; \
505 return *(this->self()); \
507 static_assert(true, \
508 "This assert is used to counter the false positive extra " \
509 "semi-colon warnings")
510 #endif // defined(__CUDACC__) || defined(__HIPCC__)
521 #define GKO_DEFERRED_FACTORY_PARAMETER(_name) \
525 using _name##_type = typename decltype(_name)::element_type; \
528 auto with_##_name(::gko::deferred_factory_parameter<_name##_type> factory) \
529 ->std::decay_t<decltype(*(this->self()))>& \
531 this->_name##_generator_ = std::move(factory); \
532 this->deferred_factories[#_name] = [](const auto& exec, \
534 if (!params._name##_generator_.is_empty()) { \
535 params._name = params._name##_generator_.on(exec); \
538 return *(this->self()); \
542 ::gko::deferred_factory_parameter<_name##_type> _name##_generator_; \
545 static_assert(true, \
546 "This assert is used to counter the false positive extra " \
547 "semi-colon warnings")
559 #define GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(_name) \
563 using _name##_type = typename decltype(_name)::value_type::element_type; \
566 template <typename... Args, \
567 typename = std::enable_if_t<::std::conjunction< \
568 std::is_convertible<Args, ::gko::deferred_factory_parameter< \
569 _name##_type>>...>::value>> \
570 auto with_##_name(Args&&... factories) \
571 ->std::decay_t<decltype(*(this->self()))>& \
573 this->_name##_generator_ = { \
574 ::gko::deferred_factory_parameter<_name##_type>{ \
575 std::forward<Args>(factories)}...}; \
576 this->deferred_factories[#_name] = [](const auto& exec, \
578 if (!params._name##_generator_.empty()) { \
579 params._name.clear(); \
580 for (auto& generator : params._name##_generator_) { \
581 params._name.push_back(generator.on(exec)); \
585 return *(this->self()); \
587 template <typename FactoryType, \
588 typename = std::enable_if_t<std::is_convertible< \
590 ::gko::deferred_factory_parameter<_name##_type>>::value>> \
591 auto with_##_name(const std::vector<FactoryType>& factories) \
592 ->std::decay_t<decltype(*(this->self()))>& \
594 this->_name##_generator_.clear(); \
595 for (const auto& factory : factories) { \
596 this->_name##_generator_.push_back(factory); \
598 this->deferred_factories[#_name] = [](const auto& exec, \
600 if (!params._name##_generator_.empty()) { \
601 params._name.clear(); \
602 for (auto& generator : params._name##_generator_) { \
603 params._name.push_back(generator.on(exec)); \
607 return *(this->self()); \
611 std::vector<::gko::deferred_factory_parameter<_name##_type>> \
612 _name##_generator_; \
615 static_assert(true, \
616 "This assert is used to counter the false positive extra " \
617 "semi-colon warnings")
623 #endif // GKO_PUBLIC_CORE_BASE_ABSTRACT_FACTORY_HPP_