[Mlir-commits] [mlir] [mlir][tensor] Fold pack and unpack of empty input tensor (PR #92247)
Adam Siemieniuk
llvmlistbot at llvm.org
Wed May 15 04:30:32 PDT 2024
https://github.com/adam-smnk created https://github.com/llvm/llvm-project/pull/92247
Adds canonicalization to pack and unpack to fold away operations when their source is a `tensor.empty`.
>From cc2e28ba7b01882c21ef1529a651a94d84b52ca3 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 15 May 2024 13:19:57 +0200
Subject: [PATCH] [mlir][tensor] Fold pack and unpack of empty input tensor
Adds canonicalization to pack and unpack to fold away operations
when their source is a `tensor.empty`.
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 13 +++++++++++++
mlir/test/Dialect/Tensor/canonicalize.mlir | 21 +++++++++++++++++++++
2 files changed, 34 insertions(+)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 414bd7459af8f..428bf61e2fe5a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4200,6 +4200,12 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
return success();
}
+ // Fold away packing an empty source tensor.
+ if (auto emptyTensor = packOp.getSource().getDefiningOp<tensor::EmptyOp>()) {
+ rewriter.replaceOp(packOp, packOp.getDest());
+ return success();
+ }
+
// Insert tensor.cast ops if static shape inference is available..
SmallVector<int64_t> srcShape, destShape;
if (inferStaticShape(packOp, srcShape, destShape)) {
@@ -4435,6 +4441,13 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
return success();
}
+ // Fold away unpacking an empty source tensor.
+ if (auto emptyTensor =
+ unPackOp.getSource().getDefiningOp<tensor::EmptyOp>()) {
+ rewriter.replaceOp(unPackOp, unPackOp.getDest());
+ return success();
+ }
+
// Insert tensor.cast ops if static shape inference is available..
SmallVector<int64_t> srcShape, destShape;
if (inferStaticShape(unPackOp, srcShape, destShape)) {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 8036d996d2324..4922251363950 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2486,3 +2486,24 @@ func.func @dim_out_of_bounds() -> vector<7xi32> {
return %16 : vector<7xi32>
}
+// -----
+
+// CHECK: func.func @pack_empty(
+// CHECK-SAME: %[[T:.+]]: tensor<8x8x32x32xf32>
+// CHECK: return %[[T]] : tensor<8x8x32x32xf32>
+func.func @pack_empty(%arg0: tensor<8x8x32x32xf32>) -> tensor<8x8x32x32xf32> {
+ %empty_unpacked = tensor.empty() : tensor<256x256xf32>
+ %packed = tensor.pack %empty_unpacked inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg0 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>
+ return %packed : tensor<8x8x32x32xf32>
+}
+
+// -----
+
+// CHECK: func.func @unpack_empty(
+// CHECK-SAME: %[[T:.+]]: tensor<256x256xf32>
+// CHECK: return %[[T]] : tensor<256x256xf32>
+func.func @unpack_empty(%arg0: tensor<256x256xf32>) -> tensor<256x256xf32> {
+ %empty_packed = tensor.empty() : tensor<8x8x32x32xf32>
+ %unpacked = tensor.unpack %empty_packed inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg0 : tensor<8x8x32x32xf32> -> tensor<256x256xf32>
+ return %unpacked : tensor<256x256xf32>
+}
More information about the Mlir-commits
mailing list