[Mlir-commits] [mlir] [mlir][vector] Canonicalize vector.extract and vector.broadcast to vector.shape_cast (PR #174452)

Jakub Kuderski llvmlistbot at llvm.org
Mon Jan 5 10:19:38 PST 2026


================
@@ -2359,11 +2359,48 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
   return success();
 }
 
+/// Replace `vector.extract` with `vector.shape_cast`.
+///
+/// BEFORE:
+/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+///
+/// The canonical form of vector operations that reshape vectors is shape_cast.
+struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType sourceType = extractOp.getSourceVectorType();
+    VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+    if (!outType)
+      return failure();
+
+    if (sourceType.getNumElements() != outType.getNumElements())
+      return rewriter.notifyMatchFailure(
+          extractOp, "extract to vector with fewer elements");
+
+    // Negative values in `position` means that the extacted value is poison.
+    // There is a vector.extract folder for this.
+    if (llvm::any_of(extractOp.getMixedPosition(),
+                     [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
----------------
kuhar wrote:

I think you could also do this

```suggestion
    if (!llvm::all_of(extractOp.getMixedPosition(), isZeroValue))
```
(I forgot if `isZeroValue` is a helper in mlir or iree)

https://github.com/llvm/llvm-project/pull/174452


More information about the Mlir-commits mailing list