[Mlir-commits] [mlir] [mlir][bufferization] Enable moving dependent values in eliminate-empty-tensors (PR #169718)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 26 11:51:11 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-bufferization
Author: Quinn Dawkins (qedawkins)
<details>
<summary>Changes</summary>
Currently empty tensor elimination by constructing a SubsetExtractionOp to match a SubsetInsertionOp at the end of a DPS chain will fail if any operands required by the insertion op don't dominate the insertion point for the extraction op.
This change improves the transformation by attempting to move all pure producers of required operands to the insertion point of the extraction op. In the process this improves a number of tests for empty tensor elimination.
---
Full diff: https://github.com/llvm/llvm-project/pull/169718.diff
4 Files Affected:
- (modified) mlir/include/mlir/Transforms/RegionUtils.h (+6-2)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp (+10-2)
- (modified) mlir/lib/Transforms/Utils/RegionUtils.cpp (+17-4)
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir (+80-19)
``````````diff
diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h
index 2ed96afbace81..6a0c94b06c6b2 100644
--- a/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/mlir/include/mlir/Transforms/RegionUtils.h
@@ -85,11 +85,15 @@ LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
/// only for movement of definitions within the same basic block. Note that this
/// is an all-or-nothing approach. Either definitions of all values are moved
/// before insertion point, or none of them are.
+/// If `ignoreSideEffect` is set, this will allow movement of all dependent
+/// producers regardless of whether they are side effecting.
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
Operation *insertionPoint,
- DominanceInfo &dominance);
+ DominanceInfo &dominance,
+ bool ignoreSideEffects = true);
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
- Operation *insertionPoint);
+ Operation *insertionPoint,
+ bool ignoreSideEffects = true);
/// Run a set of structural simplifications over the given regions. This
/// includes transformations like unreachable block elimination, dead argument
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 1784964cf9b95..0843b4398b24f 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
+#include "mlir/Transforms/RegionUtils.h"
namespace mlir {
namespace bufferization {
@@ -105,8 +106,15 @@ Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter,
// this replacement.
Operation *insertionPoint =
findValidInsertionPoint(emptyTensorOp, user, neededValues);
- if (!insertionPoint)
- return {};
+ if (!insertionPoint) {
+ // If no already suitable insertion point was found, attempt to move all
+ // needed values before the user.
+ if (failed(moveValueDefinitions(rewriter, neededValues, user,
+ /*ignoreSideEffects=*/false))) {
+ return {};
+ }
+ insertionPoint = user;
+ }
rewriter.setInsertionPoint(insertionPoint);
Value replacement =
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 31ae1d1895b81..390fc76cc6533 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -1145,7 +1145,8 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
ValueRange values,
Operation *insertionPoint,
- DominanceInfo &dominance) {
+ DominanceInfo &dominance,
+ bool ignoreSideEffects) {
// Remove the values that already dominate the insertion point.
SmallVector<Value> prunedValues;
for (auto value : values) {
@@ -1178,8 +1179,14 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
// Since current support is to only move within a same basic block,
// the slices dont need to look past block arguments.
options.omitBlockArguments = true;
+ bool dependsOnSideEffectingOp = false;
options.filter = [&](Operation *sliceBoundaryOp) {
- return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
+ bool mustMove =
+ !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
+ if (mustMove && !isPure(sliceBoundaryOp)) {
+ dependsOnSideEffectingOp = true;
+ }
+ return mustMove;
};
llvm::SetVector<Operation *> slice;
for (auto value : prunedValues) {
@@ -1188,6 +1195,10 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
(void)result;
}
+ // Check if any operation in the slice is side-effecting.
+ if (!ignoreSideEffects && dependsOnSideEffectingOp)
+ return failure();
+
// If the slice contains `insertionPoint` cannot move the dependencies.
if (slice.contains(insertionPoint)) {
return rewriter.notifyMatchFailure(
@@ -1206,7 +1217,9 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
ValueRange values,
- Operation *insertionPoint) {
+ Operation *insertionPoint,
+ bool ignoreSideEffects) {
DominanceInfo dominance(insertionPoint);
- return moveValueDefinitions(rewriter, values, insertionPoint, dominance);
+ return moveValueDefinitions(rewriter, values, insertionPoint, dominance,
+ ignoreSideEffects);
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 8249d59b2374e..3929f5be3b4ef 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -368,21 +368,18 @@ func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32
// -----
-// `EmptyTensorElimination` fails to find a valid insertion
-// point for the new injected `SubsetExtraction`.
-// CHECK-LABEL: func.func @fail_to_eliminate_any_empty_tensors
-func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
+// CHECK-LABEL: func.func @eliminate_all_empty_tensors
+func.func @eliminate_all_empty_tensors() -> tensor<5x6x128xf32> {
%cst_1 = arith.constant 1.0 : f32
%cst_2 = arith.constant 2.0 : f32
- // CHECK: memref.alloc
- // CHECK: memref.alloc
- // CHECK: memref.alloc
+ // CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
+ // CHECK-NOT: memref.alloc
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
%empty_2 = tensor.empty() : tensor<5x6x64xf32>
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
- // CHECK: memref.copy
+ // CHECK-NOT: memref.copy
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
@@ -392,20 +389,19 @@ func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
// -----
-// CHECK-LABEL: func.func @succeed_to_eliminate_one_empty_tensor
-func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
+// CHECK-LABEL: func.func @eliminate_concatenated_empty_tensors
+func.func @eliminate_concatenated_empty_tensors() -> tensor<5x6x128xf32> {
%cst_1 = arith.constant 1.0 : f32
%cst_2 = arith.constant 2.0 : f32
// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
- // CHECK: memref.alloc
// CHECK-NOT: memref.alloc
- %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
+ %concatenated_empty = tensor.empty() : tensor<5x6x128xf32>
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
%empty_2 = tensor.empty() : tensor<5x6x64xf32>
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
- // CHECK: memref.copy
- %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
+ // CHECK-NOT: memref.copy
+ %inserted_slice_1 = tensor.insert_slice %res_1 into %concatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
@@ -420,20 +416,22 @@ func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
// CHECK-ELIM-LABEL: func.func @multi_use_of_the_same_tensor_empty
// CHECK-LABEL: func.func @multi_use_of_the_same_tensor_empty
+// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
+// CHECK-NOT: memref.alloc
+// CHECK-NOT: memref.copy
+// CHECK-ELIM: tensor.extract_slice {{.*}}[0, 0, 0]
+// CHECK-ELIM: linalg.fill
+// CHECK-ELIM: tensor.extract_slice {{.*}}[0, 0, 64]
+// CHECK-ELIM: linalg.fill
func.func @multi_use_of_the_same_tensor_empty() -> tensor<5x6x128xf32> {
%cst_1 = arith.constant 1.0 : f32
%cst_2 = arith.constant 2.0 : f32
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
- // CHECK-ELIM: %[[VAL_3:.*]] = tensor.extract_slice
- // CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_3]]
- // CHECK-ELIM-NOT: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]]
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
- // CHECK: memref.copy
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
- // CHECK-NOT: memref.copy
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
return %inserted_slice_2 : tensor<5x6x128xf32>
@@ -476,3 +474,66 @@ func.func @direct_use_of_tensor_empty(%arg0: tensor<5x6x128xf32>) -> tensor<5x6x
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
return %inserted_slice_1 : tensor<5x6x128xf32>
}
+
+// -----
+
+// Test that dependent pure operations are moved before the
+// insertion point to enable empty tensor elimination.
+
+// CHECK-LABEL: func.func @move_dependent_arith_op(
+// CHECK-SAME: %[[ARG0:.*]]: memref<10xf32>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-NOT: memref.alloc
+// CHECK: %[[C5:.*]] = arith.constant 5 : index
+// CHECK: %[[OFFSET:.*]] = arith.addi %[[ARG1]], %[[C5]]
+// CHECK: %[[SV:.*]] = memref.subview %[[ARG0]][%[[OFFSET]]] [5] [1]
+// CHECK: linalg.fill {{.*}} outs(%[[SV]]
+// CHECK: return %[[ARG0]]
+// CHECK-ELIM-LABEL: func.func @move_dependent_arith_op(
+// CHECK-ELIM-SAME: %[[ARG0:.*]]: tensor<10xf32>
+// CHECK-ELIM-SAME: %[[ARG1:.*]]: index
+// CHECK-ELIM: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-ELIM: %[[OFFSET:.*]] = arith.addi %[[ARG1]], %[[C5]]
+// CHECK-ELIM: %[[SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[OFFSET]]] [5] [1]
+// CHECK-ELIM: %[[FILL:.*]] = linalg.fill {{.*}} outs(%[[SLICE]]
+// CHECK-ELIM: tensor.insert_slice %[[FILL]] into %[[ARG0]][%[[OFFSET]]]
+func.func @move_dependent_arith_op(
+ %arg0: tensor<10xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true},
+ %arg1: index, %f: f32) -> tensor<10xf32>
+{
+ %0 = tensor.empty() : tensor<5xf32>
+ %1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+ %c5 = arith.constant 5 : index
+ %offset = arith.addi %arg1, %c5 : index
+ %2 = tensor.insert_slice %1 into %arg0[%offset][5][1]
+ : tensor<5xf32> into tensor<10xf32>
+ return %2 : tensor<10xf32>
+}
+
+// -----
+
+// Test that side-effecting operations are not moved, preventing empty
+// tensor elimination.
+
+// CHECK-LABEL: func.func @side_effecting_op_blocks_movement(
+// CHECK: memref.alloc
+// CHECK: linalg.fill
+// CHECK: memref.load
+// CHECK: memref.subview
+// CHECK: memref.copy
+// CHECK-ELIM-LABEL: func.func @side_effecting_op_blocks_movement(
+// CHECK-ELIM: tensor.empty
+// CHECK-ELIM: linalg.fill
+// CHECK-ELIM: memref.load
+// CHECK-ELIM: tensor.insert_slice
+func.func @side_effecting_op_blocks_movement(
+ %arg0: tensor<10xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true},
+ %mem: memref<index>, %f: f32) -> tensor<10xf32>
+{
+ %0 = tensor.empty() : tensor<5xf32>
+ %1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+ %offset = memref.load %mem[] : memref<index>
+ %2 = tensor.insert_slice %1 into %arg0[%offset][5][1]
+ : tensor<5xf32> into tensor<10xf32>
+ return %2 : tensor<10xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/169718
More information about the Mlir-commits
mailing list