[Mlir-commits] [mlir] [mlir][tensor] Enhance SimplifyPackToExpandShape for unit dim cases. (PR #79247)

lorenzo chelini llvmlistbot at llvm.org
Wed Jan 24 05:21:28 PST 2024


================
@@ -34,26 +41,57 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
                                                   reassociation);
   }
 
-  LogicalResult matchAndRewrite(PackOp packOp,
-                                PatternRewriter &rewriter) const override {
-    if (packOp.getPaddingValue())
-      return rewriter.notifyMatchFailure(packOp, "expects no padding value");
-
+  /// Returns success() if it is only packing on the innermost dimension.
+  LogicalResult isPackOneInnerMostDim(RewriterBase &rewriter,
+                                      PackOp packOp) const {
     auto outerDimsPerm = packOp.getOuterDimsPerm();
     if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
       return rewriter.notifyMatchFailure(
           packOp,
           "expects outer_dims_perm is empty or an identity permutation");
     }
 
-    RankedTensorType sourceType = packOp.getSourceType();
-    RankedTensorType destType = packOp.getDestType();
+    int64_t srcRank = packOp.getSourceRank();
     ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
-    if (dimsPos.size() != 1 || (dimsPos[0] + 1 != sourceType.getRank())) {
+    if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
       return rewriter.notifyMatchFailure(
           packOp, "expects packing at the innermost dimension");
     }
+    return success();
+  }
+
+  /// Returns success() if there is only 1 dimension size in source being
+  /// greater than 1 and packing only happens on the dimension. It assumes that
----------------
chelini wrote:

since the method assumes no outer dime perm can we add an assert?

https://github.com/llvm/llvm-project/pull/79247


More information about the Mlir-commits mailing list