[Mlir-commits] [mlir] [mlir][Vector] Support efficient shape cast lowering for n-D vectors (PR #123497)
Han-Chung Wang
llvmlistbot at llvm.org
Tue Jan 21 00:11:59 PST 2025
================
@@ -90,43 +108,43 @@ class ShapeCastOp2DUpCastRewritePattern
PatternRewriter &rewriter) const override {
auto sourceVectorType = op.getSourceVectorType();
auto resultVectorType = op.getResultVectorType();
-
if (sourceVectorType.isScalable() || resultVectorType.isScalable())
return failure();
- if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
+ int64_t srcRank = sourceVectorType.getRank();
+ int64_t resRank = resultVectorType.getRank();
+ if (srcRank != 1 || resRank < 2)
return failure();
+ // Compute the number of 1-D vector elements involved in the reshape.
+ int64_t numElts = 1;
+ for (int64_t dim = 0; dim < resRank - 1; ++dim)
+ numElts *= resultVectorType.getDimSize(dim);
+
+ // Compute the indices of each 1-D vector element of the source slice
+ // extraction and destination insertion and generate such instructions.
auto loc = op.getLoc();
- Value desc = rewriter.create<arith::ConstantOp>(
+ SmallVector<int64_t> srcIdx(srcRank);
+ SmallVector<int64_t> resIdx(resRank - 1);
----------------
hanhanW wrote:
[optional] ditto, if the above suggestion is applied, please also update the code for consistency.
https://github.com/llvm/llvm-project/pull/123497
More information about the Mlir-commits
mailing list