[Mlir-commits] [mlir] [mli][vector] canonicalize vector.from_elements from ascending extracts (PR #139819)
James Newling
llvmlistbot at llvm.org
Wed May 14 12:47:14 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/139819
>From c252a3dd92ab1145dc80a847729423366a8f1dea Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 13 May 2025 16:06:45 -0700
Subject: [PATCH 1/6] first commit
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 58 +++++++++++++++
mlir/test/Dialect/Vector/canonicalize.mlir | 69 ------------------
.../canonicalize/vector-from-elements.mlir | 72 +++++++++++++++++++
3 files changed, 130 insertions(+), 69 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f6c3c6a61afb6..71844e62baba7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2385,6 +2385,64 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
return success();
}
+static LogicalResult
+rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp,
+ PatternRewriter &rewriter) {
+
+ mlir::OperandRange elements = fromElementsOp.getElements();
+ const size_t nbElements = elements.size();
+ assert(nbElements > 0 && "must be at least one element");
+
+ // https://en.wikipedia.org/wiki/List_of_prime_numbers
+ const int prime = 5387;
+ bool pseudoRandomOrder = nbElements < prime;
+
+ Value source;
+ ArrayRef<int64_t> shape;
+ for (size_t elementIndex = 0ULL; elementIndex < nbElements; elementIndex++) {
+
+ // Rather than iterating through the elements in ascending order, we might
+ // be able to exit quickly if we go through in pseudo-random order. Use
+ // fact that (i * p) % a is a bijection for i in [0, a) if p is prime and
+ // a < p.
+ int currentIndex =
+ pseudoRandomOrder ? elementIndex : (elementIndex * prime) % nbElements;
+ Value element = elements[currentIndex];
+
+ // From an extract on the same source as the other elements.
+ auto extractOp =
+ dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
+ if (!extractOp)
+ return failure();
+ Value currentSource = extractOp.getVector();
+ if (!source) {
+ source = currentSource;
+ shape = extractOp.getSourceVectorType().getShape();
+ } else if (currentSource != source) {
+ return failure();
+ }
+
+ ArrayRef<int64_t> position = extractOp.getStaticPosition();
+ assert(position.size() == shape.size());
+
+ int64_t stride{1};
+ int64_t offset{0};
+ for (auto [pos, size] :
+ llvm::zip(llvm::reverse(position), llvm::reverse(shape))) {
+ if (pos == ShapedType::kDynamic)
+ return failure();
+ offset += pos * stride;
+ stride *= size;
+ }
+ if (offset != currentIndex)
+ return failure();
+ }
+
+ // Can replace with a shape_cast.
+ rewriter.replaceOpWithNewOp<ShapeCastOp>(fromElementsOp,
+ fromElementsOp.getType(), source);
+}
+
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add(rewriteFromElementsAsSplat);
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 99f0850000a16..6af517d988360 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2952,75 +2952,6 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
// -----
-// CHECK-LABEL: func @extract_scalar_from_from_elements(
-// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
-func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
- // Extract from 0D.
- %0 = vector.from_elements %a : vector<f32>
- %1 = vector.extract %0[] : f32 from vector<f32>
-
- // Extract from 1D.
- %2 = vector.from_elements %a : vector<1xf32>
- %3 = vector.extract %2[0] : f32 from vector<1xf32>
- %4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
- %5 = vector.extract %4[4] : f32 from vector<5xf32>
-
- // Extract from 2D.
- %6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
- %7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
- %8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
- %9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
- %10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>
-
- // CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
- return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
-}
-
-// -----
-
-// CHECK-LABEL: func @extract_1d_from_from_elements(
-// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
-func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
- %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
- // CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
- %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
- // CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
- %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
- // CHECK: return %[[splat1]], %[[splat2]]
- return %1, %2 : vector<3xf32>, vector<3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @extract_2d_from_from_elements(
-// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
-func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
- %0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
- // CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
- %1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
- // CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
- %2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
- // CHECK: return %[[splat1]], %[[splat2]]
- return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @from_elements_to_splat(
-// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
-func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
- // CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
- %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
- // CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
- %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
- // CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
- %2 = vector.from_elements %a : vector<f32>
- // CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
- return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
-}
-
-// -----
-
// CHECK-LABEL: func @vector_insert_const_regression(
// CHECK: llvm.mlir.undef
// CHECK: vector.insert
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
new file mode 100644
index 0000000000000..21ce71473a3cd
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// This file contains some tests of folding/canonicalizing vector.from_elements
+
+// CHECK-LABEL: func @extract_scalar_from_from_elements(
+// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
+ // Extract from 0D.
+ %0 = vector.from_elements %a : vector<f32>
+ %1 = vector.extract %0[] : f32 from vector<f32>
+
+ // Extract from 1D.
+ %2 = vector.from_elements %a : vector<1xf32>
+ %3 = vector.extract %2[0] : f32 from vector<1xf32>
+ %4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
+ %5 = vector.extract %4[4] : f32 from vector<5xf32>
+
+ // Extract from 2D.
+ %6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
+ %7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
+ %8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
+ %9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
+ %10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>
+
+ // CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
+ return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_1d_from_from_elements(
+// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
+ %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
+ // CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
+ %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
+ // CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
+ %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
+ // CHECK: return %[[splat1]], %[[splat2]]
+ return %1, %2 : vector<3xf32>, vector<3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_2d_from_from_elements(
+// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
+ %0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
+ // CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
+ %1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
+ // CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
+ %2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
+ // CHECK: return %[[splat1]], %[[splat2]]
+ return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @from_elements_to_splat(
+// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
+ // CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
+ %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
+ // CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
+ %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
+ // CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
+ %2 = vector.from_elements %a : vector<f32>
+ // CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
+ return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
+}
+
+// -----
>From 7f40da6728bb5b197548d3466818869d45ce9720 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 13 May 2025 17:33:41 -0700
Subject: [PATCH 2/6] improvements
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 46 ++++++------
.../canonicalize/vector-from-elements.mlir | 73 +++++++++++++++++++
2 files changed, 98 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 71844e62baba7..e0ce41e5d6245 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2385,45 +2385,49 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
return success();
}
+
+/// Rewrite a vecor.from_elements as a vector.shape_cast, if possible.
+///
+/// Example:
+/// %0 = vector.extract %source[0, 0] : i8 from vector<1x2xi8>
+/// %1 = vector.extract %source[0, 1] : i8 from vector<1x2xi8>
+/// %2 = vector.from_elements %0, %1 : vector<2xi8>
+///
+/// becomes
+/// %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8>
static LogicalResult
rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp,
PatternRewriter &rewriter) {
- mlir::OperandRange elements = fromElementsOp.getElements();
- const size_t nbElements = elements.size();
- assert(nbElements > 0 && "must be at least one element");
-
- // https://en.wikipedia.org/wiki/List_of_prime_numbers
- const int prime = 5387;
- bool pseudoRandomOrder = nbElements < prime;
-
+ // The common source of vector.extract operations (if one exists), as well
+ // as its shape and rank. Set in the first iteration of the loop over the
+ // operands of `fromElementsOp`.
Value source;
ArrayRef<int64_t> shape;
- for (size_t elementIndex = 0ULL; elementIndex < nbElements; elementIndex++) {
+ int64_t rank;
- // Rather than iterating through the elements in ascending order, we might
- // be able to exit quickly if we go through in pseudo-random order. Use
- // fact that (i * p) % a is a bijection for i in [0, a) if p is prime and
- // a < p.
- int currentIndex =
- pseudoRandomOrder ? elementIndex : (elementIndex * prime) % nbElements;
- Value element = elements[currentIndex];
+ for (auto [index, element] : llvm::enumerate(fromElementsOp.getElements())) {
- // From an extract on the same source as the other elements.
+ // Check that the element is defined by an extract operation, and that
+ // the extract is on the same vector as all preceding elements.
auto extractOp =
dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
if (!extractOp)
return failure();
Value currentSource = extractOp.getVector();
- if (!source) {
+ if (index == 0) {
source = currentSource;
shape = extractOp.getSourceVectorType().getShape();
+ rank = shape.size();
} else if (currentSource != source) {
return failure();
}
+ // Check that the (linearized) index of extraction is the same as the index
+ // in the result of `fromElementsOp`.
ArrayRef<int64_t> position = extractOp.getStaticPosition();
- assert(position.size() == shape.size());
+ if (position.size() != rank)
+ return failure();
int64_t stride{1};
int64_t offset{0};
@@ -2434,11 +2438,10 @@ rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp,
offset += pos * stride;
stride *= size;
}
- if (offset != currentIndex)
+ if (offset != index)
return failure();
}
- // Can replace with a shape_cast.
rewriter.replaceOpWithNewOp<ShapeCastOp>(fromElementsOp,
fromElementsOp.getType(), source);
}
@@ -2446,6 +2449,7 @@ rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp,
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add(rewriteFromElementsAsSplat);
+ results.add(rewriteFromElementsAsShapeCast);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index 21ce71473a3cd..fafac4419d719 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -2,6 +2,10 @@
// This file contains some tests of folding/canonicalizing vector.from_elements
+///===----------------------------------------------===//
+/// Tests of `rewriteFromElementsAsSplat`
+///===----------------------------------------------===//
+
// CHECK-LABEL: func @extract_scalar_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
@@ -70,3 +74,72 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
}
// -----
+
+
+///===----------------------------------------------===//
+/// Tests of `rewriteFromElementsAsShapeCast`
+///===----------------------------------------------===//
+
+// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
+// CHECK-SAME: %[[a:.*]]: vector<1x2xi8>)
+// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[a]] : vector<1x2xi8> to vector<2xi8>
+// CHECK: return %[[shape_cast]] : vector<2xi8>
+func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
+ %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
+ %1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
+ %4 = vector.from_elements %0, %1 : vector<2xi8>
+ return %4 : vector<2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @to_shape_cast_rank1_to_rank3(
+// CHECK-SAME: %[[a:.*]]: vector<8xi8>)
+// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[a]] : vector<8xi8> to vector<2x2x2xi8>
+// CHECK: return %[[shape_cast]] : vector<2x2x2xi8>
+func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8> {
+ %0 = vector.extract %arg0[0] : i8 from vector<8xi8>
+ %1 = vector.extract %arg0[1] : i8 from vector<8xi8>
+ %2 = vector.extract %arg0[2] : i8 from vector<8xi8>
+ %3 = vector.extract %arg0[3] : i8 from vector<8xi8>
+ %4 = vector.extract %arg0[4] : i8 from vector<8xi8>
+ %5 = vector.extract %arg0[5] : i8 from vector<8xi8>
+ %6 = vector.extract %arg0[6] : i8 from vector<8xi8>
+ %7 = vector.extract %arg0[7] : i8 from vector<8xi8>
+ %8 = vector.from_elements %0, %1, %2, %3, %4, %5, %6, %7 : vector<2x2x2xi8>
+ return %8 : vector<2x2x2xi8>
+}
+
+// -----
+
+// The extracted elements are recombined into a single vector, but in a new order.
+// CHECK-LABEL: func @negative_nonascending_order(
+// CHECK-NOT: shape_cast
+func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> {
+ %0 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
+ %1 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
+ %2 = vector.from_elements %0, %1 : vector<2xi8>
+ return %2 : vector<2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_nonstatic_extract(
+// CHECK-NOT: shape_cast
+func.func @negative_nonstatic_extract(%arg0: vector<1x2xi8>, %i0 : index, %i1 : index) -> vector<2xi8> {
+ %0 = vector.extract %arg0[0, %i0] : i8 from vector<1x2xi8>
+ %1 = vector.extract %arg0[0, %i1] : i8 from vector<1x2xi8>
+ %2 = vector.from_elements %0, %1 : vector<2xi8>
+ return %2 : vector<2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_different_sources(
+// CHECK-NOT: shape_cast
+func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi8>) -> vector<2xi8> {
+ %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
+ %1 = vector.extract %arg1[0, 1] : i8 from vector<1x2xi8>
+ %2 = vector.from_elements %0, %1 : vector<2xi8>
+ return %2 : vector<2xi8>
+}
>From 3ce0713d33fef03141239a00033f5bf5c896fc72 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 13 May 2025 17:41:11 -0700
Subject: [PATCH 3/6] apply some polish
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 12 +++++-------
1 file changed, 5 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index e0ce41e5d6245..e0b406ce0bc46 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2385,8 +2385,7 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
return success();
}
-
-/// Rewrite a vecor.from_elements as a vector.shape_cast, if possible.
+/// Rewrite vector.from_elements as vector.shape_cast, if possible.
///
/// Example:
/// %0 = vector.extract %source[0, 0] : i8 from vector<1x2xi8>
@@ -2400,8 +2399,8 @@ rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp,
PatternRewriter &rewriter) {
// The common source of vector.extract operations (if one exists), as well
- // as its shape and rank. Set in the first iteration of the loop over the
- // operands of `fromElementsOp`.
+ // as its shape and rank. These are set in the first iteration of the loop
+ // over the operands (elements) of `fromElementsOp`.
Value source;
ArrayRef<int64_t> shape;
int64_t rank;
@@ -2426,9 +2425,8 @@ rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp,
// Check that the (linearized) index of extraction is the same as the index
// in the result of `fromElementsOp`.
ArrayRef<int64_t> position = extractOp.getStaticPosition();
- if (position.size() != rank)
- return failure();
-
+ assert(position.size() == rank &&
+ "scalar extract must have full rank position");
int64_t stride{1};
int64_t offset{0};
for (auto [pos, size] :
>From 94c5d8c7b3919975f46437c481f9f49e076814f8 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 14 May 2025 11:28:40 -0700
Subject: [PATCH 4/6] fix blindspot where source is larger
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 116 +++++++++++-------
.../canonicalize/vector-from-elements.mlir | 11 ++
2 files changed, 84 insertions(+), 43 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index e0b406ce0bc46..7b7f014480ccd 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -33,6 +33,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Support/LLVM.h"
@@ -2394,60 +2395,89 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
///
/// becomes
/// %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8>
-static LogicalResult
-rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp,
- PatternRewriter &rewriter) {
+///
+/// The requirements for this to be valid are
+/// i) all elements are extracted from the same vector (source),
+/// ii) source and from_elements result have the same number of elements,
+/// iii) the elements are extracted in ascending order.
+///
+/// It might be possible to rewrite vector.from_elements as a single
+/// vector.extract if (ii) is not satisifed, or in some cases as a
+/// a single vector_extract_strided_slice if (ii) and (iii) are not satisfied,
+/// this is left for future consideration.
+class FromElementsToShapCast : public OpRewritePattern<FromElementsOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
- // The common source of vector.extract operations (if one exists), as well
- // as its shape and rank. These are set in the first iteration of the loop
- // over the operands (elements) of `fromElementsOp`.
- Value source;
- ArrayRef<int64_t> shape;
- int64_t rank;
+ LogicalResult matchAndRewrite(FromElementsOp fromElements,
+ PatternRewriter &rewriter) const override {
- for (auto [index, element] : llvm::enumerate(fromElementsOp.getElements())) {
+ mlir::OperandRange elements = fromElements.getElements();
+ assert(!elements.empty() && "must be at least 1 element");
- // Check that the element is defined by an extract operation, and that
- // the extract is on the same vector as all preceding elements.
- auto extractOp =
- dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
- if (!extractOp)
- return failure();
- Value currentSource = extractOp.getVector();
- if (index == 0) {
- source = currentSource;
- shape = extractOp.getSourceVectorType().getShape();
- rank = shape.size();
- } else if (currentSource != source) {
- return failure();
+ Value firstElement = elements.front();
+ ExtractOp extractOp =
+ dyn_cast_if_present<vector::ExtractOp>(firstElement.getDefiningOp());
+ if (!extractOp) {
+ return rewriter.notifyMatchFailure(
+ fromElements, "first element not from vector.extract");
}
+ VectorType sourceType = extractOp.getSourceVectorType();
+ Value source = extractOp.getVector();
- // Check that the (linearized) index of extraction is the same as the index
- // in the result of `fromElementsOp`.
- ArrayRef<int64_t> position = extractOp.getStaticPosition();
- assert(position.size() == rank &&
- "scalar extract must have full rank position");
- int64_t stride{1};
- int64_t offset{0};
- for (auto [pos, size] :
- llvm::zip(llvm::reverse(position), llvm::reverse(shape))) {
- if (pos == ShapedType::kDynamic)
- return failure();
- offset += pos * stride;
- stride *= size;
+ // Check condition (ii).
+ if (static_cast<size_t>(sourceType.getNumElements()) != elements.size()) {
+ return rewriter.notifyMatchFailure(fromElements,
+ "number of elements differ");
}
- if (offset != index)
- return failure();
- }
- rewriter.replaceOpWithNewOp<ShapeCastOp>(fromElementsOp,
- fromElementsOp.getType(), source);
-}
+ for (auto [indexMinusOne, element] :
+ llvm::enumerate(elements.drop_front(1))) {
+
+ extractOp =
+ dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
+ if (!extractOp) {
+ return rewriter.notifyMatchFailure(fromElements,
+ "element not from vector.extract");
+ }
+ Value currentSource = extractOp.getVector();
+ // Check condition (i).
+ if (currentSource != source) {
+ return rewriter.notifyMatchFailure(fromElements,
+ "element from different vector");
+ }
+
+ ArrayRef<int64_t> position = extractOp.getStaticPosition();
+ assert(position.size() == static_cast<size_t>(sourceType.getRank()) &&
+ "scalar extract must have full rank position");
+ int64_t stride{1};
+ int64_t offset{0};
+ for (auto [pos, size] : llvm::zip(llvm::reverse(position),
+ llvm::reverse(sourceType.getShape()))) {
+ if (pos == ShapedType::kDynamic) {
+ return rewriter.notifyMatchFailure(
+ fromElements, "elements not in ascending order (dynamic order)");
+ }
+ offset += pos * stride;
+ stride *= size;
+ }
+ // Check condition (iii).
+ if (offset != static_cast<int64_t>(indexMinusOne + 1)) {
+ return rewriter.notifyMatchFailure(
+ fromElements, "elements not in ascending order (static order)");
+ }
+ }
+
+ rewriter.replaceOpWithNewOp<ShapeCastOp>(fromElements,
+ fromElements.getType(), source);
+ return success();
+ }
+};
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add(rewriteFromElementsAsSplat);
- results.add(rewriteFromElementsAsShapeCast);
+ results.add<FromElementsToShapCast>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index fafac4419d719..2899abb07c97c 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -143,3 +143,14 @@ func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi
%2 = vector.from_elements %0, %1 : vector<2xi8>
return %2 : vector<2xi8>
}
+
+// -----
+
+// CHECK-LABEL: func @negative_source_too_large(
+// CHECK-NOT: shape_cast
+func.func @negative_source_too_large(%arg0: vector<1x3xi8>) -> vector<2xi8> {
+ %0 = vector.extract %arg0[0, 0] : i8 from vector<1x3xi8>
+ %1 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8>
+ %2 = vector.from_elements %0, %1 : vector<2xi8>
+ return %2 : vector<2xi8>
+}
>From 28fccebe260d0ddf9bdb48ebde5efe36d7967516 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 14 May 2025 12:40:11 -0700
Subject: [PATCH 5/6] spacing nit
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 7b7f014480ccd..1080263ed3eb6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2414,8 +2414,8 @@ class FromElementsToShapCast : public OpRewritePattern<FromElementsOp> {
mlir::OperandRange elements = fromElements.getElements();
assert(!elements.empty() && "must be at least 1 element");
-
Value firstElement = elements.front();
+
ExtractOp extractOp =
dyn_cast_if_present<vector::ExtractOp>(firstElement.getDefiningOp());
if (!extractOp) {
>From e985a7ef3fcfda4fc435218fba6349611e5deae0 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 14 May 2025 12:47:25 -0700
Subject: [PATCH 6/6] fix test grouping title
---
.../Vector/canonicalize/vector-from-elements.mlir | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index 2899abb07c97c..14bf5d9df4783 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
-// This file contains some tests of folding/canonicalizing vector.from_elements
+// This file contains some tests of folding/canonicalizing vector.from_elements
///===----------------------------------------------===//
/// Tests of `rewriteFromElementsAsSplat`
@@ -75,9 +75,8 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
// -----
-
///===----------------------------------------------===//
-/// Tests of `rewriteFromElementsAsShapeCast`
+/// Tests of `FromElementsToShapeCast`
///===----------------------------------------------===//
// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
@@ -112,7 +111,7 @@ func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8>
// -----
-// The extracted elements are recombined into a single vector, but in a new order.
+// The extracted elements are recombined into a single vector, but in a new order.
// CHECK-LABEL: func @negative_nonascending_order(
// CHECK-NOT: shape_cast
func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> {
@@ -122,7 +121,7 @@ func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> {
return %2 : vector<2xi8>
}
-// -----
+// -----
// CHECK-LABEL: func @negative_nonstatic_extract(
// CHECK-NOT: shape_cast
More information about the Mlir-commits
mailing list