[Mlir-commits] [mlir] [mlir][sparse] add conanicalization patterns for IterateOp. (PR #95569)

Peiming Liu llvmlistbot at llvm.org
Fri Jun 14 09:58:03 PDT 2024


https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/95569

None

>From 7d5444cbfca0e66d61ec738557bfbbb0698adfd3 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 14 Jun 2024 16:56:47 +0000
Subject: [PATCH] [mlir][sparse] add conanicalization patterns for IterateOp.

---
 .../SparseTensor/IR/SparseTensorOps.td        |  8 +++++
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 34 +++++++++++++++++++
 .../Dialect/SparseTensor/canonicalize.mlir    | 20 ++++++++++-
 3 files changed, 61 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 5ae6f9f3443f8..a20de92d2d3ed 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1601,6 +1601,13 @@ def IterateOp : SparseTensor_Op<"iterate",
     BlockArgument getIterator() {
       return getRegion().getArguments().front();
     }
+    std::optional<BlockArgument> getLvlCrd(Level lvl) {
+      if (getCrdUsedLvls()[lvl]) {
+        uint64_t mask = (1 << lvl) - 1;
+        return getCrds()[llvm::popcount(mask & getCrdUsedLvls())];
+      }
+      return std::nullopt;
+    }
     Block::BlockArgListType getCrds() {
       // The first block argument is iterator, the remaining arguments are
       // referenced coordinates.
@@ -1613,6 +1620,7 @@ def IterateOp : SparseTensor_Op<"iterate",
 
   let hasVerifier = 1;
   let hasRegionVerifier = 1;
+  let hasCanonicalizer = 1;
   let hasCustomAssemblyFormat = 1;
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 232d25d718c65..ac711769ed2ea 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -24,6 +24,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/Bitset.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/FormatVariadic.h"
 
@@ -2266,6 +2267,39 @@ LogicalResult ExtractIterSpaceOp::verify() {
   return success();
 }
 
+struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IterateOp iterateOp,
+                                PatternRewriter &rewriter) const override {
+    LevelSet newUsedLvls(0);
+    llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
+    for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
+      if (auto crd = iterateOp.getLvlCrd(i)) {
+        if (crd->getUsers().empty())
+          toRemove.set(crd->getArgNumber());
+        else
+          newUsedLvls.set(i);
+      }
+    }
+
+    // All coordinates are used.
+    if (toRemove.none())
+      return failure();
+
+    rewriter.startOpModification(iterateOp);
+    iterateOp.setCrdUsedLvls(newUsedLvls);
+    iterateOp.getBody()->eraseArguments(toRemove);
+    rewriter.finalizeOpModification(iterateOp);
+    return success();
+  }
+};
+
+void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
+                                            mlir::MLIRContext *context) {
+  results.add<RemoveUnusedLvlCrds>(context);
+}
+
 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
   OpAsmParser::Argument iterator;
   OpAsmParser::UnresolvedOperand iterSpace;
diff --git a/mlir/test/Dialect/SparseTensor/canonicalize.mlir b/mlir/test/Dialect/SparseTensor/canonicalize.mlir
index b1d3d7916c142..37b6c89f43be1 100644
--- a/mlir/test/Dialect/SparseTensor/canonicalize.mlir
+++ b/mlir/test/Dialect/SparseTensor/canonicalize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
 
 #BCOO = #sparse_tensor.encoding<{
   map = (d0, d1, d2) -> (d0 : dense, d1 : loose_compressed(nonunique), d2 : singleton)
@@ -21,3 +21,21 @@ func.func @sparse_slice_canonicalize(%arg0 : tensor<?x?x?xf32, #BCOO>, %arg1 : i
   %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32, #BCOO> to tensor<?x?x?xf32, #BCOO>
   return %0 : tensor<?x?x?xf32, #BCOO>
 }
+
+// -----
+
+#CSR = #sparse_tensor.encoding<{
+  map = (i, j) -> (i : dense, j : compressed)
+}>
+
+// Make sure that the first unused coordinate is optimized.
+// CHECK-LABEL: @sparse_iterate_canonicalize
+// CHECK:         sparse_tensor.iterate {{.*}} at(_, %{{.*}})
+func.func @sparse_iterate_canonicalize(%sp : tensor<?x?xf64, #CSR>) {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 to 2
+      : tensor<?x?xf64, #CSR> -> !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
+  sparse_tensor.iterate %it1 in %l1 at (%coord0, %coord1) : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2> {
+    "test.op"(%coord1) : (index) -> ()
+  }
+  return
+}



More information about the Mlir-commits mailing list