[Mlir-commits] [mlir] [mlir][tensor] Fold pack-unpack with unbalanced outer_dims_perm attr (PR #92234)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 15 02:52:52 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor
Author: Adam Siemieniuk (adam-smnk)
<details>
<summary>Changes</summary>
Extends pack/unpack perm attribute checker to account for cases when the optional outer_dims_perm attribute might be missing in one operation and the other one has explicit identity permutation. This enables canonicalizer to fold more unpack(pack(x)) variants.
---
Full diff: https://github.com/llvm/llvm-project/pull/92234.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+7-1)
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+26)
``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 414bd7459af8f..8ef447cf53a37 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4112,7 +4112,13 @@ Speculation::Speculatability PackOp::getSpeculatability() {
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
return false;
- return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm();
+ if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
+ return true;
+ // Outer dims permutation is optional.
+ // To compare unbalanced pack-unpack pair, treat no permutation as equal to
+ // identity permutation.
+ return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
+ isIdentityPermutation(unPackOp.getOuterDimsPerm());
}
// Return true if pack and unpack have the same tiles.
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 8036d996d2324..b5a82eb3e9035 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2252,6 +2252,32 @@ func.func @pack_unpack_dynamic_with_padding(%t: tensor<?x?x?x?xf32>, %dim1: inde
// -----
+// CHECK: func.func @pack_outer_dims_unpack_no_outer_dims(
+// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
+// CHECK: return %[[T]] : tensor<16x16x?x?xf32>
+func.func @pack_outer_dims_unpack_no_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
+ %tensor_empty = tensor.empty() : tensor<128x128xf32>
+ %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
+ %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
+ %packed = tensor.pack %unpacked outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
+ return %packed : tensor<16x16x?x?xf32>
+}
+
+// -----
+
+// CHECK: func.func @pack_no_outer_dims_unpack_outer_dims(
+// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
+// CHECK: return %[[T]] : tensor<16x16x?x?xf32>
+func.func @pack_no_outer_dims_unpack_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
+ %tensor_empty = tensor.empty() : tensor<128x128xf32>
+ %unpacked = tensor.unpack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
+ %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
+ %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
+ return %packed : tensor<16x16x?x?xf32>
+}
+
+// -----
+
// CHECK: func.func @invalid_empty_negative_size
// CHECK: %[[IDX:.*]] = index.constant
// CHECK: %[[T:.*]] = tensor.empty(%[[IDX]]) : tensor<4x5x?xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/92234
More information about the Mlir-commits
mailing list