[Mlir-commits] [mlir] [mli][vector] canonicalize vector.from_elements from ascending extracts (PR #139819)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu May 15 06:47:31 PDT 2025
================
@@ -2385,9 +2386,98 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
return success();
}
+/// Rewrite vector.from_elements as vector.shape_cast, if possible.
+///
+/// Example:
+/// %0 = vector.extract %source[0, 0] : i8 from vector<1x2xi8>
+/// %1 = vector.extract %source[0, 1] : i8 from vector<1x2xi8>
+/// %2 = vector.from_elements %0, %1 : vector<2xi8>
+///
+/// becomes
+/// %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8>
+///
+/// The requirements for this to be valid are
+/// i) all elements are extracted from the same vector (source),
+/// ii) source and from_elements result have the same number of elements,
+/// iii) the elements are extracted in ascending order.
+///
+/// It might be possible to rewrite vector.from_elements as a single
+/// vector.extract if (ii) is not satisifed, or in some cases as a
+/// a single vector_extract_strided_slice if (ii) and (iii) are not satisfied,
+/// this is left for future consideration.
+class FromElementsToShapCast : public OpRewritePattern<FromElementsOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(FromElementsOp fromElements,
+ PatternRewriter &rewriter) const override {
+
+ mlir::OperandRange elements = fromElements.getElements();
+ assert(!elements.empty() && "must be at least 1 element");
+ Value firstElement = elements.front();
+
+ ExtractOp extractOp =
+ dyn_cast_if_present<vector::ExtractOp>(firstElement.getDefiningOp());
+ if (!extractOp) {
+ return rewriter.notifyMatchFailure(
+ fromElements, "first element not from vector.extract");
+ }
+ VectorType sourceType = extractOp.getSourceVectorType();
+ Value source = extractOp.getVector();
+
+ // Check condition (ii).
+ if (static_cast<size_t>(sourceType.getNumElements()) != elements.size()) {
+ return rewriter.notifyMatchFailure(fromElements,
+ "number of elements differ");
+ }
----------------
banach-space wrote:
[nit] I would "rebrand" this as "Condition (i)" (it's the first condition to be checked) and move it all the way to the top - it feels like a fairly high level condition that deserves a special place :)
https://github.com/llvm/llvm-project/pull/139819
More information about the Mlir-commits
mailing list