[Mlir-commits] [mlir] [mlir][Vector] Support efficient shape cast lowering for n-D vectors (PR #123497)

Han-Chung Wang llvmlistbot at llvm.org
Tue Jan 21 00:11:58 PST 2025


================
@@ -53,35 +54,52 @@ class ShapeCastOp2DDownCastRewritePattern
                                 PatternRewriter &rewriter) const override {
     auto sourceVectorType = op.getSourceVectorType();
     auto resultVectorType = op.getResultVectorType();
-
     if (sourceVectorType.isScalable() || resultVectorType.isScalable())
       return failure();
 
-    if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
+    int64_t srcRank = sourceVectorType.getRank();
+    int64_t resRank = resultVectorType.getRank();
+    if (srcRank < 2 || resRank != 1)
       return failure();
 
+    // Compute the number of 1-D vector elements involved in the reshape.
+    int64_t numElts = 1;
+    for (int64_t dim = 0; dim < srcRank - 1; ++dim)
+      numElts *= sourceVectorType.getDimSize(dim);
+
     auto loc = op.getLoc();
-    Value desc = rewriter.create<arith::ConstantOp>(
+    SmallVector<int64_t> srcIdx(srcRank - 1);
+    SmallVector<int64_t> resIdx(resRank);
----------------
hanhanW wrote:

[optional nit]: Interesting code, and it initializes all the values to zeros. I know it works and how it works. But it is not common in the codebase, IMO. We usually add the default value to the constructor. Would you mind to explicitly add the default value to the constructor?

https://github.com/llvm/llvm-project/pull/123497


More information about the Mlir-commits mailing list