[Mlir-commits] [mlir] [mlir][sparse] assemble SoA COO correctly. (PR #82449)

Peiming Liu llvmlistbot at llvm.org
Tue Feb 20 17:32:36 PST 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/82449

>From fe4613a34afa295381deddfaa1daf6f42493e57f Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 21 Feb 2024 01:05:51 +0000
Subject: [PATCH 1/2] [mlir][sparse] assemble SoA COO correctly.

---
 .../Transforms/SparseAssembler.cpp            | 77 +++++++++----------
 mlir/test/Dialect/SparseTensor/external.mlir  | 24 ++++++
 2 files changed, 61 insertions(+), 40 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index 9414d81e6bf5c6..a107cd71959abd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -22,16 +22,13 @@ using namespace sparse_tensor;
 // Helper methods.
 //===----------------------------------------------------------------------===//
 
-// TODO: reuse StorageLayout::foreachField?
-
-// TODO: we need COO AoS and SoA
-
 // Convert type range to new types range, with sparse tensors externalized.
-void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
-               SmallVectorImpl<Type> *extraTypes = nullptr) {
+static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
+                      SmallVectorImpl<Type> *extraTypes = nullptr) {
   for (auto type : types) {
+    auto enc = getSparseTensorEncoding(type);
     // All "dense" data passes through unmodified.
-    if (!getSparseTensorEncoding(type)) {
+    if (!enc) {
       convTypes.push_back(type);
       continue;
     }
@@ -42,29 +39,30 @@ void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
     convTypes.push_back(vtp);
     if (extraTypes)
       extraTypes->push_back(vtp);
-    // Convert the external representations of the pos/crd arrays.
-    for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
-      const auto lt = stt.getLvlType(lvl);
-      if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
-        auto ptp = RankedTensorType::get(shape, stt.getPosType());
-        auto ctp = RankedTensorType::get(shape, stt.getCrdType());
-        convTypes.push_back(ptp);
-        convTypes.push_back(ctp);
-        if (extraTypes) {
-          extraTypes->push_back(ptp);
-          extraTypes->push_back(ctp);
-        }
-      } else {
-        assert(isDenseLT(lt)); // TODO: handle other cases
+
+    // Convert the external representation of the position/coordinate array.
+    foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes](
+                                               Type t, FieldIndex,
+                                               SparseTensorFieldKind kind,
+                                               Level, LevelType) {
+      if (kind == SparseTensorFieldKind::CrdMemRef ||
+          kind == SparseTensorFieldKind::PosMemRef) {
+        ShapedType st = t.cast<ShapedType>();
+        auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
+        convTypes.push_back(rtp);
+        if (extraTypes)
+          extraTypes->push_back(rtp);
       }
-    }
+      return true;
+    });
   }
 }
 
 // Convert input and output values to [dis]assemble ops for sparse tensors.
-void convVals(OpBuilder &builder, Location loc, TypeRange types,
-              ValueRange fromVals, ValueRange extraVals,
-              SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn) {
+static void convVals(OpBuilder &builder, Location loc, TypeRange types,
+                     ValueRange fromVals, ValueRange extraVals,
+                     SmallVectorImpl<Value> &toVals, unsigned extra,
+                     bool isIn) {
   unsigned idx = 0;
   for (auto type : types) {
     // All "dense" data passes through unmodified.
@@ -85,29 +83,28 @@ void convVals(OpBuilder &builder, Location loc, TypeRange types,
     if (!isIn) {
       inputs.push_back(extraVals[extra++]);
       retTypes.push_back(RankedTensorType::get(shape, stt.getElementType()));
-      cntTypes.push_back(builder.getIndexType());
+      cntTypes.push_back(builder.getIndexType()); // nnz
     }
+
     // Collect the external representations of the pos/crd arrays.
-    for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
-      const auto lt = stt.getLvlType(lvl);
-      if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
+    foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
+                                                     SparseTensorFieldKind kind,
+                                                     Level, LevelType) {
+      if (kind == SparseTensorFieldKind::CrdMemRef ||
+          kind == SparseTensorFieldKind::PosMemRef) {
         if (isIn) {
           inputs.push_back(fromVals[idx++]);
-          inputs.push_back(fromVals[idx++]);
         } else {
-          Type pTp = stt.getPosType();
-          Type cTp = stt.getCrdType();
-          inputs.push_back(extraVals[extra++]);
+          ShapedType st = t.cast<ShapedType>();
+          auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
           inputs.push_back(extraVals[extra++]);
-          retTypes.push_back(RankedTensorType::get(shape, pTp));
-          retTypes.push_back(RankedTensorType::get(shape, cTp));
-          cntTypes.push_back(pTp);
-          cntTypes.push_back(cTp);
+          retTypes.push_back(rtp);
+          cntTypes.push_back(rtp.getElementType());
         }
-      } else {
-        assert(isDenseLT(lt)); // TODO: handle other cases
       }
-    }
+      return true;
+    });
+
     if (isIn) {
       // Assemble multiple inputs into a single sparse tensor.
       auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);
diff --git a/mlir/test/Dialect/SparseTensor/external.mlir b/mlir/test/Dialect/SparseTensor/external.mlir
index c17ba13e86c926..b5701ad2024264 100644
--- a/mlir/test/Dialect/SparseTensor/external.mlir
+++ b/mlir/test/Dialect/SparseTensor/external.mlir
@@ -100,3 +100,27 @@ func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<6
 func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32, #sparse> {
   return %arg0 : tensor<64x64xf32, #sparse>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @sparse_inout_coo_soa(
+// CHECK-SAME:    %[[A:.*0]]: tensor<?xf32>,
+// CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME:    %[[D:.*3]]: tensor<?xindex>,
+// CHECK-SAME:    %[[E:.*4]]: tensor<?xf32>,
+// CHECK-SAME:    %[[F:.*5]]: tensor<?xindex>,
+// CHECK-SAME:    %[[G:.*6]]: tensor<?xindex>,
+// CHECK-SAME:    %[[H:.*7]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]], %[[D]]
+// CHECK:         %[[F:.*]] = call @_internal_sparse_inout_coo_soa(%[[I]])
+// CHECK:         sparse_tensor.disassemble %[[F]]
+// CHECK:         return
+// CHECK:       }
+// CHECK:       func.func private @_internal_sparse_inout
+#sparse = #sparse_tensor.encoding<{
+   map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa))
+}>
+func.func @sparse_inout_coo_soa(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32, #sparse> {
+  return %arg0 : tensor<64x64xf32, #sparse>
+}

>From 5a7ae901f002fea2d6778b104d7f95c6fde71364 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 21 Feb 2024 01:32:19 +0000
Subject: [PATCH 2/2] address comments

---
 mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index a107cd71959abd..cd6b9b49893731 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -26,9 +26,8 @@ using namespace sparse_tensor;
 static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
                       SmallVectorImpl<Type> *extraTypes = nullptr) {
   for (auto type : types) {
-    auto enc = getSparseTensorEncoding(type);
     // All "dense" data passes through unmodified.
-    if (!enc) {
+    if (!getSparseTensorEncoding(type)) {
       convTypes.push_back(type);
       continue;
     }



More information about the Mlir-commits mailing list