[Mlir-commits] [mlir] [mlir][tensor] Fold pack and unpack of empty input tensor (PR #92247)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 15 04:31:03 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Adam Siemieniuk (adam-smnk)

<details>
<summary>Changes</summary>

Adds canonicalization to pack and unpack to fold away operations when their source is a `tensor.empty`.

---
Full diff: https://github.com/llvm/llvm-project/pull/92247.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+13) 
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+21) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 414bd7459af8f..428bf61e2fe5a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4200,6 +4200,12 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     return success();
   }
 
+  // Fold away packing an empty source tensor.
+  if (auto emptyTensor = packOp.getSource().getDefiningOp<tensor::EmptyOp>()) {
+    rewriter.replaceOp(packOp, packOp.getDest());
+    return success();
+  }
+
   // Insert tensor.cast ops if static shape inference is available..
   SmallVector<int64_t> srcShape, destShape;
   if (inferStaticShape(packOp, srcShape, destShape)) {
@@ -4435,6 +4441,13 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
     return success();
   }
 
+  // Fold away unpacking an empty source tensor.
+  if (auto emptyTensor =
+          unPackOp.getSource().getDefiningOp<tensor::EmptyOp>()) {
+    rewriter.replaceOp(unPackOp, unPackOp.getDest());
+    return success();
+  }
+
   // Insert tensor.cast ops if static shape inference is available..
   SmallVector<int64_t> srcShape, destShape;
   if (inferStaticShape(unPackOp, srcShape, destShape)) {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 8036d996d2324..4922251363950 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2486,3 +2486,24 @@ func.func @dim_out_of_bounds() -> vector<7xi32> {
     return %16 : vector<7xi32>
 }
 
+// -----
+
+// CHECK: func.func @pack_empty(
+// CHECK-SAME: %[[T:.+]]: tensor<8x8x32x32xf32>
+// CHECK: return %[[T]] : tensor<8x8x32x32xf32>
+func.func @pack_empty(%arg0: tensor<8x8x32x32xf32>) -> tensor<8x8x32x32xf32> {
+  %empty_unpacked = tensor.empty() : tensor<256x256xf32>
+  %packed = tensor.pack %empty_unpacked inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg0 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>
+  return %packed : tensor<8x8x32x32xf32>
+}
+
+// -----
+
+// CHECK: func.func @unpack_empty(
+// CHECK-SAME: %[[T:.+]]: tensor<256x256xf32>
+// CHECK: return %[[T]] : tensor<256x256xf32>
+func.func @unpack_empty(%arg0: tensor<256x256xf32>) -> tensor<256x256xf32> {
+  %empty_packed = tensor.empty() : tensor<8x8x32x32xf32>
+  %unpacked = tensor.unpack %empty_packed inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg0 : tensor<8x8x32x32xf32> -> tensor<256x256xf32>
+  return %unpacked : tensor<256x256xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/92247


More information about the Mlir-commits mailing list