[Mlir-commits] [mlir] [mli][vector] vector.from_elements canonicalizer when elements ascending extracts (PR #139819)
James Newling
llvmlistbot at llvm.org
Tue May 13 17:47:05 PDT 2025
================
@@ -2385,9 +2385,69 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
return success();
}
+/// Rewrite vector.from_elements as vector.shape_cast, if possible.
+///
+/// Example:
+/// %0 = vector.extract %source[0, 0] : i8 from vector<1x2xi8>
+/// %1 = vector.extract %source[0, 1] : i8 from vector<1x2xi8>
+/// %2 = vector.from_elements %0, %1 : vector<2xi8>
+///
+/// becomes
+/// %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8>
+static LogicalResult
+rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp,
+ PatternRewriter &rewriter) {
+
+ // The common source of vector.extract operations (if one exists), as well
+ // as its shape and rank. These are set in the first iteration of the loop
+ // over the operands (elements) of `fromElementsOp`.
+ Value source;
+ ArrayRef<int64_t> shape;
+ int64_t rank;
+
+ for (auto [index, element] : llvm::enumerate(fromElementsOp.getElements())) {
+
+ // Check that the element is defined by an extract operation, and that
+ // the extract is on the same vector as all preceding elements.
+ auto extractOp =
+ dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
+ if (!extractOp)
+ return failure();
+ Value currentSource = extractOp.getVector();
+ if (index == 0) {
+ source = currentSource;
+ shape = extractOp.getSourceVectorType().getShape();
+ rank = shape.size();
+ } else if (currentSource != source) {
+ return failure();
+ }
+
+ // Check that the (linearized) index of extraction is the same as the index
+ // in the result of `fromElementsOp`.
+ ArrayRef<int64_t> position = extractOp.getStaticPosition();
+ assert(position.size() == rank &&
+ "scalar extract must have full rank position");
+ int64_t stride{1};
+ int64_t offset{0};
+ for (auto [pos, size] :
+ llvm::zip(llvm::reverse(position), llvm::reverse(shape))) {
+ if (pos == ShapedType::kDynamic)
+ return failure();
+ offset += pos * stride;
+ stride *= size;
+ }
+ if (offset != index)
+ return failure();
+ }
+
+ rewriter.replaceOpWithNewOp<ShapeCastOp>(fromElementsOp,
+ fromElementsOp.getType(), source);
+}
+
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add(rewriteFromElementsAsSplat);
+ results.add(rewriteFromElementsAsShapeCast);
----------------
newling wrote:
Just copying the design pattern of `rewriteFromElementsAsSplat` but this maybe complex enough to warrant an actual Pattern with `notifyMatchFailure` s.
https://github.com/llvm/llvm-project/pull/139819
More information about the Mlir-commits
mailing list