[Mlir-commits] [mlir] [mlir][tensor] Enhance SimplifyPackToExpandShape for unit dim cases. (PR #79247)

lorenzo chelini llvmlistbot at llvm.org
Wed Jan 24 05:21:29 PST 2024


================
@@ -34,26 +41,57 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
                                                   reassociation);
   }
 
-  LogicalResult matchAndRewrite(PackOp packOp,
-                                PatternRewriter &rewriter) const override {
-    if (packOp.getPaddingValue())
-      return rewriter.notifyMatchFailure(packOp, "expects no padding value");
-
+  /// Returns success() if it is only packing on the innermost dimension.
+  LogicalResult isPackOneInnerMostDim(RewriterBase &rewriter,
+                                      PackOp packOp) const {
     auto outerDimsPerm = packOp.getOuterDimsPerm();
     if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
       return rewriter.notifyMatchFailure(
           packOp,
           "expects outer_dims_perm is empty or an identity permutation");
     }
 
-    RankedTensorType sourceType = packOp.getSourceType();
-    RankedTensorType destType = packOp.getDestType();
+    int64_t srcRank = packOp.getSourceRank();
     ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
-    if (dimsPos.size() != 1 || (dimsPos[0] + 1 != sourceType.getRank())) {
+    if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
       return rewriter.notifyMatchFailure(
           packOp, "expects packing at the innermost dimension");
     }
+    return success();
+  }
+
+  /// Returns success() if there is only 1 dimension size in source being
+  /// greater than 1 and packing only happens on the dimension. It assumes that
+  /// the pack op does not have padding value.
+  LogicalResult isPack1DSrc(RewriterBase &rewriter, PackOp packOp) const {
+    ArrayRef<int64_t> srcShape = packOp.getSourceType().getShape();
+    if (getNumGtOneDims(srcShape) > 1) {
+      return rewriter.notifyMatchFailure(
+          packOp, "expects source is not 1D tensor with unit dims");
+    }
 
+    // The pack op does not have padding value. Non-unit inner tile size must be
+    // be used by the non-unit dimension.
+    SmallVector<int64_t> innerTiles = packOp.getStaticTiles();
+    if (getNumGtOneDims(innerTiles) > 1) {
+      return rewriter.notifyMatchFailure(
+          packOp, "expects has at most one non-unit inner tiles");
+    }
+
+    return success();
+  }
+
+  LogicalResult matchAndRewrite(PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    if (packOp.getPaddingValue())
+      return rewriter.notifyMatchFailure(packOp, "expects no padding value");
+
+    if (failed(isPackOneInnerMostDim(rewriter, packOp)) &&
+        failed(isPack1DSrc(rewriter, packOp)))
+      return failure();
----------------
chelini wrote:

nit: braces.

https://github.com/llvm/llvm-project/pull/79247


More information about the Mlir-commits mailing list