5 #ifndef GKO_PUBLIC_CORE_MATRIX_DIAGONAL_HPP_
6 #define GKO_PUBLIC_CORE_MATRIX_DIAGONAL_HPP_
9 #include <ginkgo/core/base/array.hpp>
10 #include <ginkgo/core/base/lin_op.hpp>
17 template <
typename ValueType,
typename IndexType>
20 template <
typename ValueType>
39 template <
typename ValueType = default_precision>
41 :
public EnableLinOp<Diagonal<ValueType>>,
42 public ConvertibleTo<Csr<ValueType, int32>>,
43 public ConvertibleTo<Csr<ValueType, int64>>,
44 public ConvertibleTo<Diagonal<next_precision<ValueType>>>,
46 public WritableToMatrixData<ValueType, int32>,
47 public WritableToMatrixData<ValueType, int64>,
48 public ReadableFromMatrixData<ValueType, int32>,
49 public ReadableFromMatrixData<ValueType, int64>,
50 public EnableAbsoluteComputation<remove_complex<Diagonal<ValueType>>> {
51 friend class EnablePolymorphicObject<Diagonal,
LinOp>;
52 friend class Csr<ValueType,
int32>;
53 friend class Csr<ValueType,
int64>;
59 using ConvertibleTo<Csr<ValueType, int32>>::convert_to;
60 using ConvertibleTo<Csr<ValueType, int32>>::move_to;
61 using ConvertibleTo<Csr<ValueType, int64>>::convert_to;
62 using ConvertibleTo<Csr<ValueType, int64>>::move_to;
63 using ConvertibleTo<Diagonal<next_precision<ValueType>>>::convert_to;
64 using ConvertibleTo<Diagonal<next_precision<ValueType>>>::move_to;
66 using value_type = ValueType;
67 using index_type =
int64;
68 using mat_data = matrix_data<ValueType, int64>;
69 using mat_data32 = matrix_data<ValueType, int32>;
70 using device_mat_data = device_matrix_data<ValueType, int64>;
71 using device_mat_data32 = device_matrix_data<ValueType, int32>;
72 using absolute_type = remove_complex<Diagonal>;
76 std::unique_ptr<LinOp>
transpose()
const override;
80 void convert_to(Diagonal<next_precision<ValueType>>* result)
const override;
82 void move_to(Diagonal<next_precision<ValueType>>* result)
override;
84 void convert_to(Csr<ValueType, int32>* result)
const override;
86 void move_to(Csr<ValueType, int32>* result)
override;
88 void convert_to(Csr<ValueType, int64>* result)
const override;
90 void move_to(Csr<ValueType, int64>* result)
override;
124 GKO_ASSERT_REVERSE_CONFORMANT(
this, b);
125 GKO_ASSERT_EQUAL_ROWS(b, x);
126 GKO_ASSERT_EQUAL_COLS(
this, x);
128 this->rapply_impl(b.
get(), x.
get());
142 GKO_ASSERT_CONFORMANT(
this, b);
143 GKO_ASSERT_EQUAL_ROWS(b, x);
144 GKO_ASSERT_EQUAL_ROWS(
this, x);
146 this->inverse_apply_impl(b.
get(), x.
get());
149 void read(
const mat_data& data)
override;
151 void read(
const mat_data32& data)
override;
153 void read(
const device_mat_data& data)
override;
155 void read(
const device_mat_data32& data)
override;
157 void read(device_mat_data&& data)
override;
159 void read(device_mat_data32&& data)
override;
161 void write(mat_data& data)
const override;
163 void write(mat_data32& data)
const override;
173 static std::unique_ptr<Diagonal>
create(
174 std::shared_ptr<const Executor> exec,
size_type size = 0);
190 static std::unique_ptr<Diagonal>
create(
191 std::shared_ptr<const Executor> exec,
const size_type size,
198 template <
typename InputValueType>
200 "explicitly construct the gko::array argument instead of passing an"
204 std::initializer_list<InputValueType> values)
220 std::shared_ptr<const Executor> exec,
size_type size,
221 gko::detail::const_array_view<ValueType>&& values);
229 void apply_impl(
const LinOp* b,
LinOp* x)
const override;
232 LinOp* x)
const override;
234 void rapply_impl(
const LinOp* b,
LinOp* x)
const;
236 void inverse_apply_impl(
const LinOp* b,
LinOp* x)
const;
247 #endif // GKO_PUBLIC_CORE_MATRIX_DIAGONAL_HPP_