[Mlir-commits] [mlir] a43d79a - [mlir][sparse] add canonicalization patterns for IterateOp. (#95569)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 14 10:31:32 PDT 2024


Author: Peiming Liu
Date: 2024-06-14T10:31:29-07:00
New Revision: a43d79af782e2730b170c04db49f4c3040399d97

URL: https://github.com/llvm/llvm-project/commit/a43d79af782e2730b170c04db49f4c3040399d97
DIFF: https://github.com/llvm/llvm-project/commit/a43d79af782e2730b170c04db49f4c3040399d97.diff

LOG: [mlir][sparse] add canonicalization patterns for IterateOp. (#95569)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/test/Dialect/SparseTensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 5ae6f9f3443f8..b2089924291cd 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1601,6 +1601,13 @@ def IterateOp : SparseTensor_Op<"iterate",
     BlockArgument getIterator() {
       return getRegion().getArguments().front();
     }
+    std::optional<BlockArgument> getLvlCrd(Level lvl) {
+      if (getCrdUsedLvls()[lvl]) {
+        uint64_t mask = (static_cast<uint64_t>(0x01u) << lvl) - 1;
+        return getCrds()[llvm::popcount(mask & getCrdUsedLvls())];
+      }
+      return std::nullopt;
+    }
     Block::BlockArgListType getCrds() {
       // The first block argument is iterator, the remaining arguments are
       // referenced coordinates.
@@ -1613,6 +1620,7 @@ def IterateOp : SparseTensor_Op<"iterate",
 
   let hasVerifier = 1;
   let hasRegionVerifier = 1;
+  let hasCanonicalizer = 1;
   let hasCustomAssemblyFormat = 1;
 }
 

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 232d25d718c65..ac711769ed2ea 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -24,6 +24,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/Bitset.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/FormatVariadic.h"
 
@@ -2266,6 +2267,39 @@ LogicalResult ExtractIterSpaceOp::verify() {
   return success();
 }
 
+struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IterateOp iterateOp,
+                                PatternRewriter &rewriter) const override {
+    LevelSet newUsedLvls(0);
+    llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
+    for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
+      if (auto crd = iterateOp.getLvlCrd(i)) {
+        if (crd->getUsers().empty())
+          toRemove.set(crd->getArgNumber());
+        else
+          newUsedLvls.set(i);
+      }
+    }
+
+    // All coordinates are used.
+    if (toRemove.none())
+      return failure();
+
+    rewriter.startOpModification(iterateOp);
+    iterateOp.setCrdUsedLvls(newUsedLvls);
+    iterateOp.getBody()->eraseArguments(toRemove);
+    rewriter.finalizeOpModification(iterateOp);
+    return success();
+  }
+};
+
+void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
+                                            mlir::MLIRContext *context) {
+  results.add<RemoveUnusedLvlCrds>(context);
+}
+
 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
   OpAsmParser::Argument iterator;
   OpAsmParser::UnresolvedOperand iterSpace;

diff  --git a/mlir/test/Dialect/SparseTensor/canonicalize.mlir b/mlir/test/Dialect/SparseTensor/canonicalize.mlir
index b1d3d7916c142..ceb82cab516ed 100644
--- a/mlir/test/Dialect/SparseTensor/canonicalize.mlir
+++ b/mlir/test/Dialect/SparseTensor/canonicalize.mlir
@@ -21,3 +21,21 @@ func.func @sparse_slice_canonicalize(%arg0 : tensor<?x?x?xf32, #BCOO>, %arg1 : i
   %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32, #BCOO> to tensor<?x?x?xf32, #BCOO>
   return %0 : tensor<?x?x?xf32, #BCOO>
 }
+
+// -----
+
+#CSR = #sparse_tensor.encoding<{
+  map = (i, j) -> (i : dense, j : compressed)
+}>
+
+// Make sure that the first unused coordinate is optimized.
+// CHECK-LABEL: @sparse_iterate_canonicalize
+// CHECK:         sparse_tensor.iterate {{.*}} at(_, %{{.*}})
+func.func @sparse_iterate_canonicalize(%sp : tensor<?x?xf64, #CSR>) {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 to 2
+      : tensor<?x?xf64, #CSR> -> !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
+  sparse_tensor.iterate %it1 in %l1 at (%coord0, %coord1) : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2> {
+    "test.op"(%coord1) : (index) -> ()
+  }
+  return
+}


        


More information about the Mlir-commits mailing list