[Mlir-commits] [mlir] [mlir][sparse] add conanicalization patterns for IterateOp. (PR #95569)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 14 09:58:33 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse
@llvm/pr-subscribers-mlir
Author: Peiming Liu (PeimingLiu)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/95569.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+8)
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+34)
- (modified) mlir/test/Dialect/SparseTensor/canonicalize.mlir (+19-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 5ae6f9f3443f8..a20de92d2d3ed 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1601,6 +1601,13 @@ def IterateOp : SparseTensor_Op<"iterate",
BlockArgument getIterator() {
return getRegion().getArguments().front();
}
+ std::optional<BlockArgument> getLvlCrd(Level lvl) {
+ if (getCrdUsedLvls()[lvl]) {
+ uint64_t mask = (1 << lvl) - 1;
+ return getCrds()[llvm::popcount(mask & getCrdUsedLvls())];
+ }
+ return std::nullopt;
+ }
Block::BlockArgListType getCrds() {
// The first block argument is iterator, the remaining arguments are
// referenced coordinates.
@@ -1613,6 +1620,7 @@ def IterateOp : SparseTensor_Op<"iterate",
let hasVerifier = 1;
let hasRegionVerifier = 1;
+ let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 232d25d718c65..ac711769ed2ea 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -24,6 +24,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/Bitset.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
@@ -2266,6 +2267,39 @@ LogicalResult ExtractIterSpaceOp::verify() {
return success();
}
+struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IterateOp iterateOp,
+ PatternRewriter &rewriter) const override {
+ LevelSet newUsedLvls(0);
+ llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
+ for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
+ if (auto crd = iterateOp.getLvlCrd(i)) {
+ if (crd->getUsers().empty())
+ toRemove.set(crd->getArgNumber());
+ else
+ newUsedLvls.set(i);
+ }
+ }
+
+ // All coordinates are used.
+ if (toRemove.none())
+ return failure();
+
+ rewriter.startOpModification(iterateOp);
+ iterateOp.setCrdUsedLvls(newUsedLvls);
+ iterateOp.getBody()->eraseArguments(toRemove);
+ rewriter.finalizeOpModification(iterateOp);
+ return success();
+ }
+};
+
+void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
+ mlir::MLIRContext *context) {
+ results.add<RemoveUnusedLvlCrds>(context);
+}
+
ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::Argument iterator;
OpAsmParser::UnresolvedOperand iterSpace;
diff --git a/mlir/test/Dialect/SparseTensor/canonicalize.mlir b/mlir/test/Dialect/SparseTensor/canonicalize.mlir
index b1d3d7916c142..37b6c89f43be1 100644
--- a/mlir/test/Dialect/SparseTensor/canonicalize.mlir
+++ b/mlir/test/Dialect/SparseTensor/canonicalize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
#BCOO = #sparse_tensor.encoding<{
map = (d0, d1, d2) -> (d0 : dense, d1 : loose_compressed(nonunique), d2 : singleton)
@@ -21,3 +21,21 @@ func.func @sparse_slice_canonicalize(%arg0 : tensor<?x?x?xf32, #BCOO>, %arg1 : i
%0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32, #BCOO> to tensor<?x?x?xf32, #BCOO>
return %0 : tensor<?x?x?xf32, #BCOO>
}
+
+// -----
+
+#CSR = #sparse_tensor.encoding<{
+ map = (i, j) -> (i : dense, j : compressed)
+}>
+
+// Make sure that the first unused coordinate is optimized.
+// CHECK-LABEL: @sparse_iterate_canonicalize
+// CHECK: sparse_tensor.iterate {{.*}} at(_, %{{.*}})
+func.func @sparse_iterate_canonicalize(%sp : tensor<?x?xf64, #CSR>) {
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 to 2
+ : tensor<?x?xf64, #CSR> -> !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
+ sparse_tensor.iterate %it1 in %l1 at (%coord0, %coord1) : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2> {
+ "test.op"(%coord1) : (index) -> ()
+ }
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/95569
More information about the Mlir-commits
mailing list