[Mlir-commits] [mlir] [mlir][tensor] move tensor.insert constant folding out of canonicalization (PR #142671)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 5 08:28:04 PDT 2025
https://github.com/asraa updated https://github.com/llvm/llvm-project/pull/142671
>From 872002679288e4eadb52042059159d7a3c1854f2 Mon Sep 17 00:00:00 2001
From: Asra Ali <asraa at google.com>
Date: Tue, 3 Jun 2025 20:54:49 +0000
Subject: [PATCH] [mlir][tensor] move tensor insert canonicalization to pattern
Signed-off-by: Asra Ali <asraa at google.com>
fix
Signed-off-by: Asra Ali <asraa at google.com>
fix loop
Signed-off-by: Asra Ali <asraa at google.com>
remove pattern
Signed-off-by: Asra Ali <asraa at google.com>
---
.../mlir/Dialect/Tensor/IR/TensorOps.td | 1 -
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 75 -------------------
mlir/test/Dialect/Tensor/canonicalize.mlir | 16 ----
3 files changed, 92 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index c0885a3763827..35d0b16628417 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -827,7 +827,6 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
let hasFolder = 1;
let hasVerifier = 1;
- let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 12e8b257ce9f1..6e67377ddb6e8 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1624,76 +1624,6 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
// InsertOp
//===----------------------------------------------------------------------===//
-namespace {
-
-/// Pattern to fold an insert op of a constant destination and scalar to a new
-/// constant.
-///
-/// Example:
-/// ```
-/// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
-/// %c0 = arith.constant 0 : index
-/// %c4_f32 = arith.constant 4.0 : f32
-/// %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32>
-/// ```
-/// is rewritten into:
-/// ```
-/// %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
-/// ```
-class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
-public:
- using OpRewritePattern<InsertOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(InsertOp insertOp,
- PatternRewriter &rewriter) const override {
- // Requires a ranked tensor type.
- auto destType =
- llvm::dyn_cast<RankedTensorType>(insertOp.getDest().getType());
- if (!destType)
- return failure();
-
- // Pattern requires constant indices
- SmallVector<uint64_t, 8> indices;
- for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) {
- auto indiceAttr = dyn_cast<Attribute>(indice);
- if (!indiceAttr)
- return failure();
- indices.push_back(llvm::cast<IntegerAttr>(indiceAttr).getInt());
- }
-
- // Requires a constant scalar to insert
- OpFoldResult scalar = getAsOpFoldResult(insertOp.getScalar());
- Attribute scalarAttr = dyn_cast<Attribute>(scalar);
- if (!scalarAttr)
- return failure();
-
- if (auto constantOp = dyn_cast_or_null<arith::ConstantOp>(
- insertOp.getDest().getDefiningOp())) {
- if (auto sourceAttr =
- llvm::dyn_cast<ElementsAttr>(constantOp.getValue())) {
- // Update the attribute at the inserted index.
- auto sourceValues = sourceAttr.getValues<Attribute>();
- auto flattenedIndex = sourceAttr.getFlattenedIndex(indices);
- std::vector<Attribute> updatedValues;
- updatedValues.reserve(sourceAttr.getNumElements());
- for (unsigned i = 0; i < sourceAttr.getNumElements(); ++i) {
- updatedValues.push_back(i == flattenedIndex ? scalarAttr
- : sourceValues[i]);
- }
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(
- insertOp, sourceAttr.getType(),
- DenseElementsAttr::get(cast<ShapedType>(sourceAttr.getType()),
- updatedValues));
- return success();
- }
- }
-
- return failure();
- }
-};
-
-} // namespace
-
void InsertOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "inserted");
@@ -1717,11 +1647,6 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
return {};
}
-void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<InsertOpConstantFold>(context);
-}
-
//===----------------------------------------------------------------------===//
// GenerateOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 646b2197d9aa6..f033a43c0dc24 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -231,22 +231,6 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
return %ins_1 : tensor<4xf32>
}
-
-// -----
-
-func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
- // Fold an insert into a splat.
- // CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
- // CHECK-LITERAL:
- // CHECK-NEXT: return %[[C4]]
- %cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c4_i32 = arith.constant 4 : i32
- %inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
- return %inserted : tensor<2x2xi32>
-}
-
// -----
// CHECK-LABEL: func @extract_from_tensor.cast
More information about the Mlir-commits
mailing list