[Mlir-commits] [mlir] b586149 - [mlir][tensor] Fold pack and unpack of empty input tensor (#92247)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 22 09:01:19 PDT 2024


Author: Adam Siemieniuk
Date: 2024-05-22T18:01:14+02:00
New Revision: b586149475d71440e4ef444f5df34630101bc669

URL: https://github.com/llvm/llvm-project/commit/b586149475d71440e4ef444f5df34630101bc669
DIFF: https://github.com/llvm/llvm-project/commit/b586149475d71440e4ef444f5df34630101bc669.diff

LOG: [mlir][tensor] Fold pack and unpack of empty input tensor (#92247)

Extends `tensor.empty` folding patterns with pack and unpack consumers
to fold away the operations when their source is empty.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
    mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir
    mlir/test/Dialect/Tensor/fold-empty-op.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index e8a09c4741043..dd6b0e8682564 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -59,8 +59,8 @@ void populateDropRedundantInsertSliceRankExpansionPatterns(
 /// `tensor.collapse_shape` into other ops.
 void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);
 
-/// Populates `patterns` with patterns that fold tensor.empty with
-/// tensor.[extract_slice|expand_shape|collapse_shape].
+/// Populates `patterns` with patterns that fold tensor.empty with its
+/// consumers.
 ///
 /// If `singleUseOnly` is set to "true", only tensor.empty ops with a single
 /// use are folded.

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
index 7a707e749e69b..43ad0acaf7420 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
@@ -93,6 +93,49 @@ struct FoldEmptyTensorWithExtractSliceOp
   bool foldSingleUseOnly = false;
 };
 
+/// tensor.empty does not define any tensor contents, so an unpadded pack
+/// can be folded away.
+struct FoldEmptyTensorWithPackOp : public OpRewritePattern<PackOp> {
+  using OpRewritePattern<PackOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    // Check for tensor.empty source.
+    auto emptyOp = packOp.getSource().getDefiningOp<EmptyOp>();
+    if (!emptyOp)
+      return failure();
+
+    // Check for padding.
+    // Packing with padding cannot be simply removed.
+    if (packOp.getPaddingValue())
+      return rewriter.notifyMatchFailure(packOp, "expects no padding value");
+
+    // Replace the pack directly with its destination.
+    rewriter.replaceOp(packOp, packOp.getDest());
+
+    return success();
+  }
+};
+
+/// tensor.empty does not define any tensor contents, so an unpack
+/// can be folded away.
+struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
+  using OpRewritePattern<UnPackOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(UnPackOp unPackOp,
+                                PatternRewriter &rewriter) const override {
+    // Check for tensor.empty source.
+    auto emptyOp = unPackOp.getSource().getDefiningOp<EmptyOp>();
+    if (!emptyOp)
+      return failure();
+
+    // Replace the unpack directly with its destination.
+    rewriter.replaceOp(unPackOp, unPackOp.getDest());
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
@@ -101,4 +144,6 @@ void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
                FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
                FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
       patterns.getContext(), /*benefit=*/1, foldSingleUseOnly);
+  patterns.add<FoldEmptyTensorWithPackOp, FoldEmptyTensorWithUnPackOp>(
+      patterns.getContext(), /*benefit=*/1);
 }

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 914e5e8b8c4b8..f7fbd3834288b 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2523,4 +2523,3 @@ func.func @dim_out_of_bounds() -> vector<7xi32> {
     %16 = affine.vector_load %alloc_21[%c1, %c1, %dim] : memref<?x26x2xi32>, vector<7xi32>
     return %16 : vector<7xi32>
 }
-

diff  --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
index e200a4f892613..e94f6ec7ec56e 100644
--- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir
+++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
@@ -64,6 +64,79 @@ func.func @rank_reducing_empty_tensor_extract(%sz : index, %idx : index) -> tens
   return %r: tensor<2xf32>
 }
 
+func.func @pack_empty(%arg0: tensor<8x8x32x32xf32>) -> tensor<8x8x32x32xf32> {
+  %empty_unpacked = tensor.empty() : tensor<256x256xf32>
+  %packed = tensor.pack %empty_unpacked
+    inner_dims_pos = [0, 1] inner_tiles = [32, 32]
+    into %arg0 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>
+  return %packed : tensor<8x8x32x32xf32>
+}
+
+// CHECK-LABEL: func.func @pack_empty(
+// CHECK-SAME:   %[[T:.+]]: tensor<8x8x32x32xf32>
+// CHECK-NOT:    tensor.pack
+// CHECK:        return %[[T]] : tensor<8x8x32x32xf32>
+
+func.func @pack_empty_dynamic(%arg0: tensor<?x?x?x?xf32>, %dim0: index, %dim1: index) -> tensor<?x?x?x?xf32> {
+  %empty_unpacked = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
+  %packed = tensor.pack %empty_unpacked
+    inner_dims_pos = [0, 1] inner_tiles = [32, 32]
+    into %arg0 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
+  return %packed : tensor<?x?x?x?xf32>
+}
+
+// CHECK-LABEL: func.func @pack_empty_dynamic(
+// CHECK-SAME:   %[[T:.+]]: tensor<?x?x?x?xf32>,
+// CHECK-SAME:   %[[DIM0:[a-zA-Z0-9_]+]]: index,
+// CHECK-SAME:   %[[DIM1:[a-zA-Z0-9_]+]]: index
+// CHECK-NOT:    tensor.pack
+// CHECK:        return %[[T]] : tensor<?x?x?x?xf32>
+
+func.func @unpack_empty(%arg0: tensor<256x256xf32>) -> tensor<256x256xf32> {
+  %empty_packed = tensor.empty() : tensor<8x8x32x32xf32>
+  %unpacked = tensor.unpack %empty_packed
+    inner_dims_pos = [0, 1] inner_tiles = [32, 32]
+    into %arg0 : tensor<8x8x32x32xf32> -> tensor<256x256xf32>
+  return %unpacked : tensor<256x256xf32>
+}
+
+// CHECK-LABEL: func.func @unpack_empty(
+// CHECK-SAME:   %[[T:.+]]: tensor<256x256xf32>
+// CHECK-NOT:    tensor.unpack
+// CHECK:        return %[[T]] : tensor<256x256xf32>
+
+func.func @unpack_empty_dynamic(%arg0: tensor<?x?xf32>, %dim0: index, %dim1: index, %dim2: index, %dim3: index) -> tensor<?x?xf32> {
+  %empty_packed = tensor.empty(%dim0, %dim1, %dim2, %dim3) : tensor<?x?x?x?xf32>
+  %unpacked = tensor.unpack %empty_packed
+    inner_dims_pos = [0, 1] inner_tiles = [32, 32]
+    into %arg0 : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
+  return %unpacked : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func.func @unpack_empty_dynamic(
+// CHECK-SAME:   %[[T:.+]]: tensor<?x?xf32>,
+// CHECK-SAME:   %[[DIM0:[a-zA-Z0-9_]+]]: index,
+// CHECK-SAME:   %[[DIM1:[a-zA-Z0-9_]+]]: index,
+// CHECK-SAME:   %[[DIM2:[a-zA-Z0-9_]+]]: index,
+// CHECK-SAME:   %[[DIM3:[a-zA-Z0-9_]+]]: index
+// CHECK-NOT:    tensor.unpack
+// CHECK:        return %[[T]] : tensor<?x?xf32>
+
+func.func @pack_padded_empty(%arg0: tensor<8x8x32x32xf32>) -> tensor<8x8x32x32xf32> {
+  %pad = arith.constant 1.0 : f32
+  %empty_unpacked = tensor.empty() : tensor<256x256xf32>
+  %packed = tensor.pack %empty_unpacked
+    padding_value(%pad : f32)
+    inner_dims_pos = [0, 1] inner_tiles = [32, 32]
+    into %arg0 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>
+  return %packed : tensor<8x8x32x32xf32>
+}
+
+// CHECK-LABEL: func.func @pack_padded_empty(
+// CHECK-SAME:   %[[T:.+]]: tensor<8x8x32x32xf32>
+// CHECK:        %[[PACK:.+]] = tensor.pack
+// CHECK:        return %[[PACK]] : tensor<8x8x32x32xf32>
+
 // -----
 
 module attributes {transform.with_named_sequence} {


        


More information about the Mlir-commits mailing list