5 #ifndef GKO_PUBLIC_CORE_SOLVER_WORKSPACE_HPP_
6 #define GKO_PUBLIC_CORE_SOLVER_WORKSPACE_HPP_
12 #include <ginkgo/core/matrix/dense.hpp>
25 template <
typename ValueType>
26 array<ValueType>& init(std::shared_ptr<const Executor> exec,
size_type size)
28 auto container = std::make_unique<concrete_container<ValueType>>(
29 std::move(exec), size);
30 auto& arr = container->arr;
31 data_ = std::move(container);
35 bool empty()
const {
return data_.get() ==
nullptr; }
37 template <
typename ValueType>
40 return dynamic_cast<const concrete_container<ValueType>*
>(data_.get());
43 template <
typename ValueType>
44 array<ValueType>& get()
46 GKO_ASSERT(this->
template contains<ValueType>());
47 return dynamic_cast<concrete_container<ValueType>*
>(data_.get())->arr;
50 template <
typename ValueType>
51 const array<ValueType>& get()
const
53 GKO_ASSERT(this->
template contains<ValueType>());
54 return dynamic_cast<const concrete_container<ValueType>*
>(data_.get())
58 void clear() { data_.reset(); }
61 struct generic_container {
62 virtual ~generic_container() =
default;
65 template <
typename ValueType>
66 struct concrete_container : generic_container {
67 template <
typename... Args>
68 concrete_container(Args&&... args) : arr{std::forward<Args>(args)...}
74 std::unique_ptr<generic_container> data_;
80 workspace(std::shared_ptr<const Executor> exec) : exec_{std::move(exec)} {}
82 workspace(
const workspace& other) : workspace{other.get_executor()} {}
84 workspace(workspace&& other) : workspace{other.get_executor()}
89 workspace& operator=(
const workspace& other) {
return *
this; }
91 workspace& operator=(workspace&& other)
97 template <
typename LinOpType,
typename CreateOperation>
98 LinOpType* create_or_get_op(
int op_id, CreateOperation create,
99 const std::type_info& expected_type,
102 GKO_ASSERT(op_id >= 0 && op_id < operators_.size());
105 auto stored_op = operators_[op_id].get();
107 if (!stored_op ||
typeid(*stored_op) != expected_type) {
108 auto new_op = create();
110 operators_[op_id] = std::move(new_op);
114 op = dynamic_cast<LinOpType*>(operators_[op_id].get());
116 if (op->get_size() != size || op->get_stride() != stride) {
117 auto new_op = create();
119 operators_[op_id] = std::move(new_op);
124 const LinOp* get_op(
int op_id)
const
126 GKO_ASSERT(op_id >= 0 && op_id < operators_.size());
127 return operators_[op_id].get();
130 template <
typename ValueType>
131 array<ValueType>& init_or_get_array(
int array_id)
133 GKO_ASSERT(array_id >= 0 && array_id < arrays_.size());
134 auto&
array = arrays_[array_id];
137 array.template init<ValueType>(this->get_executor(), 0);
141 GKO_ASSERT(
array.template contains<ValueType>());
142 return array.template get<ValueType>();
145 template <
typename ValueType>
146 array<ValueType>& create_or_get_array(
int array_id,
size_type size)
148 auto& result = init_or_get_array<ValueType>(array_id);
149 if (result.get_size() != size) {
150 result.resize_and_reset(size);
155 std::shared_ptr<const Executor> get_executor()
const {
return exec_; }
157 void set_size(
int num_operators,
int num_arrays)
159 operators_.resize(num_operators);
160 arrays_.resize(num_arrays);
165 for (
auto& op : operators_) {
168 for (
auto& array : arrays_) {
174 std::shared_ptr<const Executor> exec_;
175 std::vector<std::unique_ptr<LinOp>> operators_;
176 std::vector<any_array> arrays_;
184 #endif // GKO_PUBLIC_CORE_SOLVER_WORKSPACE_HPP_