5 #ifndef GKO_PUBLIC_CORE_MATRIX_DIAGONAL_HPP_
6 #define GKO_PUBLIC_CORE_MATRIX_DIAGONAL_HPP_
9 #include <ginkgo/core/base/array.hpp>
10 #include <ginkgo/core/base/lin_op.hpp>
17 template <
typename ValueType,
typename IndexType>
20 template <
typename ValueType>
39 template <
typename ValueType = default_precision>
41 :
public EnableLinOp<Diagonal<ValueType>>,
42 public ConvertibleTo<Csr<ValueType, int32>>,
43 public ConvertibleTo<Csr<ValueType, int64>>,
44 public ConvertibleTo<Diagonal<next_precision<ValueType>>>,
45 #if GINKGO_ENABLE_HALF
46 public ConvertibleTo<Diagonal<next_precision<next_precision<ValueType>>>>,
49 public WritableToMatrixData<ValueType, int32>,
50 public WritableToMatrixData<ValueType, int64>,
51 public ReadableFromMatrixData<ValueType, int32>,
52 public ReadableFromMatrixData<ValueType, int64>,
53 public EnableAbsoluteComputation<remove_complex<Diagonal<ValueType>>> {
54 friend class EnablePolymorphicObject<Diagonal,
LinOp>;
55 friend class Csr<ValueType,
int32>;
56 friend class Csr<ValueType,
int64>;
62 using ConvertibleTo<Csr<ValueType, int32>>::convert_to;
63 using ConvertibleTo<Csr<ValueType, int32>>::move_to;
64 using ConvertibleTo<Csr<ValueType, int64>>::convert_to;
65 using ConvertibleTo<Csr<ValueType, int64>>::move_to;
66 using ConvertibleTo<Diagonal<next_precision<ValueType>>>::convert_to;
67 using ConvertibleTo<Diagonal<next_precision<ValueType>>>::move_to;
69 using value_type = ValueType;
70 using index_type =
int64;
71 using mat_data = matrix_data<ValueType, int64>;
72 using mat_data32 = matrix_data<ValueType, int32>;
73 using device_mat_data = device_matrix_data<ValueType, int64>;
74 using device_mat_data32 = device_matrix_data<ValueType, int32>;
75 using absolute_type = remove_complex<Diagonal>;
77 friend class Diagonal<previous_precision<ValueType>>;
79 std::unique_ptr<LinOp>
transpose()
const override;
83 void convert_to(Diagonal<next_precision<ValueType>>* result)
const override;
85 void move_to(Diagonal<next_precision<ValueType>>* result)
override;
87 #if GINKGO_ENABLE_HALF
88 friend class Diagonal<previous_precision<previous_precision<ValueType>>>;
90 Diagonal<next_precision<next_precision<ValueType>>>>::convert_to;
92 Diagonal<next_precision<next_precision<ValueType>>>>::move_to;
94 void convert_to(Diagonal<
next_precision<next_precision<ValueType>>>* result)
98 Diagonal<
next_precision<next_precision<ValueType>>>* result)
override;
101 void convert_to(Csr<ValueType, int32>* result)
const override;
103 void move_to(Csr<ValueType, int32>* result)
override;
105 void convert_to(Csr<ValueType, int64>* result)
const override;
107 void move_to(Csr<ValueType, int64>* result)
override;
141 GKO_ASSERT_REVERSE_CONFORMANT(
this, b);
142 GKO_ASSERT_EQUAL_ROWS(b, x);
143 GKO_ASSERT_EQUAL_COLS(
this, x);
145 this->rapply_impl(b.
get(), x.
get());
159 GKO_ASSERT_CONFORMANT(
this, b);
160 GKO_ASSERT_EQUAL_ROWS(b, x);
161 GKO_ASSERT_EQUAL_ROWS(
this, x);
163 this->inverse_apply_impl(b.
get(), x.
get());
166 void read(
const mat_data& data)
override;
168 void read(
const mat_data32& data)
override;
170 void read(
const device_mat_data& data)
override;
172 void read(
const device_mat_data32& data)
override;
174 void read(device_mat_data&& data)
override;
176 void read(device_mat_data32&& data)
override;
178 void write(mat_data& data)
const override;
180 void write(mat_data32& data)
const override;
190 static std::unique_ptr<Diagonal>
create(
191 std::shared_ptr<const Executor> exec,
size_type size = 0);
207 static std::unique_ptr<Diagonal>
create(
208 std::shared_ptr<const Executor> exec,
const size_type size,
215 template <
typename InputValueType>
217 "explicitly construct the gko::array argument instead of passing an"
221 std::initializer_list<InputValueType> values)
237 std::shared_ptr<const Executor> exec,
size_type size,
238 gko::detail::const_array_view<ValueType>&& values);
246 void apply_impl(
const LinOp* b,
LinOp* x)
const override;
249 LinOp* x)
const override;
251 void rapply_impl(
const LinOp* b,
LinOp* x)
const;
253 void inverse_apply_impl(
const LinOp* b,
LinOp* x)
const;
264 #endif // GKO_PUBLIC_CORE_MATRIX_DIAGONAL_HPP_