[Mlir-commits] [mlir] [mlir][vector] Improve shape_cast lowering (PR #140800)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Tue May 27 09:50:52 PDT 2025
================
@@ -157,41 +160,54 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- auto sourceVectorType = op.getSourceVectorType();
- auto resultVectorType = op.getResultVectorType();
+ VectorType sourceType = op.getSourceVectorType();
+ VectorType resultType = op.getResultVectorType();
- if (sourceVectorType.isScalable() || resultVectorType.isScalable())
+ if (sourceType.isScalable() || resultType.isScalable())
return failure();
- // Special case for n-D / 1-D lowerings with better implementations.
- int64_t srcRank = sourceVectorType.getRank();
- int64_t resRank = resultVectorType.getRank();
- if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1))
+ // Special case for n-D / 1-D lowerings with implementations that use
+ // extract_strided_slice / insert_strided_slice.
+ int64_t sourceRank = sourceType.getRank();
+ int64_t resultRank = resultType.getRank();
+ if ((sourceRank > 1 && resultRank == 1) ||
+ (sourceRank == 1 && resultRank > 1))
return failure();
- // Generic ShapeCast lowering path goes all the way down to unrolled scalar
- // extract/insert chains.
- int64_t numElts = 1;
- for (int64_t r = 0; r < srcRank; r++)
- numElts *= sourceVectorType.getDimSize(r);
+ int64_t numExtracts = sourceType.getNumElements();
+ int64_t nbCommonInnerDims = 0;
----------------
banach-space wrote:
Do both `num` and `nb` stand for number? Could you unify?
https://github.com/llvm/llvm-project/pull/140800
More information about the Mlir-commits
mailing list