[Mlir-commits] [mlir] [mlir][linalg] Fix UnPackOp::getTiledOuterDims (PR #152960)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 11 00:04:08 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tensor

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

Fixes `getTiledOuterDims` by making sure that the `outer_dims_perm`
attribute from `linalg.unpack` is taken into account.

Fixes #<!-- -->152037


---
Full diff: https://github.com/llvm/llvm-project/pull/152960.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+37-2) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+1) 
- (modified) mlir/test/Dialect/Linalg/decompose-unpack.mlir (+15) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9d7fb18f56fef..e951c63d24b18 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5765,13 +5765,48 @@ ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
   return getSourceType().getShape().take_front(destRank);
 }
 
+static SmallVector<int64_t>
+inversePerm(const llvm::SmallVector<int64_t> &perm) {
+  const size_t n = perm.size();
+  llvm::SmallVector<int64_t> invPerm(n);
+
+  for (size_t i = 0; i < n; ++i) {
+    assert(perm[i] >= 0 && static_cast<size_t>(perm[i]) < n &&
+           "Invalid permutation entry");
+    invPerm[perm[i]] = i;
+  }
+
+  return invPerm;
+}
+
+/// Compute the inverse of a permutation. Assumes `perm` is a valid permutation
+/// of 0...n-1.
+static SmallVector<int64_t> invertPermutation(SmallVector<int64_t> perm) {
+  const size_t permLen = perm.size();
+  llvm::SmallVector<int64_t> inv(permLen);
+  for (size_t i = 0; i < permLen; ++i) {
+    assert(perm[i] >= 0 && static_cast<size_t>(perm[i]) < permLen &&
+           "Invalid permutation entry");
+    inv[perm[i]] = i;
+  }
+  return inv;
+}
+
 SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
   auto innerDimsPos = getInnerDimsPos();
-  auto packedShape = getSourceType().getShape();
+  SmallVector<int64_t> outerDims(getAllOuterDims());
   SmallVector<int64_t> res;
 
+  // Invert outer-dims-perm and use it to restore the original order
+  // of the outer dims.
+  SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
+  inversePerm(outerDimPermInv);
+  if (!outerDimPermInv.empty())
+    applyPermutationToVector(outerDims, outerDimPermInv);
+
+  // Collect the outer dims corresponding to the tilled inner dims.
   for (auto index : innerDimsPos)
-    res.push_back(packedShape[index]);
+    res.push_back(outerDims[index]);
 
   return res;
 }
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 7d4b1127a08be..c8c3a1a9eefe5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2310,6 +2310,7 @@ RankedTensorType ExtractSliceOp::inferResultType(
                                sourceTensorType.getEncoding());
 }
 
+// TODO: This uses neither offsets nor strides!
 RankedTensorType ExtractSliceOp::inferResultType(
     RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
diff --git a/mlir/test/Dialect/Linalg/decompose-unpack.mlir b/mlir/test/Dialect/Linalg/decompose-unpack.mlir
index e173d557c770d..61998abd68368 100644
--- a/mlir/test/Dialect/Linalg/decompose-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-unpack.mlir
@@ -203,3 +203,18 @@ func.func @unpack_with_non_trailing_dimensions_in_inner_dims(%arg0: tensor<1x1x1
 // CHECK-SAME:                      outs(%[[EMPTY]] : tensor<1x4xf32>) permutation = [1, 0]
 // CHECK:        %[[INSERT:.+]] = tensor.insert_slice %transposed into %[[DEST]][0, 0, 0] [1, 1, 4] [1, 1, 1] : tensor<1x4xf32> into tensor<1x1x4xf32>
 // CHECK:        return %[[INSERT]]
+
+// -----
+
+/// Note "126", which is a non-unit tile-outer-dim. This is not supported.
+
+func.func @negative_non_unit_tiled_outer_dim(%src: tensor<1x126x1x1x8xf32>, %dest: tensor<1x1x1x1001xf32>) -> tensor<1x1x1x1001xf32> {
+  %unpack = linalg.unpack %src
+    outer_dims_perm = [0, 3, 2, 1]
+    inner_dims_pos = [3]
+    inner_tiles = [8]
+    into %dest : tensor<1x126x1x1x8xf32>
+    -> tensor<1x1x1x1001xf32>
+
+  return %unpack : tensor<1x1x1x1001xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/152960


More information about the Mlir-commits mailing list