[Mlir-commits] [mlir] [mlir][sparse] allow for direct-out passing of sparse tensor buffers (PR #88327)

Peiming Liu llvmlistbot at llvm.org
Wed Apr 10 16:26:27 PDT 2024


================
@@ -24,39 +25,46 @@ using namespace sparse_tensor;
 
 // Convert type range to new types range, with sparse tensors externalized.
 static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
-                      SmallVectorImpl<Type> *extraTypes = nullptr) {
+                      SmallVectorImpl<Type> *extraTypes, bool directOut) {
   for (auto type : types) {
     // All "dense" data passes through unmodified.
     if (!getSparseTensorEncoding(type)) {
       convTypes.push_back(type);
       continue;
     }
 
-    // Convert the external representation of the position/coordinate array
+    // Convert the external representations of the pos/crd/val arrays.
     const SparseTensorType stt(cast<RankedTensorType>(type));
-    foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes](
-                                               Type t, FieldIndex,
-                                               SparseTensorFieldKind kind,
-                                               Level, LevelType) {
-      if (kind == SparseTensorFieldKind::CrdMemRef ||
-          kind == SparseTensorFieldKind::PosMemRef ||
-          kind == SparseTensorFieldKind::ValMemRef) {
-        ShapedType st = t.cast<ShapedType>();
-        auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
-        convTypes.push_back(rtp);
-        if (extraTypes)
-          extraTypes->push_back(rtp);
-      }
-      return true;
-    });
+    foreachFieldAndTypeInSparseTensor(
+        stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
----------------
PeimingLiu wrote:

`Type t` here should be the same as `MemRefType` created at Line 50.

https://github.com/llvm/llvm-project/pull/88327


More information about the Mlir-commits mailing list