[Mlir-commits] [mlir] [mlir][tensor] Fold pack-unpack with unbalanced outer_dims_perm attr (PR #92234)
Adam Siemieniuk
llvmlistbot at llvm.org
Wed May 15 02:52:16 PDT 2024
https://github.com/adam-smnk created https://github.com/llvm/llvm-project/pull/92234
Extends pack/unpack perm attribute checker to account for cases when the optional outer_dims_perm attribute might be missing in one operation and the other one has explicit identity permutation. This enables canonicalizer to fold more unpack(pack(x)) variants.
>From 7ca0ed5a181020f31a93409c8187eb778d86f589 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 15 May 2024 11:37:59 +0200
Subject: [PATCH] [mlir][tensor] Fold pack-unpack with unbalanced
outer_dims_perm attr
Extend pack/unpack perm attribute checker to account for cases when
the optional outer_dims_perm attribute might be missing in one operation
and the other one has explicit identity permutation.
This enables canonicalizer to fold more unpack(pack(x)) variants.
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 8 ++++++-
mlir/test/Dialect/Tensor/canonicalize.mlir | 26 ++++++++++++++++++++++
2 files changed, 33 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 414bd7459af8f..8ef447cf53a37 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4112,7 +4112,13 @@ Speculation::Speculatability PackOp::getSpeculatability() {
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
return false;
- return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm();
+ if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
+ return true;
+ // Outer dims permutation is optional.
+ // To compare unbalanced pack-unpack pair, treat no permutation as equal to
+ // identity permutation.
+ return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
+ isIdentityPermutation(unPackOp.getOuterDimsPerm());
}
// Return true if pack and unpack have the same tiles.
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 8036d996d2324..b5a82eb3e9035 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2252,6 +2252,32 @@ func.func @pack_unpack_dynamic_with_padding(%t: tensor<?x?x?x?xf32>, %dim1: inde
// -----
+// CHECK: func.func @pack_outer_dims_unpack_no_outer_dims(
+// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
+// CHECK: return %[[T]] : tensor<16x16x?x?xf32>
+func.func @pack_outer_dims_unpack_no_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
+ %tensor_empty = tensor.empty() : tensor<128x128xf32>
+ %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
+ %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
+ %packed = tensor.pack %unpacked outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
+ return %packed : tensor<16x16x?x?xf32>
+}
+
+// -----
+
+// CHECK: func.func @pack_no_outer_dims_unpack_outer_dims(
+// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
+// CHECK: return %[[T]] : tensor<16x16x?x?xf32>
+func.func @pack_no_outer_dims_unpack_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
+ %tensor_empty = tensor.empty() : tensor<128x128xf32>
+ %unpacked = tensor.unpack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
+ %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
+ %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
+ return %packed : tensor<16x16x?x?xf32>
+}
+
+// -----
+
// CHECK: func.func @invalid_empty_negative_size
// CHECK: %[[IDX:.*]] = index.constant
// CHECK: %[[T:.*]] = tensor.empty(%[[IDX]]) : tensor<4x5x?xf32>
More information about the Mlir-commits
mailing list