[Mlir-commits] [mlir] dcd32bd - [mlir][tensor] Fold pack-unpack with unbalanced outer_dims_perm attr (#92234)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 16 01:05:17 PDT 2024
Author: Adam Siemieniuk
Date: 2024-05-16T10:05:12+02:00
New Revision: dcd32bd65f16e80db2485e6e02b62d6a91c3cddf
URL: https://github.com/llvm/llvm-project/commit/dcd32bd65f16e80db2485e6e02b62d6a91c3cddf
DIFF: https://github.com/llvm/llvm-project/commit/dcd32bd65f16e80db2485e6e02b62d6a91c3cddf.diff
LOG: [mlir][tensor] Fold pack-unpack with unbalanced outer_dims_perm attr (#92234)
Extends pack/unpack perm attribute checker to account for cases when the
optional outer_dims_perm attribute might be missing in one operation and
the other one has explicit identity permutation. This enables
canonicalizer to fold more unpack(pack(x)) variants.
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index d3b1754cbe1cd..8a6df82abb312 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4112,7 +4112,13 @@ Speculation::Speculatability PackOp::getSpeculatability() {
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
return false;
- return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm();
+ if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
+ return true;
+ // Outer dims permutation is optional.
+ // To compare unbalanced pack-unpack pair, treat no permutation as equal to
+ // identity permutation.
+ return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
+ isIdentityPermutation(unPackOp.getOuterDimsPerm());
}
// Return true if pack and unpack have the same tiles.
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 8036d996d2324..b5a82eb3e9035 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2252,6 +2252,32 @@ func.func @pack_unpack_dynamic_with_padding(%t: tensor<?x?x?x?xf32>, %dim1: inde
// -----
+// CHECK: func.func @pack_outer_dims_unpack_no_outer_dims(
+// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
+// CHECK: return %[[T]] : tensor<16x16x?x?xf32>
+func.func @pack_outer_dims_unpack_no_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
+ %tensor_empty = tensor.empty() : tensor<128x128xf32>
+ %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
+ %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
+ %packed = tensor.pack %unpacked outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
+ return %packed : tensor<16x16x?x?xf32>
+}
+
+// -----
+
+// CHECK: func.func @pack_no_outer_dims_unpack_outer_dims(
+// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
+// CHECK: return %[[T]] : tensor<16x16x?x?xf32>
+func.func @pack_no_outer_dims_unpack_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
+ %tensor_empty = tensor.empty() : tensor<128x128xf32>
+ %unpacked = tensor.unpack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
+ %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
+ %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
+ return %packed : tensor<16x16x?x?xf32>
+}
+
+// -----
+
// CHECK: func.func @invalid_empty_negative_size
// CHECK: %[[IDX:.*]] = index.constant
// CHECK: %[[T:.*]] = tensor.empty(%[[IDX]]) : tensor<4x5x?xf32>
More information about the Mlir-commits
mailing list