[Mlir-commits] [mlir] [mlir][sparse] add conanicalization patterns for IterateOp. (PR #95569)
Peiming Liu
llvmlistbot at llvm.org
Fri Jun 14 10:03:01 PDT 2024
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/95569
>From f5e29af25cb7facf60029a5f030c7b3266993acb Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 14 Jun 2024 16:56:47 +0000
Subject: [PATCH] [mlir][sparse] add conanicalization patterns for IterateOp.
---
.../SparseTensor/IR/SparseTensorOps.td | 8 +++++
.../SparseTensor/IR/SparseTensorDialect.cpp | 34 +++++++++++++++++++
.../Dialect/SparseTensor/canonicalize.mlir | 18 ++++++++++
3 files changed, 60 insertions(+)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 5ae6f9f3443f8..a20de92d2d3ed 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1601,6 +1601,13 @@ def IterateOp : SparseTensor_Op<"iterate",
BlockArgument getIterator() {
return getRegion().getArguments().front();
}
+ std::optional<BlockArgument> getLvlCrd(Level lvl) {
+ if (getCrdUsedLvls()[lvl]) {
+ uint64_t mask = (1 << lvl) - 1;
+ return getCrds()[llvm::popcount(mask & getCrdUsedLvls())];
+ }
+ return std::nullopt;
+ }
Block::BlockArgListType getCrds() {
// The first block argument is iterator, the remaining arguments are
// referenced coordinates.
@@ -1613,6 +1620,7 @@ def IterateOp : SparseTensor_Op<"iterate",
let hasVerifier = 1;
let hasRegionVerifier = 1;
+ let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 232d25d718c65..ac711769ed2ea 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -24,6 +24,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/Bitset.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
@@ -2266,6 +2267,39 @@ LogicalResult ExtractIterSpaceOp::verify() {
return success();
}
+struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IterateOp iterateOp,
+ PatternRewriter &rewriter) const override {
+ LevelSet newUsedLvls(0);
+ llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
+ for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
+ if (auto crd = iterateOp.getLvlCrd(i)) {
+ if (crd->getUsers().empty())
+ toRemove.set(crd->getArgNumber());
+ else
+ newUsedLvls.set(i);
+ }
+ }
+
+ // All coordinates are used.
+ if (toRemove.none())
+ return failure();
+
+ rewriter.startOpModification(iterateOp);
+ iterateOp.setCrdUsedLvls(newUsedLvls);
+ iterateOp.getBody()->eraseArguments(toRemove);
+ rewriter.finalizeOpModification(iterateOp);
+ return success();
+ }
+};
+
+void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
+ mlir::MLIRContext *context) {
+ results.add<RemoveUnusedLvlCrds>(context);
+}
+
ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::Argument iterator;
OpAsmParser::UnresolvedOperand iterSpace;
diff --git a/mlir/test/Dialect/SparseTensor/canonicalize.mlir b/mlir/test/Dialect/SparseTensor/canonicalize.mlir
index b1d3d7916c142..ceb82cab516ed 100644
--- a/mlir/test/Dialect/SparseTensor/canonicalize.mlir
+++ b/mlir/test/Dialect/SparseTensor/canonicalize.mlir
@@ -21,3 +21,21 @@ func.func @sparse_slice_canonicalize(%arg0 : tensor<?x?x?xf32, #BCOO>, %arg1 : i
%0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32, #BCOO> to tensor<?x?x?xf32, #BCOO>
return %0 : tensor<?x?x?xf32, #BCOO>
}
+
+// -----
+
+#CSR = #sparse_tensor.encoding<{
+ map = (i, j) -> (i : dense, j : compressed)
+}>
+
+// Make sure that the first unused coordinate is optimized.
+// CHECK-LABEL: @sparse_iterate_canonicalize
+// CHECK: sparse_tensor.iterate {{.*}} at(_, %{{.*}})
+func.func @sparse_iterate_canonicalize(%sp : tensor<?x?xf64, #CSR>) {
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 to 2
+ : tensor<?x?xf64, #CSR> -> !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
+ sparse_tensor.iterate %it1 in %l1 at (%coord0, %coord1) : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2> {
+ "test.op"(%coord1) : (index) -> ()
+ }
+ return
+}
More information about the Mlir-commits
mailing list