[Mlir-commits] [mlir] 53ffafb - [mlir][sparse] support sparse constant to BSR conversion. (#71114)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 2 14:45:43 PDT 2023


Author: Peiming Liu
Date: 2023-11-02T14:45:39-07:00
New Revision: 53ffafb24d81c25909a1ae06584fc65245b86b7f

URL: https://github.com/llvm/llvm-project/commit/53ffafb24d81c25909a1ae06584fc65245b86b7f
DIFF: https://github.com/llvm/llvm-project/commit/53ffafb24d81c25909a1ae06584fc65245b86b7f.diff

LOG: [mlir][sparse] support sparse constant to BSR conversion. (#71114)

support direct convert from a constant tensor defined by
SparseArrayElements to BSR

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index c727b8d05c26d7d..b3a155686824555 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1666,11 +1666,10 @@ LogicalResult ForeachOp::verify() {
   const Dimension dimRank = t.getDimRank();
   const auto args = getBody()->getArguments();
 
-  if (getOrder().has_value() &&
-      (t.getEncoding() || !getOrder()->isPermutation()))
-    return emitError("Only support permuted order on non encoded dense tensor");
+  if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
+    return emitError("Level traverse order does not match tensor's level rank");
 
-  if (static_cast<size_t>(dimRank) + 1 + getInitArgs().size() != args.size())
+  if (dimRank + 1 + getInitArgs().size() != args.size())
     return emitError("Unmatched number of arguments in the block");
 
   if (getNumResults() != getInitArgs().size())

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index f6fb59fa2c3b84b..db969436a30712d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -421,8 +421,11 @@ Operation *mlir::sparse_tensor::getTop(Operation *op) {
 void sparse_tensor::foreachInSparseConstant(
     OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order,
     function_ref<void(ArrayRef<Value>, Value)> callback) {
-  const Dimension dimRank =
-      SparseTensorType(getRankedTensorType(attr)).getDimRank();
+  if (!order)
+    order = builder.getMultiDimIdentityMap(attr.getType().getRank());
+
+  auto stt = SparseTensorType(getRankedTensorType(attr));
+  const Dimension dimRank = stt.getDimRank();
   const auto coordinates = attr.getIndices().getValues<IntegerAttr>();
   const auto values = attr.getValues().getValues<Attribute>();
 
@@ -446,20 +449,23 @@ void sparse_tensor::foreachInSparseConstant(
 
   // Sorts the sparse element attribute based on coordinates.
   std::sort(elems.begin(), elems.end(),
-            [order, dimRank](const ElementAttr &lhs, const ElementAttr &rhs) {
-              const auto &lhsCoords = lhs.first;
-              const auto &rhsCoords = rhs.first;
-              for (Dimension d = 0; d < dimRank; d++) {
-                // FIXME: This only makes sense for permutations.
-                // And since we don't check that `order` is a permutation,
-                // it can also cause OOB errors when we use `l`.
-                const Level l = order ? order.getDimPosition(d) : d;
-                if (lhsCoords[l].getInt() == rhsCoords[l].getInt())
-                  continue;
-                return lhsCoords[l].getInt() < rhsCoords[l].getInt();
-              }
+            [order](const ElementAttr &lhs, const ElementAttr &rhs) {
               if (std::addressof(lhs) == std::addressof(rhs))
                 return false;
+
+              auto lhsCoords = llvm::map_to_vector(
+                  lhs.first, [](IntegerAttr i) { return i.getInt(); });
+              auto rhsCoords = llvm::map_to_vector(
+                  rhs.first, [](IntegerAttr i) { return i.getInt(); });
+
+              SmallVector<int64_t, 4> lhsLvlCrds = order.compose(lhsCoords);
+              SmallVector<int64_t, 4> rhsLvlCrds = order.compose(rhsCoords);
+              // Sort the element based on the lvl coordinates.
+              for (Level l = 0; l < order.getNumResults(); l++) {
+                if (lhsLvlCrds[l] == rhsLvlCrds[l])
+                  continue;
+                return lhsLvlCrds[l] < rhsLvlCrds[l];
+              }
               llvm_unreachable("no equal coordinate in sparse element attr");
             });
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 13388dce6bbb5ec..7770bd857e88093 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1129,14 +1129,11 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
 
     SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
     if (op.getOrder()) {
-      // FIXME: There is some dim/lvl confusion here since `dimRank != lvlRank`
-      const Dimension dimRank = stt.getDimRank();
-      SmallVector<Value> dcvs = lcvs; // keep a copy
-      for (Dimension d = 0; d < dimRank; d++) {
-        auto l = op.getOrder()->getDimPosition(d);
-        lcvs[l] = dcvs[d];
-      }
+      // TODO: Support it so that we can do direct conversion from CSR->BSR.
+      llvm_unreachable(
+          "Level order not yet implemented on non-constant input tensors.");
     }
+
     Value vals = loopEmitter.getValBuffer()[0];
     Value pos = loopEmitter.getPosits()[0].back();
     // Loads the value from sparse tensor using position-index;

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir
index ec14492c5b44999..ccd61aea0ba1684 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir
@@ -74,26 +74,34 @@ module {
     //
     // Initialize a 2-dim dense tensor.
     //
-    %t = arith.constant dense<[
-       [  1.0,  2.0,  3.0,  4.0 ],
-       [  5.0,  6.0,  7.0,  8.0 ]
-    ]> : tensor<2x4xf64>
+    %t = arith.constant sparse<[[0, 0], [0, 1], [0, 2], [0, 3],
+                                [1, 0], [1, 1], [1, 2], [1, 3]],
+                                [ 1.0, 2.0, 3.0, 4.0,
+                                  5.0, 6.0, 7.0, 8.0 ]> : tensor<2x4xf64>
 
+    %td = arith.constant dense<[[ 1.0, 2.0, 3.0, 4.0 ],
+                                [ 5.0, 6.0, 7.0, 8.0 ]]> : tensor<2x4xf64>
 
-    %1 = sparse_tensor.convert %t : tensor<2x4xf64> to tensor<2x4xf64, #CSR>
-    %2 = sparse_tensor.convert %1 : tensor<2x4xf64, #CSR> to tensor<2x4xf64, #BSR>
-    %3 = sparse_tensor.convert %2 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSC>
+    // constant -> BSR (either from SparseElementAttibutes or DenseElementAttribute)
+    %1 = sparse_tensor.convert %t : tensor<2x4xf64> to tensor<2x4xf64, #BSR>
+    %2 = sparse_tensor.convert %td : tensor<2x4xf64> to tensor<2x4xf64, #BSR>
+    %3 = sparse_tensor.convert %1 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSR>
+    %4 = sparse_tensor.convert %1 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSC>
 
-    %v1 = sparse_tensor.values %1 : tensor<2x4xf64, #CSR> to memref<?xf64>
+    %v1 = sparse_tensor.values %1 : tensor<2x4xf64, #BSR> to memref<?xf64>
     %v2 = sparse_tensor.values %2 : tensor<2x4xf64, #BSR> to memref<?xf64>
-    %v3 = sparse_tensor.values %3 : tensor<2x4xf64, #CSC> to memref<?xf64>
+    %v3 = sparse_tensor.values %3 : tensor<2x4xf64, #CSR> to memref<?xf64>
+    %v4 = sparse_tensor.values %4 : tensor<2x4xf64, #CSC> to memref<?xf64>
 
-    // CHECK:      ( 1, 2, 3, 4, 5, 6, 7, 8 )
+
+    // CHECK:      ( 1, 2, 5, 6, 3, 4, 7, 8 )
     // CHECK-NEXT: ( 1, 2, 5, 6, 3, 4, 7, 8 )
+    // CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8 )
     // CHECK-NEXT: ( 1, 5, 2, 6, 3, 7, 4, 8 )
     call @dumpf64(%v1) : (memref<?xf64>) -> ()
     call @dumpf64(%v2) : (memref<?xf64>) -> ()
     call @dumpf64(%v3) : (memref<?xf64>) -> ()
+    call @dumpf64(%v4) : (memref<?xf64>) -> ()
 
     return
   }


        


More information about the Mlir-commits mailing list