[Mlir-commits] [mlir] [mlir][tensor] Fold pack-unpack with unbalanced outer_dims_perm attr (PR #92234)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 15 02:52:52 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tensor

Author: Adam Siemieniuk (adam-smnk)

<details>
<summary>Changes</summary>

Extends pack/unpack perm attribute checker to account for cases when the optional outer_dims_perm attribute might be missing in one operation and the other one has explicit identity permutation. This enables canonicalizer to fold more unpack(pack(x)) variants.

---
Full diff: https://github.com/llvm/llvm-project/pull/92234.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+7-1) 
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+26) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 414bd7459af8f..8ef447cf53a37 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4112,7 +4112,13 @@ Speculation::Speculatability PackOp::getSpeculatability() {
 static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
   if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
     return false;
-  return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm();
+  if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
+    return true;
+  // Outer dims permutation is optional.
+  // To compare unbalanced pack-unpack pair, treat no permutation as equal to
+  // identity permutation.
+  return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
+         isIdentityPermutation(unPackOp.getOuterDimsPerm());
 }
 
 // Return true if pack and unpack have the same tiles.
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 8036d996d2324..b5a82eb3e9035 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2252,6 +2252,32 @@ func.func @pack_unpack_dynamic_with_padding(%t: tensor<?x?x?x?xf32>, %dim1: inde
 
 // -----
 
+// CHECK: func.func @pack_outer_dims_unpack_no_outer_dims(
+// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
+// CHECK: return %[[T]] : tensor<16x16x?x?xf32>
+func.func @pack_outer_dims_unpack_no_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
+  %tensor_empty = tensor.empty() : tensor<128x128xf32>
+  %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
+  %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
+  %packed = tensor.pack %unpacked outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
+  return %packed : tensor<16x16x?x?xf32>
+}
+
+// -----
+
+// CHECK: func.func @pack_no_outer_dims_unpack_outer_dims(
+// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
+// CHECK: return %[[T]] : tensor<16x16x?x?xf32>
+func.func @pack_no_outer_dims_unpack_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
+  %tensor_empty = tensor.empty() : tensor<128x128xf32>
+  %unpacked = tensor.unpack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
+  %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
+  %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
+  return %packed : tensor<16x16x?x?xf32>
+}
+
+// -----
+
 // CHECK: func.func @invalid_empty_negative_size
 // CHECK: %[[IDX:.*]] = index.constant
 // CHECK: %[[T:.*]] = tensor.empty(%[[IDX]]) : tensor<4x5x?xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/92234


More information about the Mlir-commits mailing list