[Mlir-commits] [mlir] [mlir][tensor] move tensor insert canonicalization to pattern (PR #142671)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 3 14:50:00 PDT 2025


https://github.com/asraa updated https://github.com/llvm/llvm-project/pull/142671

>From a82d7237d91ce007eb6bb1958502663275ea3113 Mon Sep 17 00:00:00 2001
From: Asra Ali <asraa at google.com>
Date: Tue, 3 Jun 2025 20:54:49 +0000
Subject: [PATCH] [mlir][tensor] move tensor insert canonicalization to pattern

Signed-off-by: Asra Ali <asraa at google.com>

fix

Signed-off-by: Asra Ali <asraa at google.com>
---
 mlir/include/mlir/Dialect/Tensor/IR/Tensor.h     |  3 +++
 mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td |  1 -
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp         |  7 +++----
 mlir/test/Dialect/Tensor/canonicalize.mlir       | 16 ----------------
 .../Dialect/Tensor/insert-after-constant.mlir    | 14 ++++++++++++++
 .../lib/Dialect/Tensor/TestTensorTransforms.cpp  | 13 +++++++++++++
 6 files changed, 33 insertions(+), 21 deletions(-)
 create mode 100644 mlir/test/Dialect/Tensor/insert-after-constant.mlir

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index e8e1342ef36fd..447f5b906cad1 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -176,6 +176,9 @@ void populateFoldConstantExtractSlicePatterns(
           return false;
         });
 
+/// Patterns to fold inserts into a constant into a new constant.
+void populateFoldInsertAfterConstant(RewritePatternSet &patterns);
+
 /// Patterns to fold extracts of a collapse_shaped tensor to an extract of the
 /// source tensor.
 void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index c0885a3763827..35d0b16628417 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -827,7 +827,6 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
 
   let hasFolder = 1;
   let hasVerifier = 1;
-  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 12e8b257ce9f1..b30f84e6be724 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1653,7 +1653,7 @@ class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
       return failure();
 
     // Pattern requires constant indices
-    SmallVector<uint64_t, 8> indices;
+    SmallVector<uint64_t> indices;
     for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) {
       auto indiceAttr = dyn_cast<Attribute>(indice);
       if (!indiceAttr)
@@ -1717,9 +1717,8 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
-void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                           MLIRContext *context) {
-  results.add<InsertOpConstantFold>(context);
+void tensor::populateFoldInsertAfterConstant(RewritePatternSet &patterns) {
+  patterns.add<InsertOpConstantFold>(patterns.getContext());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 646b2197d9aa6..f033a43c0dc24 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -231,22 +231,6 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
   return %ins_1 : tensor<4xf32>
 }
 
-
-// -----
-
-func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
-  // Fold an insert into a splat.
-  // CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
-  // CHECK-LITERAL:
-  // CHECK-NEXT: return %[[C4]]
-  %cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c4_i32 = arith.constant 4 : i32
-  %inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
-  return %inserted : tensor<2x2xi32>
-}
-
 // -----
 
 // CHECK-LABEL: func @extract_from_tensor.cast
diff --git a/mlir/test/Dialect/Tensor/insert-after-constant.mlir b/mlir/test/Dialect/Tensor/insert-after-constant.mlir
new file mode 100644
index 0000000000000..73f49ac6eba78
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/insert-after-constant.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-insert-after-constant %s | FileCheck %s
+
+func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
+  // Fold an insert into a splat.
+  // CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
+  // CHECK-LITERAL:
+  // CHECK-NEXT: return %[[C4]]
+  %cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4_i32 = arith.constant 4 : i32
+  %inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
+  return %inserted : tensor<2x2xi32>
+}
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 0e191c32f009e..f2750c5e9a0de 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -82,6 +82,11 @@ struct TestTensorTransforms
       llvm::cl::desc("Test folding of extract from collapse_shape"),
       llvm::cl::init(false)};
 
+  Option<bool> testFoldInsertAfterConstant{
+      *this, "test-fold-insert-after-constant",
+      llvm::cl::desc("Test folding of insert of a constant"),
+      llvm::cl::init(false)};
+
   Option<bool> useForeach{
       *this, "use-foreach",
       llvm::cl::desc(
@@ -143,6 +148,12 @@ static void applyFoldExtractFromCollapseShapePatterns(Operation *rootOp) {
   (void)applyPatternsGreedily(rootOp, std::move(patterns));
 }
 
+static void applyFoldInsertAfterConstantPattern(Operation *rootOp) {
+  RewritePatternSet patterns(rootOp->getContext());
+  tensor::populateFoldInsertAfterConstant(patterns);
+  (void)applyPatternsGreedily(rootOp, std::move(patterns));
+}
+
 namespace {
 /// Base pattern to rewrite  a `tensor.collapse_shape -> tensor.extract_slice`.
 /// The `tensor.extract_slice` is replaced by a loop or gather operation that
@@ -393,6 +404,8 @@ void TestTensorTransforms::runOnOperation() {
   }
   if (testFoldExtractFromCollapseShape)
     applyFoldExtractFromCollapseShapePatterns(rootOp);
+  if (testFoldInsertAfterConstant)
+    applyFoldInsertAfterConstantPattern(rootOp);
   if (testTrackingListener)
     if (failed(testTrackingListenerReplacements(rootOp)))
       return signalPassFailure();



More information about the Mlir-commits mailing list