5 #ifndef GKO_PUBLIC_CORE_MATRIX_DIAGONAL_HPP_
6 #define GKO_PUBLIC_CORE_MATRIX_DIAGONAL_HPP_
9 #include <ginkgo/core/base/array.hpp>
10 #include <ginkgo/core/base/lin_op.hpp>
17 template <
typename ValueType,
typename IndexType>
20 template <
typename ValueType>
39 template <
typename ValueType = default_precision>
41 :
public EnableLinOp<Diagonal<ValueType>>,
42 public ConvertibleTo<Csr<ValueType, int32>>,
43 public ConvertibleTo<Csr<ValueType, int64>>,
44 public ConvertibleTo<Diagonal<next_precision<ValueType>>>,
45 #if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
46 public ConvertibleTo<Diagonal<next_precision<ValueType, 2>>>,
48 #if GINKGO_ENABLE_HALF && GINKGO_ENABLE_BFLOAT16
49 public ConvertibleTo<Diagonal<next_precision<ValueType, 3>>>,
52 public WritableToMatrixData<ValueType, int32>,
53 public WritableToMatrixData<ValueType, int64>,
54 public ReadableFromMatrixData<ValueType, int32>,
55 public ReadableFromMatrixData<ValueType, int64>,
56 public EnableAbsoluteComputation<remove_complex<Diagonal<ValueType>>> {
57 friend class EnablePolymorphicObject<Diagonal,
LinOp>;
58 friend class Csr<ValueType,
int32>;
59 friend class Csr<ValueType,
int64>;
61 GKO_ASSERT_SUPPORTED_VALUE_TYPE;
66 using ConvertibleTo<Csr<ValueType, int32>>::convert_to;
67 using ConvertibleTo<Csr<ValueType, int32>>::move_to;
68 using ConvertibleTo<Csr<ValueType, int64>>::convert_to;
69 using ConvertibleTo<Csr<ValueType, int64>>::move_to;
70 using ConvertibleTo<Diagonal<next_precision<ValueType>>>::convert_to;
71 using ConvertibleTo<Diagonal<next_precision<ValueType>>>::move_to;
73 using value_type = ValueType;
74 using index_type =
int64;
75 using mat_data = matrix_data<ValueType, int64>;
76 using mat_data32 = matrix_data<ValueType, int32>;
77 using device_mat_data = device_matrix_data<ValueType, int64>;
78 using device_mat_data32 = device_matrix_data<ValueType, int32>;
79 using absolute_type = remove_complex<Diagonal>;
83 std::unique_ptr<LinOp>
transpose()
const override;
87 void convert_to(Diagonal<next_precision<ValueType>>* result)
const override;
89 void move_to(Diagonal<next_precision<ValueType>>* result)
override;
91 #if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
93 using ConvertibleTo<Diagonal<next_precision<ValueType, 2>>>::convert_to;
94 using ConvertibleTo<Diagonal<next_precision<ValueType, 2>>>::move_to;
97 Diagonal<next_precision<ValueType, 2>>* result)
const override;
99 void move_to(Diagonal<next_precision<ValueType, 2>>* result)
override;
102 #if GINKGO_ENABLE_HALF && GINKGO_ENABLE_BFLOAT16
104 using ConvertibleTo<Diagonal<next_precision<ValueType, 3>>>::convert_to;
105 using ConvertibleTo<Diagonal<next_precision<ValueType, 3>>>::move_to;
108 Diagonal<next_precision<ValueType, 3>>* result)
const override;
110 void move_to(Diagonal<next_precision<ValueType, 3>>* result)
override;
113 void convert_to(Csr<ValueType, int32>* result)
const override;
115 void move_to(Csr<ValueType, int32>* result)
override;
117 void convert_to(Csr<ValueType, int64>* result)
const override;
119 void move_to(Csr<ValueType, int64>* result)
override;
153 GKO_ASSERT_REVERSE_CONFORMANT(
this, b);
154 GKO_ASSERT_EQUAL_ROWS(b, x);
155 GKO_ASSERT_EQUAL_COLS(
this, x);
157 this->rapply_impl(b.
get(), x.
get());
171 GKO_ASSERT_CONFORMANT(
this, b);
172 GKO_ASSERT_EQUAL_ROWS(b, x);
173 GKO_ASSERT_EQUAL_ROWS(
this, x);
175 this->inverse_apply_impl(b.
get(), x.
get());
178 void read(
const mat_data& data)
override;
180 void read(
const mat_data32& data)
override;
182 void read(
const device_mat_data& data)
override;
184 void read(
const device_mat_data32& data)
override;
186 void read(device_mat_data&& data)
override;
188 void read(device_mat_data32&& data)
override;
190 void write(mat_data& data)
const override;
192 void write(mat_data32& data)
const override;
202 static std::unique_ptr<Diagonal>
create(
203 std::shared_ptr<const Executor> exec,
size_type size = 0);
219 static std::unique_ptr<Diagonal>
create(
220 std::shared_ptr<const Executor> exec,
const size_type size,
227 template <
typename InputValueType>
229 "explicitly construct the gko::array argument instead of passing an"
233 std::initializer_list<InputValueType> values)
249 std::shared_ptr<const Executor> exec,
size_type size,
250 gko::detail::const_array_view<ValueType>&& values);
258 void apply_impl(
const LinOp* b,
LinOp* x)
const override;
261 LinOp* x)
const override;
263 void rapply_impl(
const LinOp* b,
LinOp* x)
const;
265 void inverse_apply_impl(
const LinOp* b,
LinOp* x)
const;
276 #endif // GKO_PUBLIC_CORE_MATRIX_DIAGONAL_HPP_