[Mlir-commits] [mlir] [Linalg] Fix bug in control function logic of push down extract pattern (PR #158348)
Nirvedh Meshram
llvmlistbot at llvm.org
Fri Sep 12 12:08:15 PDT 2025
https://github.com/nirvedhmeshram created https://github.com/llvm/llvm-project/pull/158348
Current logic just bails out if the first extract producer fails the control function, this PR fixes that.
>From 430cbb491da435c2c505fc1387395f5c83117f79 Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Fri, 12 Sep 2025 12:05:53 -0700
Subject: [PATCH] [Linalg] Fix bug in control function logic of push down
extract pattern
Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
---
.../Transforms/DataLayoutPropagation.cpp | 36 ++++++++++++-------
.../Linalg/data-layout-propagation.mlir | 30 ++++++++++++++++
.../Linalg/TestDataLayoutPropagation.cpp | 9 +++--
3 files changed, 60 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index ed2efd6fea5f7..6c17c3c2d0cab 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -1245,21 +1245,21 @@ struct SliceDimInfo {
OpFoldResult outputSize;
};
-/// Return the first input extract slice operand, if present, for the current
+/// Return all extract slice operands, if present, for the current
/// generic op.
-static FailureOr<OpOperand *> getSliceOperand(GenericOp genericOp) {
- OpOperand *sliceOperand = nullptr;
+static FailureOr<SmallVector<OpOperand *>>
+getSliceOperands(GenericOp genericOp) {
+ SmallVector<OpOperand *> sliceOperands;
for (auto operand : genericOp.getDpsInputOperands()) {
auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
if (!extractOp)
continue;
- sliceOperand = operand;
- break;
+ sliceOperands.push_back(operand);
}
- if (!sliceOperand) {
+ if (sliceOperands.empty()) {
return failure();
}
- return sliceOperand;
+ return sliceOperands;
}
// Return a map of dims that have partial slices on them so that other operands
@@ -1336,14 +1336,24 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
genericOp,
"propagation through generic with gather semantics is unsupported.");
// Collect the sliced operand, if present.
- auto maybeSliceOperand = getSliceOperand(genericOp);
- if (failed(maybeSliceOperand))
+ auto maybeSliceOperands = getSliceOperands(genericOp);
+ if (failed(maybeSliceOperands))
return failure();
- OpOperand *sliceOperand = *maybeSliceOperand;
- unsigned OperandIndex = sliceOperand->getOperandNumber();
-
- if (!controlFn(sliceOperand))
+ SmallVector<OpOperand *> sliceOperands = *maybeSliceOperands;
+ OpOperand *sliceOperand;
+
+ bool foundValidOperand = false;
+ for (auto currSliceOperand : sliceOperands) {
+ if (controlFn(currSliceOperand)) {
+ sliceOperand = currSliceOperand;
+ foundValidOperand = true;
+ break;
+ }
+ }
+ if (!foundValidOperand) {
return failure();
+ }
+ unsigned OperandIndex = sliceOperand->getOperandNumber();
tensor::ExtractSliceOp producerSliceOp =
sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index fb16e1e7dcda4..a5f8d63a3e912 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1577,3 +1577,33 @@ func.func @push_extract_through_generic_rank0_operand(%arg0: tensor<128x128xf32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]]
// CHECK: return %[[EXTRACT]]
+
+// -----
+// Test that if one extract doesnt pass the control function which in this case is set to
+// only allow extracts from the same block, then an extract from a later operand can still be pushed
+// down.
+func.func @push_extract_through_generic_secondextract(%arg0: tensor<128x128xf32>, %arg1: tensor<?x?xbf16>, %arg2: index) -> tensor<?x?xbf16> {
+ %c0 = arith.constant 0 : index
+ %c32 = arith.constant 32 : index
+ %extracted_slice1 = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+ %for = scf.for %arg3 = %c0 to %c32 step %arg2 iter_args(%arg4 = %arg1) -> tensor<?x?xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d0, d1)> ,affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice1, %extracted_slice : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg1 : tensor<?x?xbf16>) {
+ ^bb0(%in: f32, %in_1 : f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ linalg.yield %1 : bf16
+ } -> tensor<?x?xbf16>
+ scf.yield %0 : tensor<?x?xbf16>
+ }
+ return %for : tensor<?x?xbf16>
+}
+
+// CHECK-LABEL: func.func @push_extract_through_generic_secondextract
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice
+// CHECK: %[[FOR:.+]] = scf.for
+// CHECK: %[[PAD:.+]] = tensor.pad %[[EXTRACT]]
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[PAD]], %[[ARG0]]
+// CHECK: %[[EXTRACT2:.+]] = tensor.extract_slice %[[GENERIC]]
+// CHECK: scf.yield %[[EXTRACT2]]
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index 2cf25d8fc8c19..d332270468ea8 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -34,8 +34,13 @@ struct TestDataLayoutPropagationPass
RewritePatternSet patterns(context);
linalg::populateDataLayoutPropagationPatterns(
patterns, [](OpOperand *opOperand) { return true; });
- linalg::populateExtractSliceSinkingPatterns(
- patterns, [](OpOperand *opOperand) { return true; });
+ linalg::ControlPropagationFn controlExtract =
+ [](OpOperand *opOperand) -> bool {
+ Operation *producer = opOperand->get().getDefiningOp();
+ Operation *consumer = opOperand->getOwner();
+ return consumer->getBlock() == producer->getBlock();
+ };
+ linalg::populateExtractSliceSinkingPatterns(patterns, controlExtract);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
More information about the Mlir-commits
mailing list