[Mlir-commits] [mlir] [mlir][sparse] support sparse constant to BSR conversion. (PR #71114)

Peiming Liu llvmlistbot at llvm.org
Thu Nov 2 14:31:12 PDT 2023


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/71114

>From 1e190f542db2a0c331516aed1f0695664ec1a47d Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 2 Nov 2023 21:13:06 +0000
Subject: [PATCH 1/2] [mlir][sparse] support sparse constant to BSR conversion.

---
 .../SparseTensor/IR/SparseTensorDialect.cpp   |  7 ++--
 .../SparseTensor/Transforms/CodegenUtils.cpp  | 34 +++++++++++--------
 .../Transforms/SparseTensorRewriting.cpp      | 11 +++---
 .../CPU/sparse_conversion_block.mlir          | 23 ++++++-------
 4 files changed, 38 insertions(+), 37 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 99214fadf4ba3db..76ecd4171d81aa8 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1665,11 +1665,10 @@ LogicalResult ForeachOp::verify() {
   const Dimension dimRank = t.getDimRank();
   const auto args = getBody()->getArguments();
 
-  if (getOrder().has_value() &&
-      (t.getEncoding() || !getOrder()->isPermutation()))
-    return emitError("Only support permuted order on non encoded dense tensor");
+  if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
+    return emitError("Level traverse order does not match tensor's level rank");
 
-  if (static_cast<size_t>(dimRank) + 1 + getInitArgs().size() != args.size())
+  if (dimRank + 1 + getInitArgs().size() != args.size())
     return emitError("Unmatched number of arguments in the block");
 
   if (getNumResults() != getInitArgs().size())
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index f6fb59fa2c3b84b..db969436a30712d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -421,8 +421,11 @@ Operation *mlir::sparse_tensor::getTop(Operation *op) {
 void sparse_tensor::foreachInSparseConstant(
     OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order,
     function_ref<void(ArrayRef<Value>, Value)> callback) {
-  const Dimension dimRank =
-      SparseTensorType(getRankedTensorType(attr)).getDimRank();
+  if (!order)
+    order = builder.getMultiDimIdentityMap(attr.getType().getRank());
+
+  auto stt = SparseTensorType(getRankedTensorType(attr));
+  const Dimension dimRank = stt.getDimRank();
   const auto coordinates = attr.getIndices().getValues<IntegerAttr>();
   const auto values = attr.getValues().getValues<Attribute>();
 
@@ -446,20 +449,23 @@ void sparse_tensor::foreachInSparseConstant(
 
   // Sorts the sparse element attribute based on coordinates.
   std::sort(elems.begin(), elems.end(),
-            [order, dimRank](const ElementAttr &lhs, const ElementAttr &rhs) {
-              const auto &lhsCoords = lhs.first;
-              const auto &rhsCoords = rhs.first;
-              for (Dimension d = 0; d < dimRank; d++) {
-                // FIXME: This only makes sense for permutations.
-                // And since we don't check that `order` is a permutation,
-                // it can also cause OOB errors when we use `l`.
-                const Level l = order ? order.getDimPosition(d) : d;
-                if (lhsCoords[l].getInt() == rhsCoords[l].getInt())
-                  continue;
-                return lhsCoords[l].getInt() < rhsCoords[l].getInt();
-              }
+            [order](const ElementAttr &lhs, const ElementAttr &rhs) {
               if (std::addressof(lhs) == std::addressof(rhs))
                 return false;
+
+              auto lhsCoords = llvm::map_to_vector(
+                  lhs.first, [](IntegerAttr i) { return i.getInt(); });
+              auto rhsCoords = llvm::map_to_vector(
+                  rhs.first, [](IntegerAttr i) { return i.getInt(); });
+
+              SmallVector<int64_t, 4> lhsLvlCrds = order.compose(lhsCoords);
+              SmallVector<int64_t, 4> rhsLvlCrds = order.compose(rhsCoords);
+              // Sort the element based on the lvl coordinates.
+              for (Level l = 0; l < order.getNumResults(); l++) {
+                if (lhsLvlCrds[l] == rhsLvlCrds[l])
+                  continue;
+                return lhsLvlCrds[l] < rhsLvlCrds[l];
+              }
               llvm_unreachable("no equal coordinate in sparse element attr");
             });
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 13388dce6bbb5ec..7770bd857e88093 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1129,14 +1129,11 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
 
     SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
     if (op.getOrder()) {
-      // FIXME: There is some dim/lvl confusion here since `dimRank != lvlRank`
-      const Dimension dimRank = stt.getDimRank();
-      SmallVector<Value> dcvs = lcvs; // keep a copy
-      for (Dimension d = 0; d < dimRank; d++) {
-        auto l = op.getOrder()->getDimPosition(d);
-        lcvs[l] = dcvs[d];
-      }
+      // TODO: Support it so that we can do direct conversion from CSR->BSR.
+      llvm_unreachable(
+          "Level order not yet implemented on non-constant input tensors.");
     }
+
     Value vals = loopEmitter.getValBuffer()[0];
     Value pos = loopEmitter.getPosits()[0].back();
     // Loads the value from sparse tensor using position-index;
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir
index ec14492c5b44999..eb9e23f9b7dc846 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir
@@ -74,22 +74,21 @@ module {
     //
     // Initialize a 2-dim dense tensor.
     //
-    %t = arith.constant dense<[
-       [  1.0,  2.0,  3.0,  4.0 ],
-       [  5.0,  6.0,  7.0,  8.0 ]
-    ]> : tensor<2x4xf64>
+    %t = arith.constant sparse<[[0, 0], [0, 1], [0, 2], [0, 3],
+                                [1, 0], [1, 1], [1, 2], [1, 3]],
+                                [ 1.0, 2.0, 3.0, 4.0,
+                                  5.0, 6.0, 7.0, 8.0 ]> : tensor<2x4xf64>
 
+    %1 = sparse_tensor.convert %t : tensor<2x4xf64> to tensor<2x4xf64, #BSR>
+    %2 = sparse_tensor.convert %1 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSR>
+    %3 = sparse_tensor.convert %1 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSC>
 
-    %1 = sparse_tensor.convert %t : tensor<2x4xf64> to tensor<2x4xf64, #CSR>
-    %2 = sparse_tensor.convert %1 : tensor<2x4xf64, #CSR> to tensor<2x4xf64, #BSR>
-    %3 = sparse_tensor.convert %2 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSC>
-
-    %v1 = sparse_tensor.values %1 : tensor<2x4xf64, #CSR> to memref<?xf64>
-    %v2 = sparse_tensor.values %2 : tensor<2x4xf64, #BSR> to memref<?xf64>
+    %v1 = sparse_tensor.values %1 : tensor<2x4xf64, #BSR> to memref<?xf64>
+    %v2 = sparse_tensor.values %2 : tensor<2x4xf64, #CSR> to memref<?xf64>
     %v3 = sparse_tensor.values %3 : tensor<2x4xf64, #CSC> to memref<?xf64>
 
-    // CHECK:      ( 1, 2, 3, 4, 5, 6, 7, 8 )
-    // CHECK-NEXT: ( 1, 2, 5, 6, 3, 4, 7, 8 )
+    // CHECK:      ( 1, 2, 5, 6, 3, 4, 7, 8 )
+    // CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8 )
     // CHECK-NEXT: ( 1, 5, 2, 6, 3, 7, 4, 8 )
     call @dumpf64(%v1) : (memref<?xf64>) -> ()
     call @dumpf64(%v2) : (memref<?xf64>) -> ()

>From 4780f9badabf553648370608f412fe59166cf6d4 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 2 Nov 2023 21:30:43 +0000
Subject: [PATCH 2/2] add dense test

---
 .../CPU/sparse_conversion_block.mlir            | 17 +++++++++++++----
 1 file changed, 13 insertions(+), 4 deletions(-)

diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir
index eb9e23f9b7dc846..ccd61aea0ba1684 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir
@@ -79,20 +79,29 @@ module {
                                 [ 1.0, 2.0, 3.0, 4.0,
                                   5.0, 6.0, 7.0, 8.0 ]> : tensor<2x4xf64>
 
+    %td = arith.constant dense<[[ 1.0, 2.0, 3.0, 4.0 ],
+                                [ 5.0, 6.0, 7.0, 8.0 ]]> : tensor<2x4xf64>
+
+    // constant -> BSR (either from SparseElementAttibutes or DenseElementAttribute)
     %1 = sparse_tensor.convert %t : tensor<2x4xf64> to tensor<2x4xf64, #BSR>
-    %2 = sparse_tensor.convert %1 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSR>
-    %3 = sparse_tensor.convert %1 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSC>
+    %2 = sparse_tensor.convert %td : tensor<2x4xf64> to tensor<2x4xf64, #BSR>
+    %3 = sparse_tensor.convert %1 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSR>
+    %4 = sparse_tensor.convert %1 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSC>
 
     %v1 = sparse_tensor.values %1 : tensor<2x4xf64, #BSR> to memref<?xf64>
-    %v2 = sparse_tensor.values %2 : tensor<2x4xf64, #CSR> to memref<?xf64>
-    %v3 = sparse_tensor.values %3 : tensor<2x4xf64, #CSC> to memref<?xf64>
+    %v2 = sparse_tensor.values %2 : tensor<2x4xf64, #BSR> to memref<?xf64>
+    %v3 = sparse_tensor.values %3 : tensor<2x4xf64, #CSR> to memref<?xf64>
+    %v4 = sparse_tensor.values %4 : tensor<2x4xf64, #CSC> to memref<?xf64>
+
 
     // CHECK:      ( 1, 2, 5, 6, 3, 4, 7, 8 )
+    // CHECK-NEXT: ( 1, 2, 5, 6, 3, 4, 7, 8 )
     // CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8 )
     // CHECK-NEXT: ( 1, 5, 2, 6, 3, 7, 4, 8 )
     call @dumpf64(%v1) : (memref<?xf64>) -> ()
     call @dumpf64(%v2) : (memref<?xf64>) -> ()
     call @dumpf64(%v3) : (memref<?xf64>) -> ()
+    call @dumpf64(%v4) : (memref<?xf64>) -> ()
 
     return
   }



More information about the Mlir-commits mailing list