5 #ifndef GKO_PUBLIC_CORE_SOLVER_WORKSPACE_HPP_
6 #define GKO_PUBLIC_CORE_SOLVER_WORKSPACE_HPP_
11 #include <ginkgo/core/matrix/dense.hpp>
24 template <
typename ValueType>
25 array<ValueType>& init(std::shared_ptr<const Executor> exec,
size_type size)
27 auto container = std::make_unique<concrete_container<ValueType>>(
28 std::move(exec), size);
29 auto& arr = container->arr;
30 data_ = std::move(container);
34 bool empty()
const {
return data_.get() ==
nullptr; }
36 template <
typename ValueType>
39 return dynamic_cast<const concrete_container<ValueType>*
>(data_.get());
42 template <
typename ValueType>
43 array<ValueType>& get()
45 GKO_ASSERT(this->
template contains<ValueType>());
46 return dynamic_cast<concrete_container<ValueType>*
>(data_.get())->arr;
49 template <
typename ValueType>
50 const array<ValueType>& get()
const
52 GKO_ASSERT(this->
template contains<ValueType>());
53 return dynamic_cast<const concrete_container<ValueType>*
>(data_.get())
57 void clear() { data_.reset(); }
60 struct generic_container {
61 virtual ~generic_container() =
default;
64 template <
typename ValueType>
65 struct concrete_container : generic_container {
66 template <
typename... Args>
67 concrete_container(Args&&... args) : arr{std::forward<Args>(args)...}
73 std::unique_ptr<generic_container> data_;
79 workspace(std::shared_ptr<const Executor> exec) : exec_{std::move(exec)} {}
81 workspace(
const workspace& other) : workspace{other.get_executor()} {}
83 workspace(workspace&& other) : workspace{other.get_executor()}
88 workspace& operator=(
const workspace& other) {
return *
this; }
90 workspace& operator=(workspace&& other)
96 template <
typename LinOpType,
typename CreateOperation>
97 LinOpType* create_or_get_op(
int op_id, CreateOperation create,
98 const std::type_info& expected_type,
101 GKO_ASSERT(op_id >= 0 && op_id < operators_.size());
104 auto stored_op = operators_[op_id].get();
106 if (!stored_op ||
typeid(*stored_op) != expected_type) {
107 auto new_op = create();
109 operators_[op_id] = std::move(new_op);
113 op = dynamic_cast<LinOpType*>(operators_[op_id].get());
115 if (op->get_size() != size || op->get_stride() != stride) {
116 auto new_op = create();
118 operators_[op_id] = std::move(new_op);
123 const LinOp* get_op(
int op_id)
const
125 GKO_ASSERT(op_id >= 0 && op_id < operators_.size());
126 return operators_[op_id].get();
129 template <
typename ValueType>
130 array<ValueType>& init_or_get_array(
int array_id)
132 GKO_ASSERT(array_id >= 0 && array_id < arrays_.size());
133 auto&
array = arrays_[array_id];
136 array.template init<ValueType>(this->get_executor(), 0);
140 GKO_ASSERT(
array.template contains<ValueType>());
141 return array.template get<ValueType>();
144 template <
typename ValueType>
145 array<ValueType>& create_or_get_array(
int array_id,
size_type size)
147 auto& result = init_or_get_array<ValueType>(array_id);
148 if (result.get_size() != size) {
149 result.resize_and_reset(size);
154 std::shared_ptr<const Executor> get_executor()
const {
return exec_; }
156 void set_size(
int num_operators,
int num_arrays)
158 operators_.resize(num_operators);
159 arrays_.resize(num_arrays);
164 for (
auto& op : operators_) {
167 for (
auto& array : arrays_) {
173 std::shared_ptr<const Executor> exec_;
174 std::vector<std::unique_ptr<LinOp>> operators_;
175 std::vector<any_array> arrays_;
183 #endif // GKO_PUBLIC_CORE_SOLVER_WORKSPACE_HPP_