[Mlir-commits] [mlir] [mlir][tensor] move tensor.insert constant folding out of canonicalization (PR #142671)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 5 08:28:04 PDT 2025


https://github.com/asraa updated https://github.com/llvm/llvm-project/pull/142671

>From 872002679288e4eadb52042059159d7a3c1854f2 Mon Sep 17 00:00:00 2001
From: Asra Ali <asraa at google.com>
Date: Tue, 3 Jun 2025 20:54:49 +0000
Subject: [PATCH] [mlir][tensor] move tensor insert canonicalization to pattern

Signed-off-by: Asra Ali <asraa at google.com>

fix

Signed-off-by: Asra Ali <asraa at google.com>

fix loop

Signed-off-by: Asra Ali <asraa at google.com>

remove pattern

Signed-off-by: Asra Ali <asraa at google.com>
---
 .../mlir/Dialect/Tensor/IR/TensorOps.td       |  1 -
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 75 -------------------
 mlir/test/Dialect/Tensor/canonicalize.mlir    | 16 ----
 3 files changed, 92 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index c0885a3763827..35d0b16628417 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -827,7 +827,6 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
 
   let hasFolder = 1;
   let hasVerifier = 1;
-  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 12e8b257ce9f1..6e67377ddb6e8 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1624,76 +1624,6 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
 // InsertOp
 //===----------------------------------------------------------------------===//
 
-namespace {
-
-/// Pattern to fold an insert op of a constant destination and scalar to a new
-/// constant.
-///
-/// Example:
-/// ```
-///   %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
-///   %c0 = arith.constant 0 : index
-///   %c4_f32 = arith.constant 4.0 : f32
-///   %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32>
-/// ```
-/// is rewritten into:
-/// ```
-///   %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
-/// ```
-class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
-public:
-  using OpRewritePattern<InsertOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(InsertOp insertOp,
-                                PatternRewriter &rewriter) const override {
-    // Requires a ranked tensor type.
-    auto destType =
-        llvm::dyn_cast<RankedTensorType>(insertOp.getDest().getType());
-    if (!destType)
-      return failure();
-
-    // Pattern requires constant indices
-    SmallVector<uint64_t, 8> indices;
-    for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) {
-      auto indiceAttr = dyn_cast<Attribute>(indice);
-      if (!indiceAttr)
-        return failure();
-      indices.push_back(llvm::cast<IntegerAttr>(indiceAttr).getInt());
-    }
-
-    // Requires a constant scalar to insert
-    OpFoldResult scalar = getAsOpFoldResult(insertOp.getScalar());
-    Attribute scalarAttr = dyn_cast<Attribute>(scalar);
-    if (!scalarAttr)
-      return failure();
-
-    if (auto constantOp = dyn_cast_or_null<arith::ConstantOp>(
-            insertOp.getDest().getDefiningOp())) {
-      if (auto sourceAttr =
-              llvm::dyn_cast<ElementsAttr>(constantOp.getValue())) {
-        // Update the attribute at the inserted index.
-        auto sourceValues = sourceAttr.getValues<Attribute>();
-        auto flattenedIndex = sourceAttr.getFlattenedIndex(indices);
-        std::vector<Attribute> updatedValues;
-        updatedValues.reserve(sourceAttr.getNumElements());
-        for (unsigned i = 0; i < sourceAttr.getNumElements(); ++i) {
-          updatedValues.push_back(i == flattenedIndex ? scalarAttr
-                                                      : sourceValues[i]);
-        }
-        rewriter.replaceOpWithNewOp<arith::ConstantOp>(
-            insertOp, sourceAttr.getType(),
-            DenseElementsAttr::get(cast<ShapedType>(sourceAttr.getType()),
-                                   updatedValues));
-        return success();
-      }
-    }
-
-    return failure();
-  }
-};
-
-} // namespace
-
 void InsertOp::getAsmResultNames(
     function_ref<void(Value, StringRef)> setNameFn) {
   setNameFn(getResult(), "inserted");
@@ -1717,11 +1647,6 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
-void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                           MLIRContext *context) {
-  results.add<InsertOpConstantFold>(context);
-}
-
 //===----------------------------------------------------------------------===//
 // GenerateOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 646b2197d9aa6..f033a43c0dc24 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -231,22 +231,6 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
   return %ins_1 : tensor<4xf32>
 }
 
-
-// -----
-
-func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
-  // Fold an insert into a splat.
-  // CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
-  // CHECK-LITERAL:
-  // CHECK-NEXT: return %[[C4]]
-  %cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c4_i32 = arith.constant 4 : i32
-  %inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
-  return %inserted : tensor<2x2xi32>
-}
-
 // -----
 
 // CHECK-LABEL: func @extract_from_tensor.cast



More information about the Mlir-commits mailing list