[Mlir-commits] [mlir] [mlir][tensor] Enhance SimplifyPackToExpandShape for unit dim cases. (PR #79247)
lorenzo chelini
llvmlistbot at llvm.org
Wed Jan 24 05:21:28 PST 2024
================
@@ -34,26 +41,57 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
reassociation);
}
- LogicalResult matchAndRewrite(PackOp packOp,
- PatternRewriter &rewriter) const override {
- if (packOp.getPaddingValue())
- return rewriter.notifyMatchFailure(packOp, "expects no padding value");
-
+ /// Returns success() if it is only packing on the innermost dimension.
+ LogicalResult isPackOneInnerMostDim(RewriterBase &rewriter,
+ PackOp packOp) const {
auto outerDimsPerm = packOp.getOuterDimsPerm();
if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
return rewriter.notifyMatchFailure(
packOp,
"expects outer_dims_perm is empty or an identity permutation");
}
- RankedTensorType sourceType = packOp.getSourceType();
- RankedTensorType destType = packOp.getDestType();
+ int64_t srcRank = packOp.getSourceRank();
ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
- if (dimsPos.size() != 1 || (dimsPos[0] + 1 != sourceType.getRank())) {
+ if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
return rewriter.notifyMatchFailure(
packOp, "expects packing at the innermost dimension");
}
+ return success();
+ }
+
+ /// Returns success() if there is only 1 dimension size in source being
+ /// greater than 1 and packing only happens on the dimension. It assumes that
----------------
chelini wrote:
since the method assumes no outer dime perm can we add an assert?
https://github.com/llvm/llvm-project/pull/79247
More information about the Mlir-commits
mailing list