[Mlir-commits] [mlir] 4b14205 - [mlir][tensor] Centralize pack/unpack related patterns. (#76603)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Dec 30 11:40:44 PST 2023
Author: Han-Chung Wang
Date: 2023-12-30T11:40:40-08:00
New Revision: 4b14205bc0b8e91a8e94c63773e01f20a6505188
URL: https://github.com/llvm/llvm-project/commit/4b14205bc0b8e91a8e94c63773e01f20a6505188
DIFF: https://github.com/llvm/llvm-project/commit/4b14205bc0b8e91a8e94c63773e01f20a6505188.diff
LOG: [mlir][tensor] Centralize pack/unpack related patterns. (#76603)
The revision moves pack/unpack related patterns to
PackAndUnpackPatterns.cpp. This follows the convention like other tensor
ops.
It also renames `populateSimplifyTensorPack` to
`populateSimplifyPackAndUnpackPatterns` and adds a TODO item for
tensor.unpack op.
Added:
mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
Modified:
mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
Removed:
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
mlir/test/Dialect/Tensor/simplify-tensor-pack.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 06642adda42b38..0a21c9922b223b 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -163,9 +163,6 @@ void populateFoldConstantExtractSlicePatterns(
return false;
});
-/// Patterns to simplify tensor.pack.
-void populateSimplifyTensorPack(RewritePatternSet &patterns);
-
} // namespace tensor
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 44b8377bd6aad9..35b519e790d1c3 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -74,6 +74,11 @@ void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
/// that it can be bufferized into a sequence of copies.
void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
+/// Populates `patterns` with patterns that simplify `tensor.pack` and
+/// `tensor.unpack` operations.
+/// TODO: Add a pattern to convert tensor.unpack op to tensor.collapse_shape op.
+void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns);
+
/// Populates `patterns` with patterns that fold operations like `tensor.pad`
/// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
/// respectively.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 7c35dd4d953619..816e6ba8fed94e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3466,44 +3466,6 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
// PackOp/UnPackOp Common
//===----------------------------------------------------------------------===//
-namespace {
-
-/// Packing one-dimensional tensor can be expressed as an expand shape op.
-struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
- using OpRewritePattern<PackOp>::OpRewritePattern;
-
- Value insertExpand(RewriterBase &rewriter, Location loc, Value operand,
- Type newOperandType, ArrayAttr reassociation) const {
- if (operand.getType() == newOperandType)
- return operand;
- return rewriter.create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
- reassociation);
- }
-
- LogicalResult matchAndRewrite(PackOp packOp,
- PatternRewriter &rewriter) const override {
- RankedTensorType sourceType = packOp.getSourceType();
- RankedTensorType destType = packOp.getDestType();
- if (sourceType.getRank() != 1 || packOp.getPaddingValue())
- return failure();
- auto reassociation =
- getReassociationIndicesForReshape(sourceType, destType);
- if (!reassociation)
- return failure();
- Value expanded = insertExpand(
- rewriter, packOp.getLoc(), packOp.getSource(), destType,
- getReassociationIndicesAttribute(rewriter, *reassociation));
- rewriter.replaceOp(packOp, expanded);
- return success();
- }
-};
-
-} // namespace
-
-void mlir::tensor::populateSimplifyTensorPack(RewritePatternSet &patterns) {
- patterns.add<SimplifyPackToExpandShape>(patterns.getContext());
-}
-
template <typename OpTy>
static LogicalResult
reifyResultShapesImpl(OpTy op, OpBuilder &builder,
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index d233ab7a0e8974..cbc0d499d9d52c 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -4,10 +4,10 @@ add_mlir_dialect_library(MLIRTensorTransforms
ConcatOpPatterns.cpp
EmptyOpPatterns.cpp
ExtractSliceFromReshapeUtils.cpp
- FoldIntoPackAndUnpackPatterns.cpp
FoldTensorSubsetOps.cpp
IndependenceTransforms.cpp
MergeConsecutiveInsertExtractSlicePatterns.cpp
+ PackAndUnpackPatterns.cpp
ReshapePatterns.cpp
RewriteAsConstant.cpp
SwapExtractSliceWithProducerPatterns.cpp
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
similarity index 80%
rename from mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
rename to mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index e4509b331beeac..67651a2e38c82d 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -21,6 +21,36 @@ static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
}
+/// Packing one-dimensional tensor can be expressed as an expand shape op.
+struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
+ using OpRewritePattern<PackOp>::OpRewritePattern;
+
+ Value insertExpand(RewriterBase &rewriter, Location loc, Value operand,
+ Type newOperandType, ArrayAttr reassociation) const {
+ if (operand.getType() == newOperandType)
+ return operand;
+ return rewriter.create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
+ reassociation);
+ }
+
+ LogicalResult matchAndRewrite(PackOp packOp,
+ PatternRewriter &rewriter) const override {
+ RankedTensorType sourceType = packOp.getSourceType();
+ RankedTensorType destType = packOp.getDestType();
+ if (sourceType.getRank() != 1 || packOp.getPaddingValue())
+ return failure();
+ auto reassociation =
+ getReassociationIndicesForReshape(sourceType, destType);
+ if (!reassociation)
+ return failure();
+ Value expanded = insertExpand(
+ rewriter, packOp.getLoc(), packOp.getSource(), destType,
+ getReassociationIndicesAttribute(rewriter, *reassociation));
+ rewriter.replaceOp(packOp, expanded);
+ return success();
+ }
+};
+
/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
/// the pad op has zero low paddings, or if `pack` has no padding values.
struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
@@ -150,5 +180,9 @@ void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
patterns.getContext());
}
+void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
+ patterns.add<SimplifyPackToExpandShape>(patterns.getContext());
+}
+
} // namespace tensor
} // namespace mlir
diff --git a/mlir/test/Dialect/Tensor/simplify-tensor-pack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
similarity index 95%
rename from mlir/test/Dialect/Tensor/simplify-tensor-pack.mlir
rename to mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
index 75eb33ed033b9e..049076a67bae53 100644
--- a/mlir/test/Dialect/Tensor/simplify-tensor-pack.mlir
+++ b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-simplify-pack-patterns" %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-simplify-pack-unpack-patterns" %s | FileCheck %s
// CHECK: func.func @single_dim_packing(
// CHECK-SAME: %[[ARG0:.+]]: tensor<256xf32>)
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 3e142155df8d9b..b907f77e910825 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -84,9 +84,9 @@ struct TestTensorTransforms
"the extract_slice of collapse_shape pattern"),
llvm::cl::init(false)};
- Option<bool> testSimplifyPackPatterns{
- *this, "test-simplify-pack-patterns",
- llvm::cl::desc("Test patterns to simplify tensor.pack"),
+ Option<bool> testSimplifyPackUnpackPatterns{
+ *this, "test-simplify-pack-unpack-patterns",
+ llvm::cl::desc("Test patterns to simplify tensor.pack and tensor.unpack"),
llvm::cl::init(false)};
Option<bool> testTrackingListener{
@@ -137,9 +137,9 @@ applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) {
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
}
-static void applySimplifyPackPatterns(Operation *rootOp) {
+static void applySimplifyPackUnpackPatterns(Operation *rootOp) {
RewritePatternSet patterns(rootOp->getContext());
- tensor::populateSimplifyTensorPack(patterns);
+ tensor::populateSimplifyPackAndUnpackPatterns(patterns);
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
}
@@ -376,8 +376,8 @@ static LogicalResult testTrackingListenerReplacements(Operation *rootOp) {
void TestTensorTransforms::runOnOperation() {
Operation *rootOp = getOperation();
- if (testSimplifyPackPatterns)
- applySimplifyPackPatterns(rootOp);
+ if (testSimplifyPackUnpackPatterns)
+ applySimplifyPackUnpackPatterns(rootOp);
if (testFoldConstantExtractSlice)
applyFoldConstantExtractSlicePatterns(rootOp);
if (testFoldConsecutiveInsertExtractSlice)
More information about the Mlir-commits
mailing list