5 #ifndef GKO_PUBLIC_CORE_MATRIX_DIAGONAL_HPP_
6 #define GKO_PUBLIC_CORE_MATRIX_DIAGONAL_HPP_
9 #include <ginkgo/core/base/array.hpp>
10 #include <ginkgo/core/base/lin_op.hpp>
17 template <
typename ValueType,
typename IndexType>
20 template <
typename ValueType>
39 template <
typename ValueType = default_precision>
41 :
public EnableLinOp<Diagonal<ValueType>>,
42 public ConvertibleTo<Csr<ValueType, int32>>,
43 public ConvertibleTo<Csr<ValueType, int64>>,
44 public ConvertibleTo<Diagonal<next_precision<ValueType>>>,
45 #if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
46 public ConvertibleTo<Diagonal<next_precision<ValueType, 2>>>,
48 #if GINKGO_ENABLE_HALF && GINKGO_ENABLE_BFLOAT16
49 public ConvertibleTo<Diagonal<next_precision<ValueType, 3>>>,
52 public WritableToMatrixData<ValueType, int32>,
53 public WritableToMatrixData<ValueType, int64>,
54 public ReadableFromMatrixData<ValueType, int32>,
55 public ReadableFromMatrixData<ValueType, int64>,
56 public EnableAbsoluteComputation<remove_complex<Diagonal<ValueType>>> {
57 friend class EnablePolymorphicObject<Diagonal,
LinOp>;
58 friend class Csr<ValueType,
int32>;
59 friend class Csr<ValueType,
int64>;
65 using ConvertibleTo<Csr<ValueType, int32>>::convert_to;
66 using ConvertibleTo<Csr<ValueType, int32>>::move_to;
67 using ConvertibleTo<Csr<ValueType, int64>>::convert_to;
68 using ConvertibleTo<Csr<ValueType, int64>>::move_to;
69 using ConvertibleTo<Diagonal<next_precision<ValueType>>>::convert_to;
70 using ConvertibleTo<Diagonal<next_precision<ValueType>>>::move_to;
72 using value_type = ValueType;
73 using index_type =
int64;
74 using mat_data = matrix_data<ValueType, int64>;
75 using mat_data32 = matrix_data<ValueType, int32>;
76 using device_mat_data = device_matrix_data<ValueType, int64>;
77 using device_mat_data32 = device_matrix_data<ValueType, int32>;
78 using absolute_type = remove_complex<Diagonal>;
82 std::unique_ptr<LinOp>
transpose()
const override;
86 void convert_to(Diagonal<next_precision<ValueType>>* result)
const override;
88 void move_to(Diagonal<next_precision<ValueType>>* result)
override;
90 #if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
92 using ConvertibleTo<Diagonal<next_precision<ValueType, 2>>>::convert_to;
93 using ConvertibleTo<Diagonal<next_precision<ValueType, 2>>>::move_to;
96 Diagonal<next_precision<ValueType, 2>>* result)
const override;
98 void move_to(Diagonal<next_precision<ValueType, 2>>* result)
override;
101 #if GINKGO_ENABLE_HALF && GINKGO_ENABLE_BFLOAT16
103 using ConvertibleTo<Diagonal<next_precision<ValueType, 3>>>::convert_to;
104 using ConvertibleTo<Diagonal<next_precision<ValueType, 3>>>::move_to;
107 Diagonal<next_precision<ValueType, 3>>* result)
const override;
109 void move_to(Diagonal<next_precision<ValueType, 3>>* result)
override;
112 void convert_to(Csr<ValueType, int32>* result)
const override;
114 void move_to(Csr<ValueType, int32>* result)
override;
116 void convert_to(Csr<ValueType, int64>* result)
const override;
118 void move_to(Csr<ValueType, int64>* result)
override;
152 GKO_ASSERT_REVERSE_CONFORMANT(
this, b);
153 GKO_ASSERT_EQUAL_ROWS(b, x);
154 GKO_ASSERT_EQUAL_COLS(
this, x);
156 this->rapply_impl(b.
get(), x.
get());
170 GKO_ASSERT_CONFORMANT(
this, b);
171 GKO_ASSERT_EQUAL_ROWS(b, x);
172 GKO_ASSERT_EQUAL_ROWS(
this, x);
174 this->inverse_apply_impl(b.
get(), x.
get());
177 void read(
const mat_data& data)
override;
179 void read(
const mat_data32& data)
override;
181 void read(
const device_mat_data& data)
override;
183 void read(
const device_mat_data32& data)
override;
185 void read(device_mat_data&& data)
override;
187 void read(device_mat_data32&& data)
override;
189 void write(mat_data& data)
const override;
191 void write(mat_data32& data)
const override;
201 static std::unique_ptr<Diagonal>
create(
202 std::shared_ptr<const Executor> exec,
size_type size = 0);
218 static std::unique_ptr<Diagonal>
create(
219 std::shared_ptr<const Executor> exec,
const size_type size,
226 template <
typename InputValueType>
228 "explicitly construct the gko::array argument instead of passing an"
232 std::initializer_list<InputValueType> values)
248 std::shared_ptr<const Executor> exec,
size_type size,
249 gko::detail::const_array_view<ValueType>&& values);
257 void apply_impl(
const LinOp* b,
LinOp* x)
const override;
260 LinOp* x)
const override;
262 void rapply_impl(
const LinOp* b,
LinOp* x)
const;
264 void inverse_apply_impl(
const LinOp* b,
LinOp* x)
const;
275 #endif // GKO_PUBLIC_CORE_MATRIX_DIAGONAL_HPP_