[Mlir-commits] [mlir] c66b72f - [mlir][tensor] remove tensor.insert constant folding out of canonicalization (#142671)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 5 14:53:36 PDT 2025
Author: asraa
Date: 2025-06-05T14:53:33-07:00
New Revision: c66b72f8ce4d12b6fa12f9b359b114fae5d2dcff
URL: https://github.com/llvm/llvm-project/commit/c66b72f8ce4d12b6fa12f9b359b114fae5d2dcff
DIFF: https://github.com/llvm/llvm-project/commit/c66b72f8ce4d12b6fa12f9b359b114fae5d2dcff.diff
LOG: [mlir][tensor] remove tensor.insert constant folding out of canonicalization (#142671)
Follow ups from https://github.com/llvm/llvm-project/pull/142458/
In particular concerns that indiscriminately folding tensor constants
can lead to bloating the IR as these can be arbitrarily large.
Signed-off-by: Asra Ali <asraa at google.com>
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index c0885a3763827..35d0b16628417 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -827,7 +827,6 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
let hasFolder = 1;
let hasVerifier = 1;
- let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 12e8b257ce9f1..6e67377ddb6e8 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1624,76 +1624,6 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
// InsertOp
//===----------------------------------------------------------------------===//
-namespace {
-
-/// Pattern to fold an insert op of a constant destination and scalar to a new
-/// constant.
-///
-/// Example:
-/// ```
-/// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
-/// %c0 = arith.constant 0 : index
-/// %c4_f32 = arith.constant 4.0 : f32
-/// %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32>
-/// ```
-/// is rewritten into:
-/// ```
-/// %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
-/// ```
-class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
-public:
- using OpRewritePattern<InsertOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(InsertOp insertOp,
- PatternRewriter &rewriter) const override {
- // Requires a ranked tensor type.
- auto destType =
- llvm::dyn_cast<RankedTensorType>(insertOp.getDest().getType());
- if (!destType)
- return failure();
-
- // Pattern requires constant indices
- SmallVector<uint64_t, 8> indices;
- for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) {
- auto indiceAttr = dyn_cast<Attribute>(indice);
- if (!indiceAttr)
- return failure();
- indices.push_back(llvm::cast<IntegerAttr>(indiceAttr).getInt());
- }
-
- // Requires a constant scalar to insert
- OpFoldResult scalar = getAsOpFoldResult(insertOp.getScalar());
- Attribute scalarAttr = dyn_cast<Attribute>(scalar);
- if (!scalarAttr)
- return failure();
-
- if (auto constantOp = dyn_cast_or_null<arith::ConstantOp>(
- insertOp.getDest().getDefiningOp())) {
- if (auto sourceAttr =
- llvm::dyn_cast<ElementsAttr>(constantOp.getValue())) {
- // Update the attribute at the inserted index.
- auto sourceValues = sourceAttr.getValues<Attribute>();
- auto flattenedIndex = sourceAttr.getFlattenedIndex(indices);
- std::vector<Attribute> updatedValues;
- updatedValues.reserve(sourceAttr.getNumElements());
- for (unsigned i = 0; i < sourceAttr.getNumElements(); ++i) {
- updatedValues.push_back(i == flattenedIndex ? scalarAttr
- : sourceValues[i]);
- }
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(
- insertOp, sourceAttr.getType(),
- DenseElementsAttr::get(cast<ShapedType>(sourceAttr.getType()),
- updatedValues));
- return success();
- }
- }
-
- return failure();
- }
-};
-
-} // namespace
-
void InsertOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "inserted");
@@ -1717,11 +1647,6 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
return {};
}
-void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<InsertOpConstantFold>(context);
-}
-
//===----------------------------------------------------------------------===//
// GenerateOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 0abec7e01d184..65c5b3e8602eb 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -231,22 +231,6 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
return %ins_1 : tensor<4xf32>
}
-
-// -----
-
-func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
- // Fold an insert into a splat.
- // CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
- // CHECK-LITERAL:
- // CHECK-NEXT: return %[[C4]]
- %cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c4_i32 = arith.constant 4 : i32
- %inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
- return %inserted : tensor<2x2xi32>
-}
-
// -----
// CHECK-LABEL: func @extract_from_tensor.cast
More information about the Mlir-commits
mailing list