[Mlir-commits] [mlir] [Linalg] Fix bug in control function logic of push down extract pattern (PR #158348)

Nirvedh Meshram llvmlistbot at llvm.org
Fri Sep 12 12:08:15 PDT 2025


https://github.com/nirvedhmeshram created https://github.com/llvm/llvm-project/pull/158348

Current logic just bails out if the first extract producer fails the control function, this PR fixes that.

>From 430cbb491da435c2c505fc1387395f5c83117f79 Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Fri, 12 Sep 2025 12:05:53 -0700
Subject: [PATCH] [Linalg] Fix bug in control function logic of push down
 extract pattern

Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
---
 .../Transforms/DataLayoutPropagation.cpp      | 36 ++++++++++++-------
 .../Linalg/data-layout-propagation.mlir       | 30 ++++++++++++++++
 .../Linalg/TestDataLayoutPropagation.cpp      |  9 +++--
 3 files changed, 60 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index ed2efd6fea5f7..6c17c3c2d0cab 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -1245,21 +1245,21 @@ struct SliceDimInfo {
   OpFoldResult outputSize;
 };
 
-/// Return the first input extract slice operand, if present, for the current
+/// Return all extract slice operands, if present, for the current
 /// generic op.
-static FailureOr<OpOperand *> getSliceOperand(GenericOp genericOp) {
-  OpOperand *sliceOperand = nullptr;
+static FailureOr<SmallVector<OpOperand *>>
+getSliceOperands(GenericOp genericOp) {
+  SmallVector<OpOperand *> sliceOperands;
   for (auto operand : genericOp.getDpsInputOperands()) {
     auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
     if (!extractOp)
       continue;
-    sliceOperand = operand;
-    break;
+    sliceOperands.push_back(operand);
   }
-  if (!sliceOperand) {
+  if (sliceOperands.empty()) {
     return failure();
   }
-  return sliceOperand;
+  return sliceOperands;
 }
 
 // Return a map of dims that have partial slices on them so that other operands
@@ -1336,14 +1336,24 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
         genericOp,
         "propagation through generic with gather semantics is unsupported.");
   // Collect the sliced operand, if present.
-  auto maybeSliceOperand = getSliceOperand(genericOp);
-  if (failed(maybeSliceOperand))
+  auto maybeSliceOperands = getSliceOperands(genericOp);
+  if (failed(maybeSliceOperands))
     return failure();
-  OpOperand *sliceOperand = *maybeSliceOperand;
-  unsigned OperandIndex = sliceOperand->getOperandNumber();
-
-  if (!controlFn(sliceOperand))
+  SmallVector<OpOperand *> sliceOperands = *maybeSliceOperands;
+  OpOperand *sliceOperand;
+
+  bool foundValidOperand = false;
+  for (auto currSliceOperand : sliceOperands) {
+    if (controlFn(currSliceOperand)) {
+      sliceOperand = currSliceOperand;
+      foundValidOperand = true;
+      break;
+    }
+  }
+  if (!foundValidOperand) {
     return failure();
+  }
+  unsigned OperandIndex = sliceOperand->getOperandNumber();
 
   tensor::ExtractSliceOp producerSliceOp =
       sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index fb16e1e7dcda4..a5f8d63a3e912 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1577,3 +1577,33 @@ func.func @push_extract_through_generic_rank0_operand(%arg0: tensor<128x128xf32>
 // CHECK:         %[[GENERIC:.+]] = linalg.generic
 // CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]]         
 // CHECK:         return %[[EXTRACT]]
+
+// -----
+// Test that if one extract doesnt pass the control function which in this case is set to
+// only allow extracts from the same block, then an extract from a later operand can still be pushed
+// down.
+func.func @push_extract_through_generic_secondextract(%arg0: tensor<128x128xf32>, %arg1: tensor<?x?xbf16>, %arg2: index) -> tensor<?x?xbf16> {
+  %c0 = arith.constant 0 : index
+  %c32 = arith.constant 32 : index
+  %extracted_slice1 = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+  %for = scf.for %arg3 = %c0 to %c32 step %arg2 iter_args(%arg4 = %arg1) -> tensor<?x?xbf16> {
+    %extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+    %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d0, d1)> ,affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice1, %extracted_slice : tensor<?x?xf32>,  tensor<?x?xf32>) outs(%arg1 : tensor<?x?xbf16>) {
+    ^bb0(%in: f32, %in_1 : f32, %out: bf16):
+      %1 = arith.truncf %in : f32 to bf16
+      linalg.yield %1 : bf16
+    } -> tensor<?x?xbf16>
+    scf.yield %0 : tensor<?x?xbf16>
+  }
+ return %for : tensor<?x?xbf16>
+}
+
+// CHECK-LABEL: func.func @push_extract_through_generic_secondextract
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice
+// CHECK:         %[[FOR:.+]] = scf.for
+// CHECK:           %[[PAD:.+]] = tensor.pad %[[EXTRACT]]
+// CHECK:           %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:        ins(%[[PAD]], %[[ARG0]]
+// CHECK:           %[[EXTRACT2:.+]] =  tensor.extract_slice %[[GENERIC]]
+// CHECK:           scf.yield %[[EXTRACT2]]
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index 2cf25d8fc8c19..d332270468ea8 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -34,8 +34,13 @@ struct TestDataLayoutPropagationPass
     RewritePatternSet patterns(context);
     linalg::populateDataLayoutPropagationPatterns(
         patterns, [](OpOperand *opOperand) { return true; });
-    linalg::populateExtractSliceSinkingPatterns(
-        patterns, [](OpOperand *opOperand) { return true; });
+    linalg::ControlPropagationFn controlExtract =
+        [](OpOperand *opOperand) -> bool {
+      Operation *producer = opOperand->get().getDefiningOp();
+      Operation *consumer = opOperand->getOwner();
+      return consumer->getBlock() == producer->getBlock();
+    };
+    linalg::populateExtractSliceSinkingPatterns(patterns, controlExtract);
     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
       return signalPassFailure();
   }



More information about the Mlir-commits mailing list