[Mlir-commits] [mlir] [mlir][tensor] Fold pack and unpack of empty input tensor (PR #92247)
Adam Siemieniuk
llvmlistbot at llvm.org
Wed May 22 01:41:08 PDT 2024
https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/92247
>From cc2e28ba7b01882c21ef1529a651a94d84b52ca3 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 15 May 2024 13:19:57 +0200
Subject: [PATCH 1/6] [mlir][tensor] Fold pack and unpack of empty input tensor
Adds canonicalization to pack and unpack to fold away operations
when their source is a `tensor.empty`.
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 13 +++++++++++++
mlir/test/Dialect/Tensor/canonicalize.mlir | 21 +++++++++++++++++++++
2 files changed, 34 insertions(+)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 414bd7459af8f..428bf61e2fe5a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4200,6 +4200,12 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
return success();
}
+ // Fold away packing an empty source tensor.
+ if (auto emptyTensor = packOp.getSource().getDefiningOp<tensor::EmptyOp>()) {
+ rewriter.replaceOp(packOp, packOp.getDest());
+ return success();
+ }
+
// Insert tensor.cast ops if static shape inference is available..
SmallVector<int64_t> srcShape, destShape;
if (inferStaticShape(packOp, srcShape, destShape)) {
@@ -4435,6 +4441,13 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
return success();
}
+ // Fold away unpacking an empty source tensor.
+ if (auto emptyTensor =
+ unPackOp.getSource().getDefiningOp<tensor::EmptyOp>()) {
+ rewriter.replaceOp(unPackOp, unPackOp.getDest());
+ return success();
+ }
+
// Insert tensor.cast ops if static shape inference is available..
SmallVector<int64_t> srcShape, destShape;
if (inferStaticShape(unPackOp, srcShape, destShape)) {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 8036d996d2324..4922251363950 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2486,3 +2486,24 @@ func.func @dim_out_of_bounds() -> vector<7xi32> {
return %16 : vector<7xi32>
}
+// -----
+
+// CHECK: func.func @pack_empty(
+// CHECK-SAME: %[[T:.+]]: tensor<8x8x32x32xf32>
+// CHECK: return %[[T]] : tensor<8x8x32x32xf32>
+func.func @pack_empty(%arg0: tensor<8x8x32x32xf32>) -> tensor<8x8x32x32xf32> {
+ %empty_unpacked = tensor.empty() : tensor<256x256xf32>
+ %packed = tensor.pack %empty_unpacked inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg0 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>
+ return %packed : tensor<8x8x32x32xf32>
+}
+
+// -----
+
+// CHECK: func.func @unpack_empty(
+// CHECK-SAME: %[[T:.+]]: tensor<256x256xf32>
+// CHECK: return %[[T]] : tensor<256x256xf32>
+func.func @unpack_empty(%arg0: tensor<256x256xf32>) -> tensor<256x256xf32> {
+ %empty_packed = tensor.empty() : tensor<8x8x32x32xf32>
+ %unpacked = tensor.unpack %empty_packed inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg0 : tensor<8x8x32x32xf32> -> tensor<256x256xf32>
+ return %unpacked : tensor<256x256xf32>
+}
>From e6dc9fa546971307357dfe9388858d914e3bb68c Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 15 May 2024 15:16:44 +0200
Subject: [PATCH 2/6] Move logic to folder
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 17 ++++-------------
1 file changed, 4 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 428bf61e2fe5a..7e723c55cb21e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4200,12 +4200,6 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
return success();
}
- // Fold away packing an empty source tensor.
- if (auto emptyTensor = packOp.getSource().getDefiningOp<tensor::EmptyOp>()) {
- rewriter.replaceOp(packOp, packOp.getDest());
- return success();
- }
-
// Insert tensor.cast ops if static shape inference is available..
SmallVector<int64_t> srcShape, destShape;
if (inferStaticShape(packOp, srcShape, destShape)) {
@@ -4280,6 +4274,8 @@ OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
getDestType(), paddingValue))
return reshapedSource;
+ if (getSource().getDefiningOp<tensor::EmptyOp>())
+ return getDest();
return {};
}
@@ -4441,13 +4437,6 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
return success();
}
- // Fold away unpacking an empty source tensor.
- if (auto emptyTensor =
- unPackOp.getSource().getDefiningOp<tensor::EmptyOp>()) {
- rewriter.replaceOp(unPackOp, unPackOp.getDest());
- return success();
- }
-
// Insert tensor.cast ops if static shape inference is available..
SmallVector<int64_t> srcShape, destShape;
if (inferStaticShape(unPackOp, srcShape, destShape)) {
@@ -4485,6 +4474,8 @@ OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
getResult().getType()))
return reshapedSource;
+ if (getSource().getDefiningOp<tensor::EmptyOp>())
+ return getDest();
return {};
}
>From e7e5fe769f1aa82bf2d6d10a45b6e88adeb2add1 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 16 May 2024 10:53:59 +0200
Subject: [PATCH 3/6] Move to populateFoldTensorEmptyPatterns
---
.../Dialect/Tensor/Transforms/Transforms.h | 4 +-
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 4 --
.../Tensor/Transforms/EmptyOpPatterns.cpp | 66 ++++++++++++++++++-
mlir/test/Dialect/Tensor/canonicalize.mlir | 22 -------
mlir/test/Dialect/Tensor/fold-empty-op.mlir | 41 ++++++++++++
5 files changed, 108 insertions(+), 29 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index e8a09c4741043..dd6b0e8682564 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -59,8 +59,8 @@ void populateDropRedundantInsertSliceRankExpansionPatterns(
/// `tensor.collapse_shape` into other ops.
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);
-/// Populates `patterns` with patterns that fold tensor.empty with
-/// tensor.[extract_slice|expand_shape|collapse_shape].
+/// Populates `patterns` with patterns that fold tensor.empty with its
+/// consumers.
///
/// If `singleUseOnly` is set to "true", only tensor.empty ops with a single
/// use are folded.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 7e723c55cb21e..414bd7459af8f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4274,8 +4274,6 @@ OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
getDestType(), paddingValue))
return reshapedSource;
- if (getSource().getDefiningOp<tensor::EmptyOp>())
- return getDest();
return {};
}
@@ -4474,8 +4472,6 @@ OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
getResult().getType()))
return reshapedSource;
- if (getSource().getDefiningOp<tensor::EmptyOp>())
- return getDest();
return {};
}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
index 7a707e749e69b..da1af0a85c34c 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
@@ -93,12 +93,76 @@ struct FoldEmptyTensorWithExtractSliceOp
bool foldSingleUseOnly = false;
};
+/// tensor.empty does not define any tensor contents, so an unpadded pack
+/// can be folded away.
+struct FoldEmptyTensorWithPackOp : public OpRewritePattern<PackOp> {
+ FoldEmptyTensorWithPackOp(MLIRContext *ctx, PatternBenefit benefit = 1,
+ bool foldSingleUseOnly = false)
+ : OpRewritePattern<PackOp>(ctx, benefit),
+ foldSingleUseOnly(foldSingleUseOnly) {}
+
+ LogicalResult matchAndRewrite(PackOp packOp,
+ PatternRewriter &rewriter) const override {
+ // Check for tensor.empty source.
+ auto emptyOp = packOp.getSource().getDefiningOp<EmptyOp>();
+ if (!emptyOp)
+ return failure();
+
+ // Check for single use.
+ if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
+ return failure();
+
+ // Check for padding.
+ // Packing with padding cannot be simply removed.
+ if (packOp.getPaddingValue())
+ return failure();
+
+ // Replace the pack directly with its destination.
+ rewriter.replaceOp(packOp, packOp.getDest());
+
+ return success();
+ }
+
+private:
+ bool foldSingleUseOnly = false;
+};
+
+/// tensor.empty does not define any tensor contents, so an unpack
+/// can be folded away.
+struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
+ FoldEmptyTensorWithUnPackOp(MLIRContext *ctx, PatternBenefit benefit = 1,
+ bool foldSingleUseOnly = false)
+ : OpRewritePattern<UnPackOp>(ctx, benefit),
+ foldSingleUseOnly(foldSingleUseOnly) {}
+
+ LogicalResult matchAndRewrite(UnPackOp unPackOp,
+ PatternRewriter &rewriter) const override {
+ // Check for tensor.empty source.
+ auto emptyOp = unPackOp.getSource().getDefiningOp<EmptyOp>();
+ if (!emptyOp)
+ return failure();
+
+ // Check for single use.
+ if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
+ return failure();
+
+ // Replace the unpack directly with its destination.
+ rewriter.replaceOp(unPackOp, unPackOp.getDest());
+
+ return success();
+ }
+
+private:
+ bool foldSingleUseOnly = false;
+};
+
} // namespace
void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
bool foldSingleUseOnly) {
patterns.add<FoldEmptyTensorWithExtractSliceOp,
FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
- FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
+ FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>,
+ FoldEmptyTensorWithPackOp, FoldEmptyTensorWithUnPackOp>(
patterns.getContext(), /*benefit=*/1, foldSingleUseOnly);
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 4922251363950..78f27b4a8530f 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2485,25 +2485,3 @@ func.func @dim_out_of_bounds() -> vector<7xi32> {
%16 = affine.vector_load %alloc_21[%c1, %c1, %dim] : memref<?x26x2xi32>, vector<7xi32>
return %16 : vector<7xi32>
}
-
-// -----
-
-// CHECK: func.func @pack_empty(
-// CHECK-SAME: %[[T:.+]]: tensor<8x8x32x32xf32>
-// CHECK: return %[[T]] : tensor<8x8x32x32xf32>
-func.func @pack_empty(%arg0: tensor<8x8x32x32xf32>) -> tensor<8x8x32x32xf32> {
- %empty_unpacked = tensor.empty() : tensor<256x256xf32>
- %packed = tensor.pack %empty_unpacked inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg0 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>
- return %packed : tensor<8x8x32x32xf32>
-}
-
-// -----
-
-// CHECK: func.func @unpack_empty(
-// CHECK-SAME: %[[T:.+]]: tensor<256x256xf32>
-// CHECK: return %[[T]] : tensor<256x256xf32>
-func.func @unpack_empty(%arg0: tensor<256x256xf32>) -> tensor<256x256xf32> {
- %empty_packed = tensor.empty() : tensor<8x8x32x32xf32>
- %unpacked = tensor.unpack %empty_packed inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg0 : tensor<8x8x32x32xf32> -> tensor<256x256xf32>
- return %unpacked : tensor<256x256xf32>
-}
diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
index e200a4f892613..c4c35de2e6340 100644
--- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir
+++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
@@ -64,6 +64,47 @@ func.func @rank_reducing_empty_tensor_extract(%sz : index, %idx : index) -> tens
return %r: tensor<2xf32>
}
+func.func @pack_empty(%arg0: tensor<8x8x32x32xf32>) -> tensor<8x8x32x32xf32> {
+ %empty_unpacked = tensor.empty() : tensor<256x256xf32>
+ %packed = tensor.pack %empty_unpacked
+ inner_dims_pos = [0, 1] inner_tiles = [32, 32]
+ into %arg0 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>
+ return %packed : tensor<8x8x32x32xf32>
+}
+
+// CHECK-LABEL: func.func @pack_empty(
+// CHECK-SAME: %[[T:.+]]: tensor<8x8x32x32xf32>
+// CHECK-NOT: tensor.pack
+// CHECK: return %[[T]] : tensor<8x8x32x32xf32>
+
+func.func @unpack_empty(%arg0: tensor<256x256xf32>) -> tensor<256x256xf32> {
+ %empty_packed = tensor.empty() : tensor<8x8x32x32xf32>
+ %unpacked = tensor.unpack %empty_packed
+ inner_dims_pos = [0, 1] inner_tiles = [32, 32]
+ into %arg0 : tensor<8x8x32x32xf32> -> tensor<256x256xf32>
+ return %unpacked : tensor<256x256xf32>
+}
+
+// CHECK-LABEL: func.func @unpack_empty(
+// CHECK-SAME: %[[T:.+]]: tensor<256x256xf32>
+// CHECK-NOT: tensor.unpack
+// CHECK: return %[[T]] : tensor<256x256xf32>
+
+func.func @pack_padded_empty(%arg0: tensor<8x8x32x32xf32>) -> tensor<8x8x32x32xf32> {
+ %pad = arith.constant 1.0 : f32
+ %empty_unpacked = tensor.empty() : tensor<256x256xf32>
+ %packed = tensor.pack %empty_unpacked
+ padding_value(%pad : f32)
+ inner_dims_pos = [0, 1] inner_tiles = [32, 32]
+ into %arg0 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>
+ return %packed : tensor<8x8x32x32xf32>
+}
+
+// CHECK-LABEL: func.func @pack_padded_empty(
+// CHECK-SAME: %[[T:.+]]: tensor<8x8x32x32xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack
+// CHECK: return %[[PACK]] : tensor<8x8x32x32xf32>
+
// -----
module attributes {transform.with_named_sequence} {
>From 05236a1947bee4d65171f96a33ad6b1ac5780ef9 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 21 May 2024 11:18:48 +0200
Subject: [PATCH 4/6] Dynamic shape test cases
---
mlir/test/Dialect/Tensor/fold-empty-op.mlir | 26 +++++++++++++++++++++
1 file changed, 26 insertions(+)
diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
index c4c35de2e6340..67e745e98ca52 100644
--- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir
+++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
@@ -77,6 +77,19 @@ func.func @pack_empty(%arg0: tensor<8x8x32x32xf32>) -> tensor<8x8x32x32xf32> {
// CHECK-NOT: tensor.pack
// CHECK: return %[[T]] : tensor<8x8x32x32xf32>
+func.func @pack_empty_dynamic(%arg0: tensor<?x?x?x?xf32>, %dim0: index, %dim1: index) -> tensor<?x?x?x?xf32> {
+ %empty_unpacked = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
+ %packed = tensor.pack %empty_unpacked
+ inner_dims_pos = [0, 1] inner_tiles = [32, 32]
+ into %arg0 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
+ return %packed : tensor<?x?x?x?xf32>
+}
+
+// CHECK-LABEL: func.func @pack_empty_dynamic(
+// CHECK-SAME: %[[T:.+]]: tensor<?x?x?x?xf32>
+// CHECK-NOT: tensor.pack
+// CHECK: return %[[T]] : tensor<?x?x?x?xf32>
+
func.func @unpack_empty(%arg0: tensor<256x256xf32>) -> tensor<256x256xf32> {
%empty_packed = tensor.empty() : tensor<8x8x32x32xf32>
%unpacked = tensor.unpack %empty_packed
@@ -90,6 +103,19 @@ func.func @unpack_empty(%arg0: tensor<256x256xf32>) -> tensor<256x256xf32> {
// CHECK-NOT: tensor.unpack
// CHECK: return %[[T]] : tensor<256x256xf32>
+func.func @unpack_empty_dynamic(%arg0: tensor<?x?xf32>, %dim0: index, %dim1: index, %dim2: index, %dim3: index) -> tensor<?x?xf32> {
+ %empty_packed = tensor.empty(%dim0, %dim1, %dim2, %dim3) : tensor<?x?x?x?xf32>
+ %unpacked = tensor.unpack %empty_packed
+ inner_dims_pos = [0, 1] inner_tiles = [32, 32]
+ into %arg0 : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
+ return %unpacked : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func.func @unpack_empty_dynamic(
+// CHECK-SAME: %[[T:.+]]: tensor<?x?xf32>
+// CHECK-NOT: tensor.unpack
+// CHECK: return %[[T]] : tensor<?x?xf32>
+
func.func @pack_padded_empty(%arg0: tensor<8x8x32x32xf32>) -> tensor<8x8x32x32xf32> {
%pad = arith.constant 1.0 : f32
%empty_unpacked = tensor.empty() : tensor<256x256xf32>
>From d9bf82870d526aa3ad9c4635f0b3c9c37a878b3d Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 21 May 2024 11:48:52 +0200
Subject: [PATCH 5/6] Relax single use restriction for packs
---
.../Tensor/Transforms/EmptyOpPatterns.cpp | 31 ++++---------------
1 file changed, 6 insertions(+), 25 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
index da1af0a85c34c..43ad0acaf7420 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
@@ -96,10 +96,7 @@ struct FoldEmptyTensorWithExtractSliceOp
/// tensor.empty does not define any tensor contents, so an unpadded pack
/// can be folded away.
struct FoldEmptyTensorWithPackOp : public OpRewritePattern<PackOp> {
- FoldEmptyTensorWithPackOp(MLIRContext *ctx, PatternBenefit benefit = 1,
- bool foldSingleUseOnly = false)
- : OpRewritePattern<PackOp>(ctx, benefit),
- foldSingleUseOnly(foldSingleUseOnly) {}
+ using OpRewritePattern<PackOp>::OpRewritePattern;
LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
@@ -108,32 +105,22 @@ struct FoldEmptyTensorWithPackOp : public OpRewritePattern<PackOp> {
if (!emptyOp)
return failure();
- // Check for single use.
- if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
- return failure();
-
// Check for padding.
// Packing with padding cannot be simply removed.
if (packOp.getPaddingValue())
- return failure();
+ return rewriter.notifyMatchFailure(packOp, "expects no padding value");
// Replace the pack directly with its destination.
rewriter.replaceOp(packOp, packOp.getDest());
return success();
}
-
-private:
- bool foldSingleUseOnly = false;
};
/// tensor.empty does not define any tensor contents, so an unpack
/// can be folded away.
struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
- FoldEmptyTensorWithUnPackOp(MLIRContext *ctx, PatternBenefit benefit = 1,
- bool foldSingleUseOnly = false)
- : OpRewritePattern<UnPackOp>(ctx, benefit),
- foldSingleUseOnly(foldSingleUseOnly) {}
+ using OpRewritePattern<UnPackOp>::OpRewritePattern;
LogicalResult matchAndRewrite(UnPackOp unPackOp,
PatternRewriter &rewriter) const override {
@@ -142,18 +129,11 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
if (!emptyOp)
return failure();
- // Check for single use.
- if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
- return failure();
-
// Replace the unpack directly with its destination.
rewriter.replaceOp(unPackOp, unPackOp.getDest());
return success();
}
-
-private:
- bool foldSingleUseOnly = false;
};
} // namespace
@@ -162,7 +142,8 @@ void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
bool foldSingleUseOnly) {
patterns.add<FoldEmptyTensorWithExtractSliceOp,
FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
- FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>,
- FoldEmptyTensorWithPackOp, FoldEmptyTensorWithUnPackOp>(
+ FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
patterns.getContext(), /*benefit=*/1, foldSingleUseOnly);
+ patterns.add<FoldEmptyTensorWithPackOp, FoldEmptyTensorWithUnPackOp>(
+ patterns.getContext(), /*benefit=*/1);
}
>From 02d2e8a7a40c18166f60249ece4f5f547f8fbbfd Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 22 May 2024 10:39:30 +0200
Subject: [PATCH 6/6] Add args checks
---
mlir/test/Dialect/Tensor/fold-empty-op.mlir | 10 ++++++++--
1 file changed, 8 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
index 67e745e98ca52..e94f6ec7ec56e 100644
--- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir
+++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
@@ -86,7 +86,9 @@ func.func @pack_empty_dynamic(%arg0: tensor<?x?x?x?xf32>, %dim0: index, %dim1: i
}
// CHECK-LABEL: func.func @pack_empty_dynamic(
-// CHECK-SAME: %[[T:.+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[T:.+]]: tensor<?x?x?x?xf32>,
+// CHECK-SAME: %[[DIM0:[a-zA-Z0-9_]+]]: index,
+// CHECK-SAME: %[[DIM1:[a-zA-Z0-9_]+]]: index
// CHECK-NOT: tensor.pack
// CHECK: return %[[T]] : tensor<?x?x?x?xf32>
@@ -112,7 +114,11 @@ func.func @unpack_empty_dynamic(%arg0: tensor<?x?xf32>, %dim0: index, %dim1: ind
}
// CHECK-LABEL: func.func @unpack_empty_dynamic(
-// CHECK-SAME: %[[T:.+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[T:.+]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[DIM0:[a-zA-Z0-9_]+]]: index,
+// CHECK-SAME: %[[DIM1:[a-zA-Z0-9_]+]]: index,
+// CHECK-SAME: %[[DIM2:[a-zA-Z0-9_]+]]: index,
+// CHECK-SAME: %[[DIM3:[a-zA-Z0-9_]+]]: index
// CHECK-NOT: tensor.unpack
// CHECK: return %[[T]] : tensor<?x?xf32>
More information about the Mlir-commits
mailing list