[Mlir-commits] [mlir] [MLIR][Linalg] Fix empty tensor assumptions for linalg.pack decomposition (PR #160246)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Mon Sep 29 07:36:03 PDT 2025
    
    
  
https://github.com/maxbartel updated https://github.com/llvm/llvm-project/pull/160246
>From 9f031bccdd56b02726055d7744de6be649fbf3fc Mon Sep 17 00:00:00 2001
From: Maximilian Bartel <bartel at roofline.ai>
Date: Tue, 23 Sep 2025 08:57:22 +0200
Subject: [PATCH 1/2] (linalg.pack): fix empty tensor assumptions
---
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 17 +++++++----------
 mlir/test/Dialect/Linalg/decompose-pack.mlir  | 19 +++++++++++++++++++
 2 files changed, 26 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index e9a8b253eea35..69cbc7048f646 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1151,11 +1151,11 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
       packOp.getDimAndTileMapping();
   int64_t srcRank = packOp.getSourceRank();
   int64_t destRank = packOp.getDestRank();
-  int64_t numTiles = destRank - srcRank;
 
-  // 1. Extract the inner tile sizes.
-  // Where possible, values are replaced with constant attributes (to match the
-  // behaviour of `getPackOpSourceOrPaddedSource`).
+  // 1. Extract the inner tile sizes and the shapes for the tensor.empty op
+  // before transposing. Where possible, values are replaced with constant
+  // attributes (to match the behaviour of `getPackOpSourceOrPaddedSource`).
+  SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank, oneIdxAttr);
   SmallVector<OpFoldResult> tileSizes;
   for (auto i : llvm::seq<unsigned>(0, srcRank)) {
     if (dimAndTileMapping.count(i)) {
@@ -1165,6 +1165,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
       auto [_, tileSize] =
           getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
       tileSizes.push_back(tileSize);
+      transShapeForEmptyOp[i] = tileSize;
     }
   }
 
@@ -1194,18 +1195,14 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
   LDBG() << "Pack permutation: " << packOp;
   LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose);
 
-  // 2.1 Create tensor.empty (init value for TransposeOp)
-  SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
-                                                 oneIdxAttr);
-  transShapeForEmptyOp.append(tileSizes);
-
+  // 2.2 Transpose the tensor.empty shapes.
   applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
                                          srcPermForTranspose);
   Value empty =
       tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp,
                               packOp.getSourceType().getElementType());
 
-  // 2.2 Create linalg.transpose
+  // 2.3 Create linalg.transpose
   auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
                                                   srcPermForTranspose);
 
diff --git a/mlir/test/Dialect/Linalg/decompose-pack.mlir b/mlir/test/Dialect/Linalg/decompose-pack.mlir
index 17e6c29754f9d..15521d415b8a7 100644
--- a/mlir/test/Dialect/Linalg/decompose-pack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-pack.mlir
@@ -274,3 +274,22 @@ func.func @pack_with_adjacent_trailing_dimensions_inner_dims_pos_and_unit_outer(
 // CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
 // CHECK-SAME:      [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
 // CHECK:         return %[[INSERT]]
+
+// -----
+
+func.func @pack_with_zero_pos_tile_size(%arg0: tensor<8x1x1x1xf32>, %arg1:tensor<1x1x1x1x8x1xf32>) -> tensor<1x1x1x1x8x1xf32> {
+  %pack = linalg.pack %arg0 outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [0, 3] inner_tiles = [8, 1] into %arg1: tensor<8x1x1x1xf32> -> tensor<1x1x1x1x8x1xf32>
+  return %pack : tensor<1x1x1x1x8x1xf32>
+}
+
+// CHECK-LABEL: func.func @pack_with_zero_pos_tile_size
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x8x1xf32>
+// CHECK:         %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME:      ins(%[[SRC]] : tensor<8x1x1x1xf32>)
+// CHECK-SAME:      outs(%[[EMPTY]] : tensor<1x1x8x1xf32>)
+// CHECK-SAME:      permutation = [1, 2, 0, 3]
+// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-SAME:      [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x1xf32> into tensor<1x1x1x1x8x1xf32>
+// CHECK:         return %[[INSERT]]
\ No newline at end of file
>From eae9f9946e30104dbb4e7a86b96ed3a735700929 Mon Sep 17 00:00:00 2001
From: Maximilian Bartel <bartel at roofline.ai>
Date: Mon, 29 Sep 2025 16:35:46 +0200
Subject: [PATCH 2/2] (linalg.pack): simplify outer dims patterns after review
---
 .../Dialect/Linalg/IR/LinalgRelayoutOps.td    |  3 +-
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 59 +++++++++----------
 2 files changed, 30 insertions(+), 32 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index f36b41ccf6745..5006d815a798a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -57,7 +57,8 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
     /// tile factors.
     DenseMap<int64_t, OpFoldResult> getDimAndTileMapping();
 
-    /// Return the tile sizes as OpFoldResult.
+    /// Return the tile sizes as OpFoldResult. Will return the Value
+    /// of the constant Op, not the constant Attribute.
     SmallVector<OpFoldResult> getMixedTiles();
 
     /// Return the tile sizes as `int64_t`. If a tile size is dynamic
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 69cbc7048f646..60219335d6a1c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1146,38 +1146,25 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
   Attribute oneIdxAttr = rewriter.getIndexAttr(1);
   Location loc = packOp.getLoc();
 
-  Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
-  DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
-      packOp.getDimAndTileMapping();
   int64_t srcRank = packOp.getSourceRank();
   int64_t destRank = packOp.getDestRank();
+  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
+  int64_t numberOfTiles = innerDimsPos.size();
 
-  // 1. Extract the inner tile sizes and the shapes for the tensor.empty op
-  // before transposing. Where possible, values are replaced with constant
-  // attributes (to match the behaviour of `getPackOpSourceOrPaddedSource`).
-  SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank, oneIdxAttr);
-  SmallVector<OpFoldResult> tileSizes;
-  for (auto i : llvm::seq<unsigned>(0, srcRank)) {
-    if (dimAndTileMapping.count(i)) {
-      // Rather than taking the tile size as is, extact the actual constant
-      // value Attribute where possible, e.g.:
-      //    [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
-      auto [_, tileSize] =
-          getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
-      tileSizes.push_back(tileSize);
-      transShapeForEmptyOp[i] = tileSize;
-    }
-  }
+  // 1. Get the input that is going to be packed. If the input requires padding,
+  // add a padding operation and return that as the input.
+  Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
 
   // 2. Transpose the input to match the inner tile order:
   //    %init = tensor.empty()
   //    %transposed_tile = linalg.transpose ins(%source_or_padded_source),
   //                                        outs(%init)
   // Assumptions made:
-  //  1. All outer dims are 1 - the corresponding transposition order doesn't
+  //  - All outer dims are 1 - the corresponding transposition order doesn't
   //     matter, but requires all dim indices to be present.
+
+  // 2.1 Get the permutation for linalg.transpose
   SmallVector<int64_t> srcPermForTranspose;
-  ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos());
   for (int64_t i = 0; i < srcRank; i++) {
     // We assume the `k` dimensions of the inner dim position, where `k` is the
     // rank of the inner tiling, correspond to the last `k` indices of the
@@ -1186,21 +1173,32 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
     // rank of the source tensor. For example if we have a source tensor with
     // indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining
     // indices are [1, 2]. and the transpose will be [1, 2, 3, 0].
-    if (llvm::is_contained(innerDimPos, i))
+    if (llvm::is_contained(innerDimsPos, i))
       continue;
     srcPermForTranspose.push_back(i);
   }
-  srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end());
+  srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
+
+  // 2.2 Create the init tensor for linalg.transpose with the correct shape
+  SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles,
+                                            oneIdxAttr);
+  shapeForEmptyOp.append(packOp.getMixedTiles());
+
+  // getMixedTiles() may contain Values pointing to constant ops, not the
+  // constant attributes. Replace them with a true OpFoldResult.
+  llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(),
+                  [&](OpFoldResult ofr) {
+                    if (auto val = llvm::dyn_cast<Value>(ofr))
+                      return getAsOpFoldResult(val);
+                    return ofr;
+                  });
 
   LDBG() << "Pack permutation: " << packOp;
   LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose);
+  LDBG() << "Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp);
 
-  // 2.2 Transpose the tensor.empty shapes.
-  applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
-                                         srcPermForTranspose);
-  Value empty =
-      tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp,
-                              packOp.getSourceType().getElementType());
+  Value empty = tensor::EmptyOp::create(
+      rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType());
 
   // 2.3 Create linalg.transpose
   auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
@@ -1211,8 +1209,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
   SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
   SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
   // Outer dims are all 1s!
-  SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
-                                       oneIdxAttr);
+  SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr);
   SmallVector<int64_t> writeShape;
 
   for (auto tileSize : packOp.getMixedTiles()) {
    
    
More information about the Mlir-commits
mailing list