[Mlir-commits] [mlir] 8d4f317 - [mlir][linalg] Fix UnPackOp::getTiledOuterDims (#152960)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Aug 14 03:39:54 PDT 2025


Author: Andrzej WarzyƄski
Date: 2025-08-14T11:39:50+01:00
New Revision: 8d4f3171fa97a8c74c57fdbac2f609344b9a3f2d

URL: https://github.com/llvm/llvm-project/commit/8d4f3171fa97a8c74c57fdbac2f609344b9a3f2d
DIFF: https://github.com/llvm/llvm-project/commit/8d4f3171fa97a8c74c57fdbac2f609344b9a3f2d.diff

LOG: [mlir][linalg] Fix UnPackOp::getTiledOuterDims (#152960)

Fixes `getTiledOuterDims` by making sure that the `outer_dims_perm`
attribute from `linalg.unpack` is taken into account.

Fixes #152037

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/decompose-unpack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 1e5b5d46de55f..8d5306dca43e3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1690,7 +1690,7 @@ struct DecomposeOuterUnitDimsPackOpPattern
 /// Rewrites a linalg::UnPackOp into a sequence of rank-reduced
 ///   * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp
 ///
-/// Requires that all the outer dims of the input linalg::PackOp are 1.
+/// Requires that all the tiled outer dims of the input linalg::PackOp are 1.
 ///
 /// Before:
 /// ```

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9d7fb18f56fef..db4edfe36784f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5767,11 +5767,18 @@ ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
 
 SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
   auto innerDimsPos = getInnerDimsPos();
-  auto packedShape = getSourceType().getShape();
+  SmallVector<int64_t> outerDims(getAllOuterDims());
   SmallVector<int64_t> res;
 
+  // Recover the original order of the outer dims.
+  SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
+  invertPermutationVector(outerDimPermInv);
+  if (!outerDimPermInv.empty())
+    applyPermutationToVector(outerDims, outerDimPermInv);
+
+  // Collect the outer dims corresponding to the tilled inner dims.
   for (auto index : innerDimsPos)
-    res.push_back(packedShape[index]);
+    res.push_back(outerDims[index]);
 
   return res;
 }

diff  --git a/mlir/test/Dialect/Linalg/decompose-unpack.mlir b/mlir/test/Dialect/Linalg/decompose-unpack.mlir
index e173d557c770d..a53dde8d3dd05 100644
--- a/mlir/test/Dialect/Linalg/decompose-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-unpack.mlir
@@ -203,3 +203,20 @@ func.func @unpack_with_non_trailing_dimensions_in_inner_dims(%arg0: tensor<1x1x1
 // CHECK-SAME:                      outs(%[[EMPTY]] : tensor<1x4xf32>) permutation = [1, 0]
 // CHECK:        %[[INSERT:.+]] = tensor.insert_slice %transposed into %[[DEST]][0, 0, 0] [1, 1, 4] [1, 1, 1] : tensor<1x4xf32> into tensor<1x1x4xf32>
 // CHECK:        return %[[INSERT]]
+
+// -----
+
+/// Note "126", which is a non-unit tile-outer-dim. This is not supported.
+
+func.func @negative_non_unit_tiled_outer_dim(%src: tensor<1x126x1x1x8xf32>, %dest: tensor<1x1x1x1001xf32>) -> tensor<1x1x1x1001xf32> {
+  %unpack = linalg.unpack %src
+    outer_dims_perm = [0, 3, 2, 1]
+    inner_dims_pos = [3]
+    inner_tiles = [8]
+    into %dest : tensor<1x126x1x1x8xf32>
+    -> tensor<1x1x1x1001xf32>
+
+  return %unpack : tensor<1x1x1x1001xf32>
+}
+// CHECK-LABEL: @negative_non_unit_tiled_outer_dim(
+// CHECK: linalg.unpack


        


More information about the Mlir-commits mailing list