[Mlir-commits] [mlir] [mlir][linalg] Use ub.poison in data layout propagation if a packed operand requires padding. (PR #159467)

Nirvedh Meshram llvmlistbot at llvm.org
Fri Sep 19 13:58:35 PDT 2025


https://github.com/nirvedhmeshram updated https://github.com/llvm/llvm-project/pull/159467

>From 511fa2362c0f4eb4597fa7a190dd337ac69c974c Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Wed, 17 Sep 2025 15:18:21 -0700
Subject: [PATCH 1/2] [mlir][linalg] Use ub.poison in data layout propagation
 if a packed operand requires padding.

In the past, it was hard to set padding values because we did not have
ub.poison. It is not always correct if we set zeros as padding values.
Now we can use `ub.poison` in this case. The revision adds the support
for setting padding value using `ub.poison` when padding is required in
the propagation. Otherwise, it creats an invalid pack op.

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../Transforms/DataLayoutPropagation.cpp      |  9 +++--
 .../Linalg/data-layout-propagation.mlir       | 35 ++++++++++++++++---
 2 files changed, 37 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 6c17c3c2d0cab..2d075d92017f2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/Dominance.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -289,9 +290,11 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
 
   auto empty = linalg::PackOp::createDestinationTensor(
       b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
-  auto packedOperand = linalg::PackOp::create(
-      b, loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
-      /*padding=*/std::nullopt, outerDimsPerm);
+  auto poison = ub::PoisonOp::create(
+      b, loc, getElementTypeOrSelf(opOperand->get().getType()));
+  auto packedOperand =
+      linalg::PackOp::create(b, loc, opOperand->get(), empty, innerDimsPos,
+                             innerTileSizes, poison, outerDimsPerm);
   return std::make_tuple(packedOperand, indexingMap);
 }
 
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index a5f8d63a3e912..7a16bc0a4faee 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1450,6 +1450,33 @@ func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %ar
 
 // -----
 
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @push_unpack_in_padded_domain_multiple_inputs(%arg0: tensor<1x4x16x16xf32>, %arg1: tensor<8x64xf32>, %arg2: tensor<8x64xf32>) -> tensor<8x64xf32> {
+  %0 = tensor.empty() : tensor<8x64xf32>
+  %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<1x4x16x16xf32> -> tensor<8x64xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg1, %unpack : tensor<8x64xf32>, tensor<8x64xf32>) outs(%arg2 : tensor<8x64xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %2 = arith.addf %in, %in_0 : f32
+    linalg.yield %2 : f32
+  } -> tensor<8x64xf32>
+  return %1 : tensor<8x64xf32>
+}
+// CHECK-LABEL: func.func @push_unpack_in_padded_domain_multiple_inputs
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[POISON:.+]] = ub.poison : f32
+// CHECK:         %[[PACK:.+]] = linalg.pack %[[ARG1]] padding_value(%[[POISON]] : f32)
+// CHECK-SAME:       inner_dims_pos = [0, 1] inner_tiles = [16, 16]
+// CHECK:         %[[ELEM:.+]] = linalg.generic
+// CHECK:           ins(%[[PACK]], %[[ARG0]]
+// CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[ELEM]]
+// CHECK-SAME:      inner_dims_pos = [0, 1] inner_tiles = [16, 16]
+// CHECK-SAME:      into %[[ARG2]]
+// CHECK:         return %[[UNPACK]]
+
+// -----
+
 module {
   func.func @push_extract_through_generic(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
     %extracted_slice = tensor.extract_slice %arg0[0, 0, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
@@ -1473,7 +1500,7 @@ module {
 // CHECK:         } : tensor<?x5x3x128xf32> to tensor<?x5x3x128xf32>
 // CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<128x5x128xbf16>
 // CHECK:         %[[GENERIC:.+]] = linalg.generic
-// CHECK-SAME:    ins(%[[ARG0]], %[[PADDED]]   
+// CHECK-SAME:    ins(%[[ARG0]], %[[PADDED]]
 // CHECK-SAME:    outs(%[[EMPTY]]
 // CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice %3[%[[ARG3]], 0, 0] [%[[ARG3]], 5, 128] [1, 1, 1] : tensor<128x5x128xbf16> to tensor<?x5x128xbf16>
 // CHECK:         return %[[EXTRACT]]
@@ -1492,7 +1519,7 @@ func.func @nopush_extract_through_generic_nodimexpr1(%arg0: tensor<128x7x128xf32
 
 // CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr1
 // CHECK:         %[[GENERIC:.+]] = linalg.generic
-// CHECK:         return %[[GENERIC]]          
+// CHECK:         return %[[GENERIC]]
 
 // -----
 
@@ -1508,7 +1535,7 @@ func.func @nopush_extract_through_generic_nodimexpr2(%arg0: tensor<128x?x128xf32
 
 // CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr2
 // CHECK:         %[[GENERIC:.+]] = linalg.generic
-// CHECK:         return %[[GENERIC]]   
+// CHECK:         return %[[GENERIC]]
 
 // -----
 
@@ -1575,7 +1602,7 @@ func.func @push_extract_through_generic_rank0_operand(%arg0: tensor<128x128xf32>
 
 // CHECK-LABEL: func.func @push_extract_through_generic_rank0_operand
 // CHECK:         %[[GENERIC:.+]] = linalg.generic
-// CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]]         
+// CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]]
 // CHECK:         return %[[EXTRACT]]
 
 // -----

>From d9f526eba1e90eacbc2e2189393d1ea7c7ff0abf Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Thu, 18 Sep 2025 12:11:57 -0700
Subject: [PATCH 2/2] add option to control poison padding

Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
---
 .../Dialect/Linalg/IR/LinalgRelayoutOps.td    |   8 ++
 .../Dialect/Linalg/Transforms/Transforms.h    |   5 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  29 +++++
 .../Transforms/DataLayoutPropagation.cpp      | 110 ++++++++++++------
 .../Linalg/TestDataLayoutPropagation.cpp      |   3 +-
 5 files changed, 119 insertions(+), 36 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index f36b41ccf6745..ff9eccacf6278 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -239,6 +239,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
                                     ArrayRef<int64_t> outerDimsPerm,
                                     ArrayRef<OpFoldResult> innerTiles);
 
+    // Same as above function but here dynamic dimensions are assumed
+    // to require padding.
+    static bool requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
+                                    ArrayRef<int64_t> innerDimsPos,
+                                    ArrayRef<int64_t> outputShape,
+                                    ArrayRef<int64_t> outerDimsPerm,
+                                    ArrayRef<OpFoldResult> innerTiles);
+
     static Value createDestinationTensor(OpBuilder &b, Location loc,
         Value source, ArrayRef<OpFoldResult> innerTileSizes,
         ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 64d3a2448b409..41670249936e6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1914,9 +1914,12 @@ void populateElementwiseOpsFusionPatterns(
 using ControlPropagationFn = std::function<bool(OpOperand *opOperand)>;
 
 /// Patterns to bubble up or down data layout ops across other operations.
+/// The function also has an option to allow the patterns to propagate with
+/// poison padding if requested by the caller.
 void populateDataLayoutPropagationPatterns(
     RewritePatternSet &patterns,
-    const ControlPropagationFn &controlPackUnPackPropagation);
+    const ControlPropagationFn &controlPackUnPackPropagation,
+    bool PoisonPaddingOk = false);
 
 /// Patterns to sink extract slice across other operations.
 void populateExtractSliceSinkingPatterns(
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 578931e1351c6..0932bfe45916a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5310,6 +5310,35 @@ bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
   return false;
 }
 
+bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
+                                       ArrayRef<int64_t> innerDimsPos,
+                                       ArrayRef<int64_t> outputShape,
+                                       ArrayRef<int64_t> outerDimsPerm,
+                                       ArrayRef<OpFoldResult> innerTiles) {
+  SmallVector<int64_t> outputTileSizes(
+      outputShape.take_front(inputShape.size()));
+  if (!outerDimsPerm.empty()) {
+    assert(outerDimsPerm.size() == outputTileSizes.size() &&
+           "expected output and outer_dims_perm to have same size");
+    applyPermutationToVector(outputTileSizes,
+                             invertPermutationVector(outerDimsPerm));
+  }
+  for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
+    if (ShapedType::isDynamic(inputShape[pos]))
+      return true;
+    std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
+
+    if (!constantTile) {
+      if (ShapedType::isStatic(outputTileSizes[pos]) &&
+          (inputShape[pos] % outputTileSizes[pos] != 0))
+        return true;
+    } else if (inputShape[pos] % (*constantTile) != 0) {
+      return true;
+    }
+  }
+  return false;
+}
+
 LogicalResult PackOp::verify() {
   if (failed(commonVerifierPackAndUnPackOp(*this)))
     return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 2d075d92017f2..e0926d9a566a6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -221,9 +221,10 @@ static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
 ///    inner_dims_pos = [0]
 ///    inner_tiles = [8]
 ///    into %init : tensor<?xf32> -> tensor<?x8xf32>
-static std::tuple<Value, AffineMap>
+static FailureOr<std::tuple<Value, AffineMap>>
 getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
-                               GenericOp genericOp, OpOperand *opOperand) {
+                               GenericOp genericOp, OpOperand *opOperand,
+                               bool poisonPaddingOk) {
   int64_t numOrigLoops = genericOp.getNumLoops();
   int64_t numInnerLoops = packInfo.getNumTiledLoops();
   int64_t numLoops = numOrigLoops + numInnerLoops;
@@ -287,12 +288,24 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
   // The operand does not have dimensions that relates to pack op.
   if (innerDimsPos.empty() && outerDimsPerm.empty())
     return std::make_tuple(opOperand->get(), indexingMap);
-
+  auto inputType = cast<RankedTensorType>(opOperand->get().getType());
+  auto maybeIntInnerTileSizes = getConstantIntValues(innerTileSizes);
+  if (!maybeIntInnerTileSizes.has_value()) {
+    return failure();
+  }
+  if (!poisonPaddingOk &&
+      linalg::PackOp::requirePaddingValueStrict(
+          inputType.getShape(), innerDimsPos,
+          linalg::PackOp::inferPackedType(inputType, *maybeIntInnerTileSizes,
+                                          innerDimsPos, outerDimsPerm)
+              .getShape(),
+          outerDimsPerm, innerTileSizes))
+    return failure();
   auto empty = linalg::PackOp::createDestinationTensor(
       b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
   auto poison = ub::PoisonOp::create(
       b, loc, getElementTypeOrSelf(opOperand->get().getType()));
-  auto packedOperand =
+  Value packedOperand =
       linalg::PackOp::create(b, loc, opOperand->get(), empty, innerDimsPos,
                              innerTileSizes, poison, outerDimsPerm);
   return std::make_tuple(packedOperand, indexingMap);
@@ -304,10 +317,10 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
 /// around it. Implicitly this will only work when a packInfo can be obtained.
 /// This make sure that we are only using this function on parallel permuted
 /// dimensions.
-static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
-                               Value dest, AffineMap packedOutIndexingMap,
-                               const PackInfo &packInfo,
-                               bool isFoldableUnpackPack) {
+static FailureOr<GenericOp>
+packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest,
+              AffineMap packedOutIndexingMap, const PackInfo &packInfo,
+              bool isFoldableUnpackPack, bool poisonPaddingOk) {
   Location loc = genericOp.getLoc();
   SmallVector<Value> inputOperands;
   SmallVector<Value> inputOperandsFromUnpackedSource;
@@ -318,8 +331,13 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
            llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles());
   };
   for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
-    auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
-        rewriter, loc, packInfo, genericOp, inputOperand);
+    auto mayBepackedOperandAndIndexing = getOrCreatePackedViewOfOperand(
+        rewriter, loc, packInfo, genericOp, inputOperand, poisonPaddingOk);
+    if (failed(mayBepackedOperandAndIndexing)) {
+      return failure();
+    }
+    auto packedOperand = std::get<0>(*mayBepackedOperandAndIndexing);
+    auto packedIndexingMap = std::get<1>(*mayBepackedOperandAndIndexing);
     auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>();
     auto packOp = packedOperand.getDefiningOp<linalg::PackOp>();
     if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) {
@@ -410,7 +428,8 @@ static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
 ///     } -> tensor<?x?x8x2xf32>
 static FailureOr<GenericOp>
 bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
-                               const ControlPropagationFn &controlFn) {
+                               const ControlPropagationFn &controlFn,
+                               bool poisonPaddingOk) {
   auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
   if (!genericOp)
     return failure();
@@ -473,9 +492,14 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
   }
 
   // Rebuild the indexing map for the corresponding init operand.
-  auto [packedOutOperand, packedOutIndexingMap] =
+  auto mayBepackedOperandAndIndexing =
       getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
-                                     genericOp, opOperand);
+                                     genericOp, opOperand, poisonPaddingOk);
+  if (failed(mayBepackedOperandAndIndexing)) {
+    return failure();
+  }
+  auto packedOutOperand = std::get<0>(*mayBepackedOperandAndIndexing);
+  auto packedOutIndexingMap = std::get<1>(*mayBepackedOperandAndIndexing);
 
   // Forward the new tensor.empty as a destination if it is one of the following
   // situations:
@@ -491,7 +515,8 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
   // pack(unpack) isn't naively foldable because the unpack op can be from
   // an arbitrary domain so we need to keep both.
   return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
-                       *packInfo, /*isFoldableUnpackPack=*/false);
+                       *packInfo, /*isFoldableUnpackPack=*/false,
+                       poisonPaddingOk);
 }
 
 /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
@@ -499,13 +524,15 @@ struct BubbleUpPackOpThroughGenericOpPattern
     : public OpRewritePattern<linalg::PackOp> {
 public:
   BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context,
-                                        ControlPropagationFn fun)
-      : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
+                                        ControlPropagationFn fun,
+                                        bool poisonPaddingOk)
+      : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)),
+        poisonPaddingOk(std::move(poisonPaddingOk)) {}
 
   LogicalResult matchAndRewrite(linalg::PackOp packOp,
                                 PatternRewriter &rewriter) const override {
-    auto genericOp =
-        bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
+    auto genericOp = bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn,
+                                                    poisonPaddingOk);
     if (failed(genericOp))
       return failure();
     rewriter.replaceOp(packOp, genericOp->getResults());
@@ -514,6 +541,7 @@ struct BubbleUpPackOpThroughGenericOpPattern
 
 private:
   ControlPropagationFn controlFn;
+  bool poisonPaddingOk;
 };
 
 /// Propagate a linalg.pack operation up through a tensor.pad. The idea is to
@@ -1083,7 +1111,8 @@ static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
 ///
 static FailureOr<std::tuple<GenericOp, Value>>
 pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
-                                 ControlPropagationFn controlFn) {
+                                 ControlPropagationFn controlFn,
+                                 bool poisonPaddingOk) {
   if (genericOp.getNumResults() != 1)
     return failure();
 
@@ -1110,9 +1139,14 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
     return failure();
 
   // Rebuild the indexing map for the corresponding init operand.
-  auto [packedOutOperand, packedOutIndexingMap] =
-      getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
-                                     genericOp, genericOp.getDpsInitOperand(0));
+  auto mayBepackedOperandAndIndexing = getOrCreatePackedViewOfOperand(
+      rewriter, genericOp.getLoc(), *packInfo, genericOp,
+      genericOp.getDpsInitOperand(0), poisonPaddingOk);
+  if (failed(mayBepackedOperandAndIndexing)) {
+    return failure();
+  }
+  auto packedOutOperand = std::get<0>(*mayBepackedOperandAndIndexing);
+  auto packedOutIndexingMap = std::get<1>(*mayBepackedOperandAndIndexing);
   auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
 
   // Forward the new tensor.empty as a destination if it is one of the following
@@ -1132,9 +1166,12 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
   // pack(unpack) is foldable in this case. This is because in pushing down the
   // unpack, by default we will populate an additional pack op after the unpack.
   // This guarantees them to be foldable.
-  GenericOp newGenericOp =
+  auto maybeGenericOp =
       packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo,
-                    /*isFoldableUnpackPack=*/true);
+                    /*isFoldableUnpackPack=*/true, poisonPaddingOk);
+  if (failed(maybeGenericOp))
+    return failure();
+  GenericOp newGenericOp = *maybeGenericOp;
   Value newResult =
       newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
 
@@ -1160,13 +1197,15 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
 struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
 public:
   PushDownUnPackOpThroughGenericOp(MLIRContext *context,
-                                   ControlPropagationFn fun)
-      : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
+                                   ControlPropagationFn fun,
+                                   bool poisonPaddingOk)
+      : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)),
+        poisonPaddingOk(std::move(poisonPaddingOk)) {}
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    auto genericAndRepl =
-        pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
+    auto genericAndRepl = pushDownUnPackOpThroughGenericOp(
+        rewriter, genericOp, controlFn, poisonPaddingOk);
     if (failed(genericAndRepl))
       return failure();
     rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
@@ -1175,6 +1214,7 @@ struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
 
 private:
   ControlPropagationFn controlFn;
+  bool poisonPaddingOk;
 };
 
 /// Propagate a linalg.unpack operation through a tensor.pad. The idea is to
@@ -1525,12 +1565,14 @@ class PushDownExtractSliceOpThroughGenericOp final
 
 void mlir::linalg::populateDataLayoutPropagationPatterns(
     RewritePatternSet &patterns,
-    const ControlPropagationFn &controlPackUnPackPropagation) {
-  patterns
-      .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
-              BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
-              PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
-          patterns.getContext(), controlPackUnPackPropagation);
+    const ControlPropagationFn &controlPackUnPackPropagation,
+    bool PoisonPaddingOk) {
+  patterns.insert<BubbleUpPackThroughPadOp, BubbleUpPackOpThroughReshapeOp,
+                  PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
+      patterns.getContext(), controlPackUnPackPropagation);
+  patterns.insert<BubbleUpPackOpThroughGenericOpPattern,
+                  PushDownUnPackOpThroughGenericOp>(
+      patterns.getContext(), controlPackUnPackPropagation, PoisonPaddingOk);
 }
 
 void mlir::linalg::populateExtractSliceSinkingPatterns(
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index d332270468ea8..d45aaf788f9c2 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -33,7 +33,8 @@ struct TestDataLayoutPropagationPass
     MLIRContext *context = &getContext();
     RewritePatternSet patterns(context);
     linalg::populateDataLayoutPropagationPatterns(
-        patterns, [](OpOperand *opOperand) { return true; });
+        patterns, [](OpOperand *opOperand) { return true; },
+        /*poisonPaddingOk=*/true);
     linalg::ControlPropagationFn controlExtract =
         [](OpOperand *opOperand) -> bool {
       Operation *producer = opOperand->get().getDefiningOp();



More information about the Mlir-commits mailing list