[Mlir-commits] [mlir] [mlir][tensor] Centralize pack/unpack related patterns. (PR #76603)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 29 21:44:29 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

<details>
<summary>Changes</summary>

The revision moves pack/unpack related patterns to PackAndUnpackPatterns.cpp. This follows the convention like other tensor ops.

---
Full diff: https://github.com/llvm/llvm-project/pull/76603.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tensor/IR/Tensor.h (-3) 
- (modified) mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h (+4) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (-38) 
- (modified) mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt (+1-1) 
- (renamed) mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp (+34) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 06642adda42b38..0a21c9922b223b 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -163,9 +163,6 @@ void populateFoldConstantExtractSlicePatterns(
           return false;
         });
 
-/// Patterns to simplify tensor.pack.
-void populateSimplifyTensorPack(RewritePatternSet &patterns);
-
 } // namespace tensor
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 44b8377bd6aad9..b7d29a508ff95c 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -74,6 +74,10 @@ void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
 /// that it can be bufferized into a sequence of copies.
 void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
 
+/// Populates `patterns` with patterns that simplifies `tensor.pack` and
+/// `tensor.unpack` operations.
+void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns);
+
 /// Populates `patterns` with patterns that fold operations like `tensor.pad`
 /// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
 /// respectively.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 7c35dd4d953619..816e6ba8fed94e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3466,44 +3466,6 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
 // PackOp/UnPackOp Common
 //===----------------------------------------------------------------------===//
 
-namespace {
-
-/// Packing one-dimensional tensor can be expressed as an expand shape op.
-struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
-  using OpRewritePattern<PackOp>::OpRewritePattern;
-
-  Value insertExpand(RewriterBase &rewriter, Location loc, Value operand,
-                     Type newOperandType, ArrayAttr reassociation) const {
-    if (operand.getType() == newOperandType)
-      return operand;
-    return rewriter.create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
-                                                  reassociation);
-  }
-
-  LogicalResult matchAndRewrite(PackOp packOp,
-                                PatternRewriter &rewriter) const override {
-    RankedTensorType sourceType = packOp.getSourceType();
-    RankedTensorType destType = packOp.getDestType();
-    if (sourceType.getRank() != 1 || packOp.getPaddingValue())
-      return failure();
-    auto reassociation =
-        getReassociationIndicesForReshape(sourceType, destType);
-    if (!reassociation)
-      return failure();
-    Value expanded = insertExpand(
-        rewriter, packOp.getLoc(), packOp.getSource(), destType,
-        getReassociationIndicesAttribute(rewriter, *reassociation));
-    rewriter.replaceOp(packOp, expanded);
-    return success();
-  }
-};
-
-} // namespace
-
-void mlir::tensor::populateSimplifyTensorPack(RewritePatternSet &patterns) {
-  patterns.add<SimplifyPackToExpandShape>(patterns.getContext());
-}
-
 template <typename OpTy>
 static LogicalResult
 reifyResultShapesImpl(OpTy op, OpBuilder &builder,
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index d233ab7a0e8974..cbc0d499d9d52c 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -4,10 +4,10 @@ add_mlir_dialect_library(MLIRTensorTransforms
   ConcatOpPatterns.cpp
   EmptyOpPatterns.cpp
   ExtractSliceFromReshapeUtils.cpp
-  FoldIntoPackAndUnpackPatterns.cpp
   FoldTensorSubsetOps.cpp
   IndependenceTransforms.cpp
   MergeConsecutiveInsertExtractSlicePatterns.cpp
+  PackAndUnpackPatterns.cpp
   ReshapePatterns.cpp
   RewriteAsConstant.cpp
   SwapExtractSliceWithProducerPatterns.cpp
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
similarity index 80%
rename from mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
rename to mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index e4509b331beeac..67651a2e38c82d 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -21,6 +21,36 @@ static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
       ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
 }
 
+/// Packing one-dimensional tensor can be expressed as an expand shape op.
+struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
+  using OpRewritePattern<PackOp>::OpRewritePattern;
+
+  Value insertExpand(RewriterBase &rewriter, Location loc, Value operand,
+                     Type newOperandType, ArrayAttr reassociation) const {
+    if (operand.getType() == newOperandType)
+      return operand;
+    return rewriter.create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
+                                                  reassociation);
+  }
+
+  LogicalResult matchAndRewrite(PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    RankedTensorType sourceType = packOp.getSourceType();
+    RankedTensorType destType = packOp.getDestType();
+    if (sourceType.getRank() != 1 || packOp.getPaddingValue())
+      return failure();
+    auto reassociation =
+        getReassociationIndicesForReshape(sourceType, destType);
+    if (!reassociation)
+      return failure();
+    Value expanded = insertExpand(
+        rewriter, packOp.getLoc(), packOp.getSource(), destType,
+        getReassociationIndicesAttribute(rewriter, *reassociation));
+    rewriter.replaceOp(packOp, expanded);
+    return success();
+  }
+};
+
 /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
 /// the pad op has zero low paddings, or if `pack` has no padding values.
 struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
@@ -150,5 +180,9 @@ void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
       patterns.getContext());
 }
 
+void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
+  patterns.add<SimplifyPackToExpandShape>(patterns.getContext());
+}
+
 } // namespace tensor
 } // namespace mlir

``````````

</details>


https://github.com/llvm/llvm-project/pull/76603


More information about the Mlir-commits mailing list