[Mlir-commits] [mlir] [mlir][linalg] fix OuterUnitDims linalg.pack decomposition pattern (PR #141613)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 27 07:40:54 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Christopher McGirr (chrsmcgrr)

<details>
<summary>Changes</summary>

Given the following example:
```
module {
  func.func @<!-- -->main(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x1x4x1xf32> {
    %pack = linalg.pack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg0 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32>
    return %pack : tensor<1x1x1x4x1xf32>
  }
}
```

We would generate an invalid transpose operation because the calculated permutation would be `[0, 2, 0]` which is semantically incorrect. As the permutation must contain unique integers corresponding to the source tensor dimensions.

The following change modifies how we calculate the permutation array and ensures that the dimension indices given in the permutation array is unique.

The above example would then translate to a transpose having a permutation of `[1, 2, 0]`. Following the rule, that the `inner_dim_pos` is appended to the permutation array and the preceding indices are filled with the remaining dimensions.

---
Full diff: https://github.com/llvm/llvm-project/pull/141613.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+15-8) 
- (modified) mlir/test/Dialect/Linalg/decompose-pack.mlir (+19) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 8718c57b9e86c..7b6c8243d1040 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1205,16 +1205,23 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
   //    %init = tensor.empty()
   //    %transposed_tile = linalg.transpose ins(%source_or_padded_source),
   //                                        outs(%init)
-  // Two assumptions are made:
-  //  1. All outer dims are 1 - the corresponding transposition doesn't matter.
-  //  2. Inner dims position correspond to the trailing `numTiles` dims.
-  SmallVector<int64_t> tilesPermNormalized =
-      getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos());
+  // Assumptions made:
+  //  1. Inner dims position correspond to the trailing `numTiles` dims.
   SmallVector<int64_t> srcPermForTranspose;
-  for (int64_t i = 0; i < (srcRank - numTiles); i++)
+  ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos());
+  for (int64_t i = 0; i < srcRank; i++) {
+    // As we assume the trailing dimensions of the inner dim position correspond
+    // to the trailing indices of the transpose permutation, we need to
+    // calculate the remaining indicies of the transpose permutation. This is
+    // done by adding the indices not contained in the inner dimension position.
+    //   For example if we have a source tensor of dimensions [0, 1, 2, 3]
+    //   and inner dim position of [3, 0], the remaining indices are [1, 2].
+    //   and the transpose will be [1, 2, 3, 0].
+    if (llvm::is_contained(innerDimPos, i))
+      continue;
     srcPermForTranspose.push_back(i);
-
-  srcPermForTranspose.append(SmallVector<int64_t>(packOp.getInnerDimsPos()));
+  }
+  srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end());
 
   LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"
                     << "perm: " << llvm::interleaved(srcPermForTranspose)
diff --git a/mlir/test/Dialect/Linalg/decompose-pack.mlir b/mlir/test/Dialect/Linalg/decompose-pack.mlir
index 911b453f919c3..6d091406a639c 100644
--- a/mlir/test/Dialect/Linalg/decompose-pack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-pack.mlir
@@ -229,3 +229,22 @@ func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x
 // CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
 // CHECK-SAME:      [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
 // CHECK:         return %[[INSERT]]
+
+// -----
+
+func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x4xf32>, %arg1: tensor<1x1x1x4x1xf32>) -> tensor<1x1x1x4x1xf32> {
+  %pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32>
+  return %pack : tensor<1x1x1x4x1xf32>
+}
+
+// CHECK-LABEL: func.func @pack_with_unit_outer_dims_and_unit_inner
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x4x1xf32>
+// CHECK:         %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME:      ins(%[[SRC]] : tensor<1x1x4xf32>)
+// CHECK-SAME:      outs(%[[EMPTY]] : tensor<1x4x1xf32>)
+// CHECK-SAME:      permutation = [1, 2, 0]
+// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-SAME:      [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
+// CHECK:         return %[[INSERT]]
\ No newline at end of file

``````````

</details>


https://github.com/llvm/llvm-project/pull/141613


More information about the Mlir-commits mailing list