[Mlir-commits] [mlir] [MLIR][Canonicalization] Added shape_cast folding patterns (PR #183061)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed Feb 25 10:48:20 PST 2026
================
@@ -6660,13 +6667,36 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
}
};
+/// Pattern to rewrite Y = ShapeCast(FromElements(X)) as Y = FromElements(X)
+///
+/// BEFORE:
+/// %1 = vector.from_elements %c1, %c2, %c3 : vector<3xf32>
+/// %2 = vector.shape_cast %1 : vector<3xf32> to vector<1x3xf32>
+/// AFTER:
+/// %2 = vector.from_elements %c1, %c2, %c3 : vector<1x3xf32>
+class FoldShapeCastOfFromElements final : public OpRewritePattern<ShapeCastOp> {
+public:
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ auto fromElements = shapeCastOp.getSource().getDefiningOp<FromElementsOp>();
+ if (!fromElements)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<FromElementsOp>(
+ shapeCastOp, shapeCastOp.getResultVectorType(),
+ fromElements.getElements());
+ return success();
+ }
+};
+
----------------
banach-space wrote:
Instead of adding a new pattern, you should be able to simply extend this: https://github.com/llvm/llvm-project/blob/8f378ea7e6fa31179266c69368be56a866b631e1/mlir/lib/Dialect/Vector/IR/VectorOps.cpp?plain=1#L6481-L6521
IIUC, that's what Mehdi and Kunwar had in mind.
https://github.com/llvm/llvm-project/pull/183061
More information about the Mlir-commits
mailing list