[Mlir-commits] [mlir] [mlir][vector] Canonicalize vector.extract and vector.broadcast to vector.shape_cast (PR #174452)
Jakub Kuderski
llvmlistbot at llvm.org
Mon Jan 5 10:19:38 PST 2026
================
@@ -2359,11 +2359,48 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
return success();
}
+/// Replace `vector.extract` with `vector.shape_cast`.
+///
+/// BEFORE:
+/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+///
+/// The canonical form of vector operations that reshape vectors is shape_cast.
+struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ VectorType sourceType = extractOp.getSourceVectorType();
+ VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+ if (!outType)
+ return failure();
+
+ if (sourceType.getNumElements() != outType.getNumElements())
+ return rewriter.notifyMatchFailure(
+ extractOp, "extract to vector with fewer elements");
+
+ // Negative values in `position` means that the extacted value is poison.
+ // There is a vector.extract folder for this.
+ if (llvm::any_of(extractOp.getMixedPosition(),
+ [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
----------------
kuhar wrote:
I think you could also do this
```suggestion
if (!llvm::all_of(extractOp.getMixedPosition(), isZeroValue))
```
(I forgot if `isZeroValue` is a helper in mlir or iree)
https://github.com/llvm/llvm-project/pull/174452
More information about the Mlir-commits
mailing list