[Mlir-commits] [mlir] [mlir][tensor] move tensor insert canonicalization to pattern (PR #142671)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 3 14:50:00 PDT 2025
https://github.com/asraa updated https://github.com/llvm/llvm-project/pull/142671
>From a82d7237d91ce007eb6bb1958502663275ea3113 Mon Sep 17 00:00:00 2001
From: Asra Ali <asraa at google.com>
Date: Tue, 3 Jun 2025 20:54:49 +0000
Subject: [PATCH] [mlir][tensor] move tensor insert canonicalization to pattern
Signed-off-by: Asra Ali <asraa at google.com>
fix
Signed-off-by: Asra Ali <asraa at google.com>
---
mlir/include/mlir/Dialect/Tensor/IR/Tensor.h | 3 +++
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 1 -
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 7 +++----
mlir/test/Dialect/Tensor/canonicalize.mlir | 16 ----------------
.../Dialect/Tensor/insert-after-constant.mlir | 14 ++++++++++++++
.../lib/Dialect/Tensor/TestTensorTransforms.cpp | 13 +++++++++++++
6 files changed, 33 insertions(+), 21 deletions(-)
create mode 100644 mlir/test/Dialect/Tensor/insert-after-constant.mlir
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index e8e1342ef36fd..447f5b906cad1 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -176,6 +176,9 @@ void populateFoldConstantExtractSlicePatterns(
return false;
});
+/// Patterns to fold inserts into a constant into a new constant.
+void populateFoldInsertAfterConstant(RewritePatternSet &patterns);
+
/// Patterns to fold extracts of a collapse_shaped tensor to an extract of the
/// source tensor.
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index c0885a3763827..35d0b16628417 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -827,7 +827,6 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
let hasFolder = 1;
let hasVerifier = 1;
- let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 12e8b257ce9f1..b30f84e6be724 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1653,7 +1653,7 @@ class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
return failure();
// Pattern requires constant indices
- SmallVector<uint64_t, 8> indices;
+ SmallVector<uint64_t> indices;
for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) {
auto indiceAttr = dyn_cast<Attribute>(indice);
if (!indiceAttr)
@@ -1717,9 +1717,8 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
return {};
}
-void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<InsertOpConstantFold>(context);
+void tensor::populateFoldInsertAfterConstant(RewritePatternSet &patterns) {
+ patterns.add<InsertOpConstantFold>(patterns.getContext());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 646b2197d9aa6..f033a43c0dc24 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -231,22 +231,6 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
return %ins_1 : tensor<4xf32>
}
-
-// -----
-
-func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
- // Fold an insert into a splat.
- // CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
- // CHECK-LITERAL:
- // CHECK-NEXT: return %[[C4]]
- %cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c4_i32 = arith.constant 4 : i32
- %inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
- return %inserted : tensor<2x2xi32>
-}
-
// -----
// CHECK-LABEL: func @extract_from_tensor.cast
diff --git a/mlir/test/Dialect/Tensor/insert-after-constant.mlir b/mlir/test/Dialect/Tensor/insert-after-constant.mlir
new file mode 100644
index 0000000000000..73f49ac6eba78
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/insert-after-constant.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-insert-after-constant %s | FileCheck %s
+
+func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
+ // Fold an insert into a splat.
+ // CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
+ // CHECK-LITERAL:
+ // CHECK-NEXT: return %[[C4]]
+ %cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4_i32 = arith.constant 4 : i32
+ %inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
+ return %inserted : tensor<2x2xi32>
+}
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 0e191c32f009e..f2750c5e9a0de 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -82,6 +82,11 @@ struct TestTensorTransforms
llvm::cl::desc("Test folding of extract from collapse_shape"),
llvm::cl::init(false)};
+ Option<bool> testFoldInsertAfterConstant{
+ *this, "test-fold-insert-after-constant",
+ llvm::cl::desc("Test folding of insert of a constant"),
+ llvm::cl::init(false)};
+
Option<bool> useForeach{
*this, "use-foreach",
llvm::cl::desc(
@@ -143,6 +148,12 @@ static void applyFoldExtractFromCollapseShapePatterns(Operation *rootOp) {
(void)applyPatternsGreedily(rootOp, std::move(patterns));
}
+static void applyFoldInsertAfterConstantPattern(Operation *rootOp) {
+ RewritePatternSet patterns(rootOp->getContext());
+ tensor::populateFoldInsertAfterConstant(patterns);
+ (void)applyPatternsGreedily(rootOp, std::move(patterns));
+}
+
namespace {
/// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`.
/// The `tensor.extract_slice` is replaced by a loop or gather operation that
@@ -393,6 +404,8 @@ void TestTensorTransforms::runOnOperation() {
}
if (testFoldExtractFromCollapseShape)
applyFoldExtractFromCollapseShapePatterns(rootOp);
+ if (testFoldInsertAfterConstant)
+ applyFoldInsertAfterConstantPattern(rootOp);
if (testTrackingListener)
if (failed(testTrackingListenerReplacements(rootOp)))
return signalPassFailure();
More information about the Mlir-commits
mailing list