[Mlir-commits] [mlir] [mlir][sparse] allow for direct-out passing of sparse tensor buffers (PR #88327)
Peiming Liu
llvmlistbot at llvm.org
Wed Apr 10 16:26:27 PDT 2024
================
@@ -24,39 +25,46 @@ using namespace sparse_tensor;
// Convert type range to new types range, with sparse tensors externalized.
static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
- SmallVectorImpl<Type> *extraTypes = nullptr) {
+ SmallVectorImpl<Type> *extraTypes, bool directOut) {
for (auto type : types) {
// All "dense" data passes through unmodified.
if (!getSparseTensorEncoding(type)) {
convTypes.push_back(type);
continue;
}
- // Convert the external representation of the position/coordinate array
+ // Convert the external representations of the pos/crd/val arrays.
const SparseTensorType stt(cast<RankedTensorType>(type));
- foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes](
- Type t, FieldIndex,
- SparseTensorFieldKind kind,
- Level, LevelType) {
- if (kind == SparseTensorFieldKind::CrdMemRef ||
- kind == SparseTensorFieldKind::PosMemRef ||
- kind == SparseTensorFieldKind::ValMemRef) {
- ShapedType st = t.cast<ShapedType>();
- auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
- convTypes.push_back(rtp);
- if (extraTypes)
- extraTypes->push_back(rtp);
- }
- return true;
- });
+ foreachFieldAndTypeInSparseTensor(
+ stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
----------------
PeimingLiu wrote:
`Type t` here should be the same as `MemRefType` created at Line 50.
https://github.com/llvm/llvm-project/pull/88327
More information about the Mlir-commits
mailing list