[Mlir-commits] [mlir] [mlir][tensor] Fold pack and unpack of empty input tensor (PR #92247)

Adam Siemieniuk llvmlistbot at llvm.org
Wed May 15 06:17:00 PDT 2024


https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/92247

>From cc2e28ba7b01882c21ef1529a651a94d84b52ca3 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 15 May 2024 13:19:57 +0200
Subject: [PATCH 1/2] [mlir][tensor] Fold pack and unpack of empty input tensor

Adds canonicalization to pack and unpack to fold away operations
when their source is a `tensor.empty`.
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp   | 13 +++++++++++++
 mlir/test/Dialect/Tensor/canonicalize.mlir | 21 +++++++++++++++++++++
 2 files changed, 34 insertions(+)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 414bd7459af8f..428bf61e2fe5a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4200,6 +4200,12 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     return success();
   }
 
+  // Fold away packing an empty source tensor.
+  if (auto emptyTensor = packOp.getSource().getDefiningOp<tensor::EmptyOp>()) {
+    rewriter.replaceOp(packOp, packOp.getDest());
+    return success();
+  }
+
   // Insert tensor.cast ops if static shape inference is available..
   SmallVector<int64_t> srcShape, destShape;
   if (inferStaticShape(packOp, srcShape, destShape)) {
@@ -4435,6 +4441,13 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
     return success();
   }
 
+  // Fold away unpacking an empty source tensor.
+  if (auto emptyTensor =
+          unPackOp.getSource().getDefiningOp<tensor::EmptyOp>()) {
+    rewriter.replaceOp(unPackOp, unPackOp.getDest());
+    return success();
+  }
+
   // Insert tensor.cast ops if static shape inference is available..
   SmallVector<int64_t> srcShape, destShape;
   if (inferStaticShape(unPackOp, srcShape, destShape)) {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 8036d996d2324..4922251363950 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2486,3 +2486,24 @@ func.func @dim_out_of_bounds() -> vector<7xi32> {
     return %16 : vector<7xi32>
 }
 
+// -----
+
+// CHECK: func.func @pack_empty(
+// CHECK-SAME: %[[T:.+]]: tensor<8x8x32x32xf32>
+// CHECK: return %[[T]] : tensor<8x8x32x32xf32>
+func.func @pack_empty(%arg0: tensor<8x8x32x32xf32>) -> tensor<8x8x32x32xf32> {
+  %empty_unpacked = tensor.empty() : tensor<256x256xf32>
+  %packed = tensor.pack %empty_unpacked inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg0 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>
+  return %packed : tensor<8x8x32x32xf32>
+}
+
+// -----
+
+// CHECK: func.func @unpack_empty(
+// CHECK-SAME: %[[T:.+]]: tensor<256x256xf32>
+// CHECK: return %[[T]] : tensor<256x256xf32>
+func.func @unpack_empty(%arg0: tensor<256x256xf32>) -> tensor<256x256xf32> {
+  %empty_packed = tensor.empty() : tensor<8x8x32x32xf32>
+  %unpacked = tensor.unpack %empty_packed inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg0 : tensor<8x8x32x32xf32> -> tensor<256x256xf32>
+  return %unpacked : tensor<256x256xf32>
+}

>From e6dc9fa546971307357dfe9388858d914e3bb68c Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 15 May 2024 15:16:44 +0200
Subject: [PATCH 2/2] Move logic to folder

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 17 ++++-------------
 1 file changed, 4 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 428bf61e2fe5a..7e723c55cb21e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4200,12 +4200,6 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     return success();
   }
 
-  // Fold away packing an empty source tensor.
-  if (auto emptyTensor = packOp.getSource().getDefiningOp<tensor::EmptyOp>()) {
-    rewriter.replaceOp(packOp, packOp.getDest());
-    return success();
-  }
-
   // Insert tensor.cast ops if static shape inference is available..
   SmallVector<int64_t> srcShape, destShape;
   if (inferStaticShape(packOp, srcShape, destShape)) {
@@ -4280,6 +4274,8 @@ OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
           getDestType(), paddingValue))
     return reshapedSource;
+  if (getSource().getDefiningOp<tensor::EmptyOp>())
+    return getDest();
   return {};
 }
 
@@ -4441,13 +4437,6 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
     return success();
   }
 
-  // Fold away unpacking an empty source tensor.
-  if (auto emptyTensor =
-          unPackOp.getSource().getDefiningOp<tensor::EmptyOp>()) {
-    rewriter.replaceOp(unPackOp, unPackOp.getDest());
-    return success();
-  }
-
   // Insert tensor.cast ops if static shape inference is available..
   SmallVector<int64_t> srcShape, destShape;
   if (inferStaticShape(unPackOp, srcShape, destShape)) {
@@ -4485,6 +4474,8 @@ OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
           getResult().getType()))
     return reshapedSource;
+  if (getSource().getDefiningOp<tensor::EmptyOp>())
+    return getDest();
   return {};
 }
 



More information about the Mlir-commits mailing list