[Mlir-commits] [mlir] f53b624 - [MLIR][Linalg] Fix empty tensor assumptions for linalg.pack decomposition (#160246)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 8 05:39:27 PDT 2025


Author: maxbartel
Date: 2025-10-08T13:39:23+01:00
New Revision: f53b6249c24005d1a6208cd9e355595eb6519dc0

URL: https://github.com/llvm/llvm-project/commit/f53b6249c24005d1a6208cd9e355595eb6519dc0
DIFF: https://github.com/llvm/llvm-project/commit/f53b6249c24005d1a6208cd9e355595eb6519dc0.diff

LOG: [MLIR][Linalg] Fix empty tensor assumptions for linalg.pack decomposition (#160246)

The original code seemed to assume that the tiling dimensions for the
tensor.empty op before applying the transposing were always the last
dimensions. However, pack allows you to choose any dimension to tile.

The easiest way I found to solve this is to prefill the SmallVector with
1s of size (srcRank - numberOfTiles) and then appending the tile sizes.

This way I could also get rid of the first loop in the code.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/decompose-pack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 3390f380c7eb8..6504ca8664d49 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -57,7 +57,11 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
     /// tile factors.
     DenseMap<int64_t, OpFoldResult> getDimAndTileMapping();
 
-    /// Return the tile sizes as OpFoldResult.
+    // TODO: Return the folded result.
+    /// Return the tile sizes as OpFoldResult. Will return the Value
+    /// of the constant Op, not the constant Attribute.
+    /// E.g., for %size = arith.constant 1 : i32 will return %size,
+    /// not 1.
     SmallVector<OpFoldResult> getMixedTiles();
 
     /// Return the tile sizes as `int64_t`. If a tile size is dynamic

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 7863c2120fe18..0dac688e1c26d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1146,37 +1146,25 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
   Attribute oneIdxAttr = rewriter.getIndexAttr(1);
   Location loc = packOp.getLoc();
 
-  Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
-  DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
-      packOp.getDimAndTileMapping();
   int64_t srcRank = packOp.getSourceRank();
   int64_t destRank = packOp.getDestRank();
-  int64_t numTiles = destRank - srcRank;
+  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
+  int64_t numberOfTiles = innerDimsPos.size();
 
-  // 1. Extract the inner tile sizes.
-  // Where possible, values are replaced with constant attributes (to match the
-  // behaviour of `getPackOpSourceOrPaddedSource`).
-  SmallVector<OpFoldResult> tileSizes;
-  for (auto i : llvm::seq<unsigned>(0, srcRank)) {
-    if (dimAndTileMapping.count(i)) {
-      // Rather than taking the tile size as is, extact the actual constant
-      // value Attribute where possible, e.g.:
-      //    [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
-      auto [_, tileSize] =
-          getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
-      tileSizes.push_back(tileSize);
-    }
-  }
+  // 1. Get the input that is going to be packed. If the input requires padding,
+  // add a padding operation and return that as the input.
+  Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
 
   // 2. Transpose the input to match the inner tile order:
   //    %init = tensor.empty()
   //    %transposed_tile = linalg.transpose ins(%source_or_padded_source),
   //                                        outs(%init)
   // Assumptions made:
-  //  1. All outer dims are 1 - the corresponding transposition order doesn't
+  //  - All outer dims are 1 - the corresponding transposition order doesn't
   //     matter, but requires all dim indices to be present.
+
+  // 2.1 Get the permutation for linalg.transpose
   SmallVector<int64_t> srcPermForTranspose;
-  ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos());
   for (int64_t i = 0; i < srcRank; i++) {
     // We assume the `k` dimensions of the inner dim position, where `k` is the
     // rank of the inner tiling, correspond to the last `k` indices of the
@@ -1185,27 +1173,34 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
     // rank of the source tensor. For example if we have a source tensor with
     // indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining
     // indices are [1, 2]. and the transpose will be [1, 2, 3, 0].
-    if (llvm::is_contained(innerDimPos, i))
+    if (llvm::is_contained(innerDimsPos, i))
       continue;
     srcPermForTranspose.push_back(i);
   }
-  srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end());
+  srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
+
+  // 2.2 Create the init tensor for linalg.transpose with the correct shape
+  SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles,
+                                            oneIdxAttr);
+  shapeForEmptyOp.append(packOp.getMixedTiles());
+
+  // getMixedTiles() may contain Values pointing to constant ops, not the
+  // constant attributes. Replace them with a true OpFoldResult.
+  llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(),
+                  [&](OpFoldResult ofr) {
+                    if (auto val = llvm::dyn_cast<Value>(ofr))
+                      return getAsOpFoldResult(val);
+                    return ofr;
+                  });
 
   LDBG() << "Pack permutation: " << packOp;
   LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose);
+  LDBG() << "Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp);
 
-  // 2.1 Create tensor.empty (init value for TransposeOp)
-  SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
-                                                 oneIdxAttr);
-  transShapeForEmptyOp.append(tileSizes);
-
-  applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
-                                         srcPermForTranspose);
-  Value empty =
-      tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp,
-                              packOp.getSourceType().getElementType());
+  Value empty = tensor::EmptyOp::create(
+      rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType());
 
-  // 2.2 Create linalg.transpose
+  // 2.3 Create linalg.transpose
   auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
                                                   srcPermForTranspose);
 
@@ -1214,8 +1209,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
   SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
   SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
   // Outer dims are all 1s!
-  SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
-                                       oneIdxAttr);
+  SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr);
   SmallVector<int64_t> writeShape;
 
   for (auto tileSize : packOp.getMixedTiles()) {

diff  --git a/mlir/test/Dialect/Linalg/decompose-pack.mlir b/mlir/test/Dialect/Linalg/decompose-pack.mlir
index 17e6c29754f9d..18a09f4c669bb 100644
--- a/mlir/test/Dialect/Linalg/decompose-pack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-pack.mlir
@@ -274,3 +274,24 @@ func.func @pack_with_adjacent_trailing_dimensions_inner_dims_pos_and_unit_outer(
 // CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
 // CHECK-SAME:      [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
 // CHECK:         return %[[INSERT]]
+
+// -----
+
+// The following example shows a pack operation where the inner dims
+// positions are non-adjacent and non-permuted.
+func.func @pack_with_non_adjacent_and_non_permuted_inner_dims(%arg0: tensor<8x1x1x1xf32>, %arg1:tensor<1x1x1x1x8x1xf32>) -> tensor<1x1x1x1x8x1xf32> {
+  %pack = linalg.pack %arg0 outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [0, 3] inner_tiles = [8, 1] into %arg1: tensor<8x1x1x1xf32> -> tensor<1x1x1x1x8x1xf32>
+  return %pack : tensor<1x1x1x1x8x1xf32>
+}
+
+// CHECK-LABEL: func.func @pack_with_non_adjacent_and_non_permuted_inner_dims
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x8x1xf32>
+// CHECK:         %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME:      ins(%[[SRC]] : tensor<8x1x1x1xf32>)
+// CHECK-SAME:      outs(%[[EMPTY]] : tensor<1x1x8x1xf32>)
+// CHECK-SAME:      permutation = [1, 2, 0, 3]
+// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-SAME:      [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x1xf32> into tensor<1x1x1x1x8x1xf32>
+// CHECK:         return %[[INSERT]]


        


More information about the Mlir-commits mailing list