[Mlir-commits] [mlir] [mlir][Vector] Support efficient shape cast lowering for n-D vectors (PR #123497)

Han-Chung Wang llvmlistbot at llvm.org
Tue Jan 21 00:11:59 PST 2025


================
@@ -90,43 +108,43 @@ class ShapeCastOp2DUpCastRewritePattern
                                 PatternRewriter &rewriter) const override {
     auto sourceVectorType = op.getSourceVectorType();
     auto resultVectorType = op.getResultVectorType();
-
     if (sourceVectorType.isScalable() || resultVectorType.isScalable())
       return failure();
 
-    if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
+    int64_t srcRank = sourceVectorType.getRank();
+    int64_t resRank = resultVectorType.getRank();
+    if (srcRank != 1 || resRank < 2)
       return failure();
 
+    // Compute the number of 1-D vector elements involved in the reshape.
+    int64_t numElts = 1;
+    for (int64_t dim = 0; dim < resRank - 1; ++dim)
+      numElts *= resultVectorType.getDimSize(dim);
+
+    // Compute the indices of each 1-D vector element of the source slice
+    // extraction and destination insertion and generate such instructions.
     auto loc = op.getLoc();
-    Value desc = rewriter.create<arith::ConstantOp>(
+    SmallVector<int64_t> srcIdx(srcRank);
+    SmallVector<int64_t> resIdx(resRank - 1);
----------------
hanhanW wrote:

[optional] ditto, if the above suggestion is applied, please also update the code for consistency.

https://github.com/llvm/llvm-project/pull/123497


More information about the Mlir-commits mailing list