[Mlir-commits] [mlir] [mlir] Fix DataLayoutPropagation foldings invalidating IR (PR #140103)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 15 10:03:50 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (Max191)

<details>
<summary>Changes</summary>

Fixes a bug in DataLayoutPropagation that was replacing generic op destinations with tensor.empty() ops, even when the destination operand was being used.

Addresses post-merge comment: https://github.com/llvm/llvm-project/pull/138332/files/a9c1dccc3f73793bdd9e1f51ab3a6e15403a8338#r2091193712

---
Full diff: https://github.com/llvm/llvm-project/pull/140103.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+16-7) 
- (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+21-10) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 26904f1f40d12..dd8ef9608a821 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -312,10 +312,17 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
   SmallVector<Value> inputOperands;
   SmallVector<Value> inputOperandsFromUnpackedSource;
   SmallVector<AffineMap> indexingMaps;
+  auto hasEquivalentTiles = [](PackOp packOp, UnPackOp unPackOp) {
+    return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm() &&
+           packOp.getInnerDimsPos() == unPackOp.getInnerDimsPos() &&
+           llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles());
+  };
   for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
     auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
         rewriter, loc, packInfo, genericOp, inputOperand);
-    if (auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>()) {
+    auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>();
+    auto packOp = packedOperand.getDefiningOp<linalg::PackOp>();
+    if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) {
       inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
     } else {
       inputOperandsFromUnpackedSource.push_back(packedOperand);
@@ -324,14 +331,16 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
     indexingMaps.push_back(packedIndexingMap);
   }
 
-  // If the pack and unpack op can be folded:
-  // 1) use unpack op source op for operand to fold unpack -> pack sequence.
-  // 2) init tensor of the generic op can be replaced by the destination of the
-  // pack op.
+  // If the unpack->pack sequences can be folded, replace use the sources of
+  // the unpack ops in any unpack->pack chains on the generic op operands.
   if (isFoldableUnpackPack) {
     inputOperands = inputOperandsFromUnpackedSource;
-    if (auto destPack = dest.getDefiningOp<linalg::PackOp>())
-      dest = destPack.getDest();
+    if (auto destPack = dest.getDefiningOp<linalg::PackOp>()) {
+      auto destUnPack = destPack.getSource().getDefiningOp<linalg::UnPackOp>();
+      if (destUnPack && hasEquivalentTiles(destPack, destUnPack)) {
+        dest = destUnPack.getSource();
+      }
+    }
   }
 
   int64_t numInnerLoops = packInfo.getNumTiledLoops();
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 63f068d3f8681..31c9e9ed3c501 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -455,10 +455,9 @@ func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56
 // CHECK:         %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG0_EMPTY_UNPACK]]
-// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP]]]
-// CHECK-SAME:      outs(%[[EMPTY]]
+// CHECK-SAME:      outs(%[[ARG0]]
 // CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[UNPACKED_ARG0]]
@@ -482,11 +481,14 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
 // CHECK-LABEL: func.func @unpack_on_input
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
+// CHECK:         %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
+// CHECK:         %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME:      into %[[ARG1_PACK_EMPTY]]
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
 // CHECK-SAME:      ins(%[[ARG0]]
-// CHECK-SAME:      outs(%[[EMPTY]]
+// CHECK-SAME:      outs(%[[ARG1_PACK]]
 // CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG1]]
@@ -510,11 +512,14 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
 // CHECK-LABEL: func.func @unpack_element_type_change
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
+// CHECK:         %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
+// CHECK:         %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME:      into %[[ARG1_PACK_EMPTY]]
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
 // CHECK-SAME:      ins(%[[ARG0]]
-// CHECK-SAME:      outs(%[[EMPTY]]
+// CHECK-SAME:      outs(%[[ARG1_PACK]]
 // CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG1]]
@@ -1397,10 +1402,13 @@ func.func @push_unpack_in_padded_domain_foldable(%arg0: tensor<8x8x4x8xf32>, %de
 // CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]
-// CHECK:         %[[EMPTY:.+]] = tensor.empty
+// CHECK:         %[[ARG2_PACK_EMPTY:.+]] = tensor.empty
+// CHECK:         %[[ARG2_PACK:.+]] = linalg.pack %[[ARG2]]
+// CHECK-SAME:      inner_dims_pos = [0, 1] inner_tiles = [4, 8]
+// CHECK-SAME:      into %[[ARG2_PACK_EMPTY]]
 // CHECK:         %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:    ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
-// CHECK-SAME:    outs(%[[EMPTY]] : tensor<?x8x4x8xbf16>)
+// CHECK-SAME:    outs(%[[ARG2_PACK]] : tensor<?x8x4x8xbf16>)
 // CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]]
 // CHECK-SAME:    into %[[ARG2]]
 // CHECK:         return %[[UNPACK]] : tensor<?x64xbf16>
@@ -1419,10 +1427,13 @@ func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %ar
 // CHECK-LABEL: func.func @push_unpack_in_padded_domain_out_used
 // CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK:         %[[EMPTY:.+]] = tensor.empty
+// CHECK:         %[[ARG1_PACK_EMPTY:.+]] = tensor.empty
+// CHECK:         %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
+// CHECK-SAME:      inner_dims_pos = [0, 1] inner_tiles = [4, 8]
+// CHECK-SAME:      into %[[ARG1_PACK_EMPTY]]
 // CHECK:         %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:    ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
-// CHECK-SAME:    outs(%[[EMPTY]] : tensor<?x8x4x8xf32>)
+// CHECK-SAME:    outs(%[[ARG1_PACK]] : tensor<?x8x4x8xf32>)
 // CHECK:         %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]]
 // CHECK-SAME:    into %[[ARG1]]
 // CHECK:         return %[[UNPACK2]] : tensor<?x64xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/140103


More information about the Mlir-commits mailing list