[Mlir-commits] [mlir] [mli][vector] canonicalize vector.from_elements from ascending extracts (PR #139819)

James Newling llvmlistbot at llvm.org
Mon Jun 2 08:08:11 PDT 2025


https://github.com/newling updated https://github.com/llvm/llvm-project/pull/139819

>From c252a3dd92ab1145dc80a847729423366a8f1dea Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 13 May 2025 16:06:45 -0700
Subject: [PATCH 01/10] first commit

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 58 +++++++++++++++
 mlir/test/Dialect/Vector/canonicalize.mlir    | 69 ------------------
 .../canonicalize/vector-from-elements.mlir    | 72 +++++++++++++++++++
 3 files changed, 130 insertions(+), 69 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f6c3c6a61afb6..71844e62baba7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2385,6 +2385,64 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
   return success();
 }
 
+static LogicalResult
+rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp,
+                               PatternRewriter &rewriter) {
+
+  mlir::OperandRange elements = fromElementsOp.getElements();
+  const size_t nbElements = elements.size();
+  assert(nbElements > 0 && "must be at least one element");
+
+  // https://en.wikipedia.org/wiki/List_of_prime_numbers
+  const int prime = 5387;
+  bool pseudoRandomOrder = nbElements < prime;
+
+  Value source;
+  ArrayRef<int64_t> shape;
+  for (size_t elementIndex = 0ULL; elementIndex < nbElements; elementIndex++) {
+
+    // Rather than iterating through the elements in ascending order, we might
+    // be able to exit quickly if we go through in pseudo-random order. Use
+    // fact that (i * p) % a is a bijection for i in [0, a) if p is prime and
+    // a < p.
+    int currentIndex =
+        pseudoRandomOrder ? elementIndex : (elementIndex * prime) % nbElements;
+    Value element = elements[currentIndex];
+
+    // From an extract on the same source as the other elements.
+    auto extractOp =
+        dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
+    if (!extractOp)
+      return failure();
+    Value currentSource = extractOp.getVector();
+    if (!source) {
+      source = currentSource;
+      shape = extractOp.getSourceVectorType().getShape();
+    } else if (currentSource != source) {
+      return failure();
+    }
+
+    ArrayRef<int64_t> position = extractOp.getStaticPosition();
+    assert(position.size() == shape.size());
+
+    int64_t stride{1};
+    int64_t offset{0};
+    for (auto [pos, size] :
+         llvm::zip(llvm::reverse(position), llvm::reverse(shape))) {
+      if (pos == ShapedType::kDynamic)
+        return failure();
+      offset += pos * stride;
+      stride *= size;
+    }
+    if (offset != currentIndex)
+      return failure();
+  }
+
+  // Can replace with a shape_cast.
+  rewriter.replaceOpWithNewOp<ShapeCastOp>(fromElementsOp,
+                                           fromElementsOp.getType(), source);
+}
+
 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {
   results.add(rewriteFromElementsAsSplat);
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 99f0850000a16..6af517d988360 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2952,75 +2952,6 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
 
 // -----
 
-// CHECK-LABEL: func @extract_scalar_from_from_elements(
-//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
-func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
-  // Extract from 0D.
-  %0 = vector.from_elements %a : vector<f32>
-  %1 = vector.extract %0[] : f32 from vector<f32>
-
-  // Extract from 1D.
-  %2 = vector.from_elements %a : vector<1xf32>
-  %3 = vector.extract %2[0] : f32 from vector<1xf32>
-  %4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
-  %5 = vector.extract %4[4] : f32 from vector<5xf32>
-
-  // Extract from 2D.
-  %6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
-  %7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
-  %8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
-  %9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
-  %10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>
-
-  // CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
-  return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
-}
-
-// -----
-
-// CHECK-LABEL: func @extract_1d_from_from_elements(
-//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
-func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
-  %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
-  // CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
-  %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
-  // CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
-  %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
-  // CHECK: return %[[splat1]], %[[splat2]]
-  return %1, %2 : vector<3xf32>, vector<3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @extract_2d_from_from_elements(
-//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
-func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
-  %0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
-  // CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
-  %1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
-  // CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
-  %2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
-  // CHECK: return %[[splat1]], %[[splat2]]
-  return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @from_elements_to_splat(
-//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
-func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
-  // CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
-  %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
-  // CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
-  %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
-  // CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
-  %2 = vector.from_elements %a : vector<f32>
-  // CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
-  return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
-}
-
-// -----
-
 // CHECK-LABEL: func @vector_insert_const_regression(
 //       CHECK:   llvm.mlir.undef
 //       CHECK:   vector.insert
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
new file mode 100644
index 0000000000000..21ce71473a3cd
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// This file contains some tests of folding/canonicalizing vector.from_elements 
+
+// CHECK-LABEL: func @extract_scalar_from_from_elements(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
+  // Extract from 0D.
+  %0 = vector.from_elements %a : vector<f32>
+  %1 = vector.extract %0[] : f32 from vector<f32>
+
+  // Extract from 1D.
+  %2 = vector.from_elements %a : vector<1xf32>
+  %3 = vector.extract %2[0] : f32 from vector<1xf32>
+  %4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
+  %5 = vector.extract %4[4] : f32 from vector<5xf32>
+
+  // Extract from 2D.
+  %6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
+  %7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
+  %8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
+  %9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
+  %10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>
+
+  // CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
+  return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_1d_from_from_elements(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
+  %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
+  // CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
+  %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
+  // CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
+  %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
+  // CHECK: return %[[splat1]], %[[splat2]]
+  return %1, %2 : vector<3xf32>, vector<3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_2d_from_from_elements(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
+  %0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
+  // CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
+  %1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
+  // CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
+  %2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
+  // CHECK: return %[[splat1]], %[[splat2]]
+  return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @from_elements_to_splat(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
+  // CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
+  %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
+  // CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
+  %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
+  // CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
+  %2 = vector.from_elements %a : vector<f32>
+  // CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
+  return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
+}
+
+// -----

>From 7f40da6728bb5b197548d3466818869d45ce9720 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 13 May 2025 17:33:41 -0700
Subject: [PATCH 02/10] improvements

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 46 ++++++------
 .../canonicalize/vector-from-elements.mlir    | 73 +++++++++++++++++++
 2 files changed, 98 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 71844e62baba7..e0ce41e5d6245 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2385,45 +2385,49 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
   return success();
 }
 
+
+/// Rewrite a vecor.from_elements as a vector.shape_cast, if possible. 
+///
+/// Example:
+///   %0 = vector.extract %source[0, 0] : i8 from vector<1x2xi8>
+///   %1 = vector.extract %source[0, 1] : i8 from vector<1x2xi8>
+///   %2 = vector.from_elements %0, %1 : vector<2xi8>
+///
+/// becomes
+///   %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8>
 static LogicalResult
 rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp,
                                PatternRewriter &rewriter) {
 
-  mlir::OperandRange elements = fromElementsOp.getElements();
-  const size_t nbElements = elements.size();
-  assert(nbElements > 0 && "must be at least one element");
-
-  // https://en.wikipedia.org/wiki/List_of_prime_numbers
-  const int prime = 5387;
-  bool pseudoRandomOrder = nbElements < prime;
-
+  // The common source of vector.extract operations (if one exists), as well
+  // as its shape and rank. Set in the first iteration of the loop over the
+  // operands of `fromElementsOp`.
   Value source;
   ArrayRef<int64_t> shape;
-  for (size_t elementIndex = 0ULL; elementIndex < nbElements; elementIndex++) {
+  int64_t rank;
 
-    // Rather than iterating through the elements in ascending order, we might
-    // be able to exit quickly if we go through in pseudo-random order. Use
-    // fact that (i * p) % a is a bijection for i in [0, a) if p is prime and
-    // a < p.
-    int currentIndex =
-        pseudoRandomOrder ? elementIndex : (elementIndex * prime) % nbElements;
-    Value element = elements[currentIndex];
+  for (auto [index, element] : llvm::enumerate(fromElementsOp.getElements())) {
 
-    // From an extract on the same source as the other elements.
+    // Check that the element is defined by an extract operation, and that
+    // the extract is on the same vector as all preceding elements.
     auto extractOp =
         dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
     if (!extractOp)
       return failure();
     Value currentSource = extractOp.getVector();
-    if (!source) {
+    if (index == 0) {
       source = currentSource;
       shape = extractOp.getSourceVectorType().getShape();
+      rank = shape.size();
     } else if (currentSource != source) {
       return failure();
     }
 
+    // Check that the (linearized) index of extraction is the same as the index
+    // in the result of `fromElementsOp`.
     ArrayRef<int64_t> position = extractOp.getStaticPosition();
-    assert(position.size() == shape.size());
+    if (position.size() != rank)
+      return failure();
 
     int64_t stride{1};
     int64_t offset{0};
@@ -2434,11 +2438,10 @@ rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp,
       offset += pos * stride;
       stride *= size;
     }
-    if (offset != currentIndex)
+    if (offset != index)
       return failure();
   }
 
-  // Can replace with a shape_cast.
   rewriter.replaceOpWithNewOp<ShapeCastOp>(fromElementsOp,
                                            fromElementsOp.getType(), source);
 }
@@ -2446,6 +2449,7 @@ rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp,
 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {
   results.add(rewriteFromElementsAsSplat);
+  results.add(rewriteFromElementsAsShapeCast);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index 21ce71473a3cd..fafac4419d719 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -2,6 +2,10 @@
 
 // This file contains some tests of folding/canonicalizing vector.from_elements 
 
+///===----------------------------------------------===//
+///  Tests of `rewriteFromElementsAsSplat`
+///===----------------------------------------------===//
+
 // CHECK-LABEL: func @extract_scalar_from_from_elements(
 //  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
 func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
@@ -70,3 +74,72 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
 }
 
 // -----
+
+
+///===----------------------------------------------===//
+///  Tests of `rewriteFromElementsAsShapeCast`
+///===----------------------------------------------===//
+
+// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
+//  CHECK-SAME:       %[[a:.*]]: vector<1x2xi8>)
+//       CHECK:       %[[shape_cast:.*]] = vector.shape_cast %[[a]] : vector<1x2xi8> to vector<2xi8>
+//       CHECK:       return %[[shape_cast]] : vector<2xi8>
+func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
+  %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
+  %1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
+  %4 = vector.from_elements %0, %1 : vector<2xi8>
+  return %4 : vector<2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @to_shape_cast_rank1_to_rank3(
+//  CHECK-SAME:       %[[a:.*]]: vector<8xi8>)
+//       CHECK:       %[[shape_cast:.*]] = vector.shape_cast %[[a]] : vector<8xi8> to vector<2x2x2xi8>
+//       CHECK:       return %[[shape_cast]] : vector<2x2x2xi8>
+func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8> {
+  %0 = vector.extract %arg0[0] : i8 from vector<8xi8>
+  %1 = vector.extract %arg0[1] : i8 from vector<8xi8>
+  %2 = vector.extract %arg0[2] : i8 from vector<8xi8>
+  %3 = vector.extract %arg0[3] : i8 from vector<8xi8>
+  %4 = vector.extract %arg0[4] : i8 from vector<8xi8>
+  %5 = vector.extract %arg0[5] : i8 from vector<8xi8>
+  %6 = vector.extract %arg0[6] : i8 from vector<8xi8>
+  %7 = vector.extract %arg0[7] : i8 from vector<8xi8>
+  %8 = vector.from_elements %0, %1, %2, %3, %4, %5, %6, %7 : vector<2x2x2xi8>
+  return %8 : vector<2x2x2xi8>
+}
+
+// -----
+
+// The extracted elements are recombined into a single vector, but in a new order. 
+// CHECK-LABEL: func @negative_nonascending_order(
+//   CHECK-NOT: shape_cast
+func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> {
+  %0 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
+  %1 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
+  %2 = vector.from_elements %0, %1 : vector<2xi8>
+  return %2 : vector<2xi8>
+}
+
+// ----- 
+
+// CHECK-LABEL: func @negative_nonstatic_extract(
+//   CHECK-NOT: shape_cast
+func.func @negative_nonstatic_extract(%arg0: vector<1x2xi8>, %i0 : index, %i1 : index) -> vector<2xi8> {
+  %0 = vector.extract %arg0[0, %i0] : i8 from vector<1x2xi8>
+  %1 = vector.extract %arg0[0, %i1] : i8 from vector<1x2xi8>
+  %2 = vector.from_elements %0, %1 : vector<2xi8>
+  return %2 : vector<2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_different_sources(
+//   CHECK-NOT: shape_cast
+func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi8>) -> vector<2xi8> {
+  %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
+  %1 = vector.extract %arg1[0, 1] : i8 from vector<1x2xi8>
+  %2 = vector.from_elements %0, %1 : vector<2xi8>
+  return %2 : vector<2xi8>
+}

>From 3ce0713d33fef03141239a00033f5bf5c896fc72 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 13 May 2025 17:41:11 -0700
Subject: [PATCH 03/10] apply some polish

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 12 +++++-------
 1 file changed, 5 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index e0ce41e5d6245..e0b406ce0bc46 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2385,8 +2385,7 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
   return success();
 }
 
-
-/// Rewrite a vecor.from_elements as a vector.shape_cast, if possible. 
+/// Rewrite vector.from_elements as vector.shape_cast, if possible.
 ///
 /// Example:
 ///   %0 = vector.extract %source[0, 0] : i8 from vector<1x2xi8>
@@ -2400,8 +2399,8 @@ rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp,
                                PatternRewriter &rewriter) {
 
   // The common source of vector.extract operations (if one exists), as well
-  // as its shape and rank. Set in the first iteration of the loop over the
-  // operands of `fromElementsOp`.
+  // as its shape and rank. These are set in the first iteration of the loop
+  // over the operands (elements) of `fromElementsOp`.
   Value source;
   ArrayRef<int64_t> shape;
   int64_t rank;
@@ -2426,9 +2425,8 @@ rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp,
     // Check that the (linearized) index of extraction is the same as the index
     // in the result of `fromElementsOp`.
     ArrayRef<int64_t> position = extractOp.getStaticPosition();
-    if (position.size() != rank)
-      return failure();
-
+    assert(position.size() == rank &&
+           "scalar extract must have full rank position");
     int64_t stride{1};
     int64_t offset{0};
     for (auto [pos, size] :

>From 94c5d8c7b3919975f46437c481f9f49e076814f8 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 14 May 2025 11:28:40 -0700
Subject: [PATCH 04/10] fix blindspot where source is larger

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 116 +++++++++++-------
 .../canonicalize/vector-from-elements.mlir    |  11 ++
 2 files changed, 84 insertions(+), 43 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index e0b406ce0bc46..7b7f014480ccd 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -33,6 +33,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/ValueRange.h"
 #include "mlir/Interfaces/SubsetOpInterface.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Support/LLVM.h"
@@ -2394,60 +2395,89 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
 ///
 /// becomes
 ///   %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8>
-static LogicalResult
-rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp,
-                               PatternRewriter &rewriter) {
+///
+/// The requirements for this to be valid are
+/// i) all elements are extracted from the same vector (source),
+/// ii) source and from_elements result have the same number of elements,
+/// iii) the elements are extracted in ascending order.
+///
+/// It might be possible to rewrite vector.from_elements as a single
+/// vector.extract if (ii) is not satisifed, or in some cases as a
+/// a single vector_extract_strided_slice if (ii) and (iii) are not satisfied,
+/// this is left for future consideration.
+class FromElementsToShapCast : public OpRewritePattern<FromElementsOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
 
-  // The common source of vector.extract operations (if one exists), as well
-  // as its shape and rank. These are set in the first iteration of the loop
-  // over the operands (elements) of `fromElementsOp`.
-  Value source;
-  ArrayRef<int64_t> shape;
-  int64_t rank;
+  LogicalResult matchAndRewrite(FromElementsOp fromElements,
+                                PatternRewriter &rewriter) const override {
 
-  for (auto [index, element] : llvm::enumerate(fromElementsOp.getElements())) {
+    mlir::OperandRange elements = fromElements.getElements();
+    assert(!elements.empty() && "must be at least 1 element");
 
-    // Check that the element is defined by an extract operation, and that
-    // the extract is on the same vector as all preceding elements.
-    auto extractOp =
-        dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
-    if (!extractOp)
-      return failure();
-    Value currentSource = extractOp.getVector();
-    if (index == 0) {
-      source = currentSource;
-      shape = extractOp.getSourceVectorType().getShape();
-      rank = shape.size();
-    } else if (currentSource != source) {
-      return failure();
+    Value firstElement = elements.front();
+    ExtractOp extractOp =
+        dyn_cast_if_present<vector::ExtractOp>(firstElement.getDefiningOp());
+    if (!extractOp) {
+      return rewriter.notifyMatchFailure(
+          fromElements, "first element not from vector.extract");
     }
+    VectorType sourceType = extractOp.getSourceVectorType();
+    Value source = extractOp.getVector();
 
-    // Check that the (linearized) index of extraction is the same as the index
-    // in the result of `fromElementsOp`.
-    ArrayRef<int64_t> position = extractOp.getStaticPosition();
-    assert(position.size() == rank &&
-           "scalar extract must have full rank position");
-    int64_t stride{1};
-    int64_t offset{0};
-    for (auto [pos, size] :
-         llvm::zip(llvm::reverse(position), llvm::reverse(shape))) {
-      if (pos == ShapedType::kDynamic)
-        return failure();
-      offset += pos * stride;
-      stride *= size;
+    // Check condition (ii).
+    if (static_cast<size_t>(sourceType.getNumElements()) != elements.size()) {
+      return rewriter.notifyMatchFailure(fromElements,
+                                         "number of elements differ");
     }
-    if (offset != index)
-      return failure();
-  }
 
-  rewriter.replaceOpWithNewOp<ShapeCastOp>(fromElementsOp,
-                                           fromElementsOp.getType(), source);
-}
+    for (auto [indexMinusOne, element] :
+         llvm::enumerate(elements.drop_front(1))) {
+
+      extractOp =
+          dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
+      if (!extractOp) {
+        return rewriter.notifyMatchFailure(fromElements,
+                                           "element not from vector.extract");
+      }
+      Value currentSource = extractOp.getVector();
+      // Check condition (i).
+      if (currentSource != source) {
+        return rewriter.notifyMatchFailure(fromElements,
+                                           "element from different vector");
+      }
+
+      ArrayRef<int64_t> position = extractOp.getStaticPosition();
+      assert(position.size() == static_cast<size_t>(sourceType.getRank()) &&
+             "scalar extract must have full rank position");
+      int64_t stride{1};
+      int64_t offset{0};
+      for (auto [pos, size] : llvm::zip(llvm::reverse(position),
+                                        llvm::reverse(sourceType.getShape()))) {
+        if (pos == ShapedType::kDynamic) {
+          return rewriter.notifyMatchFailure(
+              fromElements, "elements not in ascending order (dynamic order)");
+        }
+        offset += pos * stride;
+        stride *= size;
+      }
+      // Check condition (iii).
+      if (offset != static_cast<int64_t>(indexMinusOne + 1)) {
+        return rewriter.notifyMatchFailure(
+            fromElements, "elements not in ascending order (static order)");
+      }
+    }
+
+    rewriter.replaceOpWithNewOp<ShapeCastOp>(fromElements,
+                                             fromElements.getType(), source);
+    return success();
+  }
+};
 
 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {
   results.add(rewriteFromElementsAsSplat);
-  results.add(rewriteFromElementsAsShapeCast);
+  results.add<FromElementsToShapCast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index fafac4419d719..2899abb07c97c 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -143,3 +143,14 @@ func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi
   %2 = vector.from_elements %0, %1 : vector<2xi8>
   return %2 : vector<2xi8>
 }
+
+// -----
+
+// CHECK-LABEL: func @negative_source_too_large(
+//   CHECK-NOT: shape_cast
+func.func @negative_source_too_large(%arg0: vector<1x3xi8>) -> vector<2xi8> {
+  %0 = vector.extract %arg0[0, 0] : i8 from vector<1x3xi8>
+  %1 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8>
+  %2 = vector.from_elements %0, %1 : vector<2xi8>
+  return %2 : vector<2xi8>
+}

>From 28fccebe260d0ddf9bdb48ebde5efe36d7967516 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 14 May 2025 12:40:11 -0700
Subject: [PATCH 05/10] spacing nit

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 7b7f014480ccd..1080263ed3eb6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2414,8 +2414,8 @@ class FromElementsToShapCast : public OpRewritePattern<FromElementsOp> {
 
     mlir::OperandRange elements = fromElements.getElements();
     assert(!elements.empty() && "must be at least 1 element");
-
     Value firstElement = elements.front();
+
     ExtractOp extractOp =
         dyn_cast_if_present<vector::ExtractOp>(firstElement.getDefiningOp());
     if (!extractOp) {

>From e985a7ef3fcfda4fc435218fba6349611e5deae0 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 14 May 2025 12:47:25 -0700
Subject: [PATCH 06/10] fix test grouping title

---
 .../Vector/canonicalize/vector-from-elements.mlir        | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index 2899abb07c97c..14bf5d9df4783 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
 
-// This file contains some tests of folding/canonicalizing vector.from_elements 
+// This file contains some tests of folding/canonicalizing vector.from_elements
 
 ///===----------------------------------------------===//
 ///  Tests of `rewriteFromElementsAsSplat`
@@ -75,9 +75,8 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
 
 // -----
 
-
 ///===----------------------------------------------===//
-///  Tests of `rewriteFromElementsAsShapeCast`
+///  Tests of `FromElementsToShapeCast`
 ///===----------------------------------------------===//
 
 // CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
@@ -112,7 +111,7 @@ func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8>
 
 // -----
 
-// The extracted elements are recombined into a single vector, but in a new order. 
+// The extracted elements are recombined into a single vector, but in a new order.
 // CHECK-LABEL: func @negative_nonascending_order(
 //   CHECK-NOT: shape_cast
 func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> {
@@ -122,7 +121,7 @@ func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> {
   return %2 : vector<2xi8>
 }
 
-// ----- 
+// -----
 
 // CHECK-LABEL: func @negative_nonstatic_extract(
 //   CHECK-NOT: shape_cast

>From 74d74f9596c86754675e35509f2aaa2106d5be80 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Fri, 16 May 2025 09:12:59 -0700
Subject: [PATCH 07/10] initial pass of review comments. Add additional test

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 81 ++++++++++---------
 .../canonicalize/vector-from-elements.mlir    | 56 ++++++++-----
 2 files changed, 79 insertions(+), 58 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1080263ed3eb6..e671dddad3a8f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2397,13 +2397,13 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
 ///   %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8>
 ///
 /// The requirements for this to be valid are
-/// i) all elements are extracted from the same vector (source),
-/// ii) source and from_elements result have the same number of elements,
+/// i) source and from_elements result have the same number of elements,
+/// ii) all elements are extracted from the same vector (%source),
 /// iii) the elements are extracted in ascending order.
 ///
 /// It might be possible to rewrite vector.from_elements as a single
-/// vector.extract if (ii) is not satisifed, or in some cases as a
-/// a single vector_extract_strided_slice if (ii) and (iii) are not satisfied,
+/// vector.extract if (i) is not satisifed, or in some cases as a
+/// a single vector_extract_strided_slice if (i) and (iii) are not satisfied,
 /// this is left for future consideration.
 class FromElementsToShapCast : public OpRewritePattern<FromElementsOp> {
 public:
@@ -2412,64 +2412,71 @@ class FromElementsToShapCast : public OpRewritePattern<FromElementsOp> {
   LogicalResult matchAndRewrite(FromElementsOp fromElements,
                                 PatternRewriter &rewriter) const override {
 
-    mlir::OperandRange elements = fromElements.getElements();
-    assert(!elements.empty() && "must be at least 1 element");
-    Value firstElement = elements.front();
+    // The source of the first element. This is initialized in the first
+    // iteration of the loop over elements.
+    TypedValue<VectorType> firstElementSource;
 
-    ExtractOp extractOp =
-        dyn_cast_if_present<vector::ExtractOp>(firstElement.getDefiningOp());
-    if (!extractOp) {
-      return rewriter.notifyMatchFailure(
-          fromElements, "first element not from vector.extract");
-    }
-    VectorType sourceType = extractOp.getSourceVectorType();
-    Value source = extractOp.getVector();
-
-    // Check condition (ii).
-    if (static_cast<size_t>(sourceType.getNumElements()) != elements.size()) {
-      return rewriter.notifyMatchFailure(fromElements,
-                                         "number of elements differ");
-    }
+    for (auto [insertIndex, element] :
+         llvm::enumerate(fromElements.getElements())) {
 
-    for (auto [indexMinusOne, element] :
-         llvm::enumerate(elements.drop_front(1))) {
-
-      extractOp =
+      // Check that the element is from a vector.extract operation.
+      auto extractOp =
           dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
       if (!extractOp) {
         return rewriter.notifyMatchFailure(fromElements,
                                            "element not from vector.extract");
       }
+
+      // Check condition (i) on the first element. As we will check that all
+      // elements have the same source, we don't need to check condition (i) for
+      // any other elements.
+      if (insertIndex == 0) {
+        firstElementSource = extractOp.getVector();
+        if (static_cast<size_t>(
+                firstElementSource.getType().getNumElements()) !=
+            fromElements.getType().getNumElements()) {
+          return rewriter.notifyMatchFailure(fromElements,
+                                             "number of elements differ");
+        }
+      }
+
+      // Check condition (ii), by checking that all elements have same source as
+      // the first element.
       Value currentSource = extractOp.getVector();
-      // Check condition (i).
-      if (currentSource != source) {
+      if (currentSource != firstElementSource) {
         return rewriter.notifyMatchFailure(fromElements,
                                            "element from different vector");
       }
 
+      // Check condition (iii).
+      // First, get the index that the element is extracted from.
+      int64_t extractIndex{0};
+      int64_t stride{1};
       ArrayRef<int64_t> position = extractOp.getStaticPosition();
-      assert(position.size() == static_cast<size_t>(sourceType.getRank()) &&
+      assert(position.size() ==
+                 static_cast<size_t>(firstElementSource.getType().getRank()) &&
              "scalar extract must have full rank position");
-      int64_t stride{1};
-      int64_t offset{0};
-      for (auto [pos, size] : llvm::zip(llvm::reverse(position),
-                                        llvm::reverse(sourceType.getShape()))) {
+      for (auto [pos, size] :
+           llvm::zip(llvm::reverse(position),
+                     llvm::reverse(firstElementSource.getType().getShape()))) {
         if (pos == ShapedType::kDynamic) {
           return rewriter.notifyMatchFailure(
               fromElements, "elements not in ascending order (dynamic order)");
         }
-        offset += pos * stride;
+        extractIndex += pos * stride;
         stride *= size;
       }
-      // Check condition (iii).
-      if (offset != static_cast<int64_t>(indexMinusOne + 1)) {
+
+      // Second, check that the index of extraction from source and insertion in
+      // from_elements are the same.
+      if (extractIndex != static_cast<int64_t>(insertIndex)) {
         return rewriter.notifyMatchFailure(
             fromElements, "elements not in ascending order (static order)");
       }
     }
 
-    rewriter.replaceOpWithNewOp<ShapeCastOp>(fromElements,
-                                             fromElements.getType(), source);
+    rewriter.replaceOpWithNewOp<ShapeCastOp>(
+        fromElements, fromElements.getType(), firstElementSource);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index 14bf5d9df4783..49641eced607f 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -7,7 +7,7 @@
 ///===----------------------------------------------===//
 
 // CHECK-LABEL: func @extract_scalar_from_from_elements(
-//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+//  CHECK-SAME:     %[[A:.*]]: f32, %[[B:.*]]: f32)
 func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
   // Extract from 0D.
   %0 = vector.from_elements %a : vector<f32>
@@ -26,50 +26,50 @@ func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32
   %9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
   %10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>
 
-  // CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
+  // CHECK: return %[[A]], %[[A]], %[[B]], %[[A]], %[[A]], %[[B]], %[[B]]
   return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
 }
 
 // -----
 
 // CHECK-LABEL: func @extract_1d_from_from_elements(
-//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+//  CHECK-SAME:     %[[A:.*]]: f32, %[[B:.*]]: f32)
 func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
   %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
-  // CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
+  // CHECK: %[[SPLAT1:.*]] = vector.splat %[[A]] : vector<3xf32>
   %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
-  // CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
+  // CHECK: %[[SPLAT2:.*]] = vector.splat %[[B]] : vector<3xf32>
   %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
-  // CHECK: return %[[splat1]], %[[splat2]]
+  // CHECK: return %[[SPLAT1]], %[[SPLAT2]]
   return %1, %2 : vector<3xf32>, vector<3xf32>
 }
 
 // -----
 
 // CHECK-LABEL: func @extract_2d_from_from_elements(
-//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+//  CHECK-SAME:     %[[A:.*]]: f32, %[[B:.*]]: f32)
 func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
   %0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
-  // CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
+  // CHECK: %[[SPLAT1:.*]] = vector.from_elements %[[A]], %[[A]], %[[A]], %[[B]] : vector<2x2xf32>
   %1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
-  // CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
+  // CHECK: %[[SPLAT2:.*]] = vector.from_elements %[[B]], %[[B]], %[[B]], %[[A]] : vector<2x2xf32>
   %2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
-  // CHECK: return %[[splat1]], %[[splat2]]
+  // CHECK: return %[[SPLAT1]], %[[SPLAT2]]
   return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
 }
 
 // -----
 
 // CHECK-LABEL: func @from_elements_to_splat(
-//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+//  CHECK-SAME:     %[[A:.*]]: f32, %[[B:.*]]: f32)
 func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
-  // CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
+  // CHECK: %[[SPLAT:.*]] = vector.splat %[[A]] : vector<2x3xf32>
   %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
-  // CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
+  // CHECK: %[[FROM_EL:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
   %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
-  // CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
+  // CHECK: %[[SPLAT2:.*]] = vector.splat %[[A]] : vector<f32>
   %2 = vector.from_elements %a : vector<f32>
-  // CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
+  // CHECK: return %[[SPLAT]], %[[FROM_EL]], %[[SPLAT2]]
   return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
 }
 
@@ -80,9 +80,9 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
 ///===----------------------------------------------===//
 
 // CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
-//  CHECK-SAME:       %[[a:.*]]: vector<1x2xi8>)
-//       CHECK:       %[[shape_cast:.*]] = vector.shape_cast %[[a]] : vector<1x2xi8> to vector<2xi8>
-//       CHECK:       return %[[shape_cast]] : vector<2xi8>
+//  CHECK-SAME:       %[[A:.*]]: vector<1x2xi8>)
+//       CHECK:       %[[SHAPE_CAST:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8>
+//       CHECK:       return %[[SHAPE_CAST]] : vector<2xi8>
 func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
   %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
   %1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
@@ -93,9 +93,9 @@ func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
 // -----
 
 // CHECK-LABEL: func @to_shape_cast_rank1_to_rank3(
-//  CHECK-SAME:       %[[a:.*]]: vector<8xi8>)
-//       CHECK:       %[[shape_cast:.*]] = vector.shape_cast %[[a]] : vector<8xi8> to vector<2x2x2xi8>
-//       CHECK:       return %[[shape_cast]] : vector<2x2x2xi8>
+//  CHECK-SAME:       %[[A:.*]]: vector<8xi8>)
+//       CHECK:       %[[SHAPE_CAST:.*]] = vector.shape_cast %[[A]] : vector<8xi8> to vector<2x2x2xi8>
+//       CHECK:       return %[[SHAPE_CAST]] : vector<2x2x2xi8>
 func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8> {
   %0 = vector.extract %arg0[0] : i8 from vector<8xi8>
   %1 = vector.extract %arg0[1] : i8 from vector<8xi8>
@@ -153,3 +153,17 @@ func.func @negative_source_too_large(%arg0: vector<1x3xi8>) -> vector<2xi8> {
   %2 = vector.from_elements %0, %1 : vector<2xi8>
   return %2 : vector<2xi8>
 }
+
+// -----
+
+// The inserted elements are are a subset of the extracted elements.
+// [0, 1, 2] -> [1, 1, 2]
+// CHECK-LABEL: func @negative_nobijection_order(
+//   CHECK-NOT: shape_cast
+func.func @negative_nobijection_order(%arg0: vector<1x3xi8>) -> vector<3xi8> {
+  %0 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8>
+  %1 = vector.extract %arg0[0, 2] : i8 from vector<1x3xi8>
+  %2 = vector.from_elements %0, %0, %1 : vector<3xi8>
+  return %2 : vector<3xi8>
+}
+

>From c584ac8b13076132af105f592a91875b5cb3fc6d Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 19 May 2025 09:24:40 -0700
Subject: [PATCH 08/10] additional tests

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 108 ++++++++++++------
 .../canonicalize/vector-from-elements.mlir    |  36 +++++-
 2 files changed, 104 insertions(+), 40 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index e671dddad3a8f..ddc01dfd9f1fa 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2379,6 +2379,7 @@ std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
 /// ==> rewrite to vector.splat %a : vector<3xf32>
 static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
                                                 PatternRewriter &rewriter) {
+
   if (!llvm::all_equal(fromElementsOp.getElements()))
     return failure();
   rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
@@ -2386,35 +2387,44 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
   return success();
 }
 
-/// Rewrite vector.from_elements as vector.shape_cast, if possible.
+/// Rewrite vector.from_elements(vector.extract, vector.extract, ...) as 
+///         vector.shape_cast(vector.extact) if possible. 
 ///
 /// Example:
-///   %0 = vector.extract %source[0, 0] : i8 from vector<1x2xi8>
-///   %1 = vector.extract %source[0, 1] : i8 from vector<1x2xi8>
+///   %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8>
+///   %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8>
 ///   %2 = vector.from_elements %0, %1 : vector<2xi8>
 ///
 /// becomes
-///   %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8>
+///   %1 = vector.extract %source[0] : vector<1x2xi8> from vector<2x2xi8>
+///   %2 = vector.shape_cast %1 : vector<1x2xi8> to vector<2xi8>
 ///
 /// The requirements for this to be valid are
-/// i) source and from_elements result have the same number of elements,
-/// ii) all elements are extracted from the same vector (%source),
-/// iii) the elements are extracted in ascending order.
 ///
-/// It might be possible to rewrite vector.from_elements as a single
-/// vector.extract if (i) is not satisifed, or in some cases as a
-/// a single vector_extract_strided_slice if (i) and (iii) are not satisfied,
-/// this is left for future consideration.
-class FromElementsToShapCast : public OpRewritePattern<FromElementsOp> {
+/// i) all elements are extracted from the same vector (%source)
+/// ii) the elements form a suffix of %source
+/// iii) the elements are extracted contiguously in ascending order
+
+class FromElementsToShapeCast
+    : public OpRewritePattern<FromElementsOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(FromElementsOp fromElements,
                                 PatternRewriter &rewriter) const override {
 
-    // The source of the first element. This is initialized in the first
-    // iteration of the loop over elements.
+    // Left for `rewriteFromElementsAsSplat` to avoid divergent
+    // canonicalizations
+    if (fromElements.getType().getNumElements() == 1) {
+      return failure();
+    }
+
+    // The source of the first element, the position (N-d vector) that the first
+    // element is extracted from, and the flattened position (index). These are
+    // all obtained in the first iteration of the loop over elements.
     TypedValue<VectorType> firstElementSource;
+    ArrayRef<int64_t> firstElementExtractPosition;
+    int64_t firstElementExtractIndex;
 
     for (auto [insertIndex, element] :
          llvm::enumerate(fromElements.getElements())) {
@@ -2427,32 +2437,19 @@ class FromElementsToShapCast : public OpRewritePattern<FromElementsOp> {
                                            "element not from vector.extract");
       }
 
-      // Check condition (i) on the first element. As we will check that all
-      // elements have the same source, we don't need to check condition (i) for
-      // any other elements.
+      // Check condition (i) by checking that all elements have same source as
+      // the first element.
       if (insertIndex == 0) {
         firstElementSource = extractOp.getVector();
-        if (static_cast<size_t>(
-                firstElementSource.getType().getNumElements()) !=
-            fromElements.getType().getNumElements()) {
-          return rewriter.notifyMatchFailure(fromElements,
-                                             "number of elements differ");
-        }
-      }
-
-      // Check condition (ii), by checking that all elements have same source as
-      // the first element.
-      Value currentSource = extractOp.getVector();
-      if (currentSource != firstElementSource) {
+      } else if (extractOp.getVector() != firstElementSource) {
         return rewriter.notifyMatchFailure(fromElements,
                                            "element from different vector");
       }
 
-      // Check condition (iii).
-      // First, get the index that the element is extracted from.
+      // Obtain the flattened index of extraction from the N-d position.
+      ArrayRef<int64_t> position = extractOp.getStaticPosition();
       int64_t extractIndex{0};
       int64_t stride{1};
-      ArrayRef<int64_t> position = extractOp.getStaticPosition();
       assert(position.size() ==
                  static_cast<size_t>(firstElementSource.getType().getRank()) &&
              "scalar extract must have full rank position");
@@ -2467,16 +2464,51 @@ class FromElementsToShapCast : public OpRewritePattern<FromElementsOp> {
         stride *= size;
       }
 
-      // Second, check that the index of extraction from source and insertion in
-      // from_elements are the same.
-      if (extractIndex != static_cast<int64_t>(insertIndex)) {
+      // Check condition (ii) using the extraction index of the first element.
+      // We check that the position that the first element is extracted
+      // from has sufficient trailing 0s. For example, in
+      // ```
+      // %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
+      // [...]
+      // %n = vector.from_elements %elm0, [...] : vector<12xi8>
+      // ```
+      // The 2 trailing 0s in the position of extraction of %0 cover 3*4 = 12
+      // elements, which is the number of elements of %n, so this is valid.
+      if (insertIndex == 0) {
+        const int64_t numFinalElements =
+            fromElements.getType().getNumElements();
+        int64_t numElementsInSourceSuffix = 1;
+        int index = position.size();
+        while (index > 0 && position[index - 1] == 0 &&
+               numElementsInSourceSuffix < numFinalElements) {
+          numElementsInSourceSuffix *=
+              firstElementSource.getType().getDimSize(index - 1);
+          --index;
+        }
+        if (numElementsInSourceSuffix != numFinalElements) {
+          return rewriter.notifyMatchFailure(
+              fromElements, "elements do not form a suffix of source");
+        }
+        firstElementExtractIndex = extractIndex;
+        firstElementExtractPosition =
+            position.drop_back(position.size() - index);
+      }
+
+      // Check condition (iii) by checking the index of extraction relative
+      // the first element.
+      else if (static_cast<int64_t>(insertIndex) + firstElementExtractIndex !=
+               extractIndex) {
         return rewriter.notifyMatchFailure(
             fromElements, "elements not in ascending order (static order)");
       }
     }
 
-    rewriter.replaceOpWithNewOp<ShapeCastOp>(
-        fromElements, fromElements.getType(), firstElementSource);
+    auto extracted = rewriter.createOrFold<vector::ExtractOp>(
+        fromElements.getLoc(), firstElementSource, firstElementExtractPosition);
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+        fromElements, fromElements.getType(), extracted);
+
     return success();
   }
 };
@@ -2484,7 +2516,7 @@ class FromElementsToShapCast : public OpRewritePattern<FromElementsOp> {
 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {
   results.add(rewriteFromElementsAsSplat);
-  results.add<FromElementsToShapCast>(context);
+  results.add<FromElementsToShapeCast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index 49641eced607f..ef7bfacdf1a5e 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt --mlir-disable-threading %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
 
 // This file contains some tests of folding/canonicalizing vector.from_elements
 
@@ -109,6 +109,39 @@ func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8>
   return %8 : vector<2x2x2xi8>
 }
 
+
+// -----
+
+//   func.func @bar(%arg0: vector<2x3x4xi8>) -> vector<12xi8> {
+//     %0 = vector.extract %arg0[1] : vector<3x4xi8> from vector<2x3x4xi8>
+//     %1 = vector.shape_cast %0 : vector<3x4xi8> to vector<12xi8>
+//     return %1 : vector<12xi8>
+
+// CHECK-LABEL: func @source_larger_than_out(
+//  CHECK-SAME:     %[[A:.*]]: vector<2x3x4xi8>)
+//       CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]] [1] : vector<3x4xi8> from vector<2x3x4xi8>
+//       CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<3x4xi8> to vector<12xi8>
+//       CHECK: return %[[SHAPE_CAST]] : vector<12xi8>
+
+func.func @source_larger_than_out(%arg0: vector<2x3x4xi8>) -> vector<12xi8> {
+  %0 = vector.extract %arg0[1, 0, 0] : i8 from vector<2x3x4xi8>
+  %1 = vector.extract %arg0[1, 0, 1] : i8 from vector<2x3x4xi8>
+  %2 = vector.extract %arg0[1, 0, 2] : i8 from vector<2x3x4xi8>
+  %3 = vector.extract %arg0[1, 0, 3] : i8 from vector<2x3x4xi8>
+  %4 = vector.extract %arg0[1, 1, 0] : i8 from vector<2x3x4xi8>
+  %5 = vector.extract %arg0[1, 1, 1] : i8 from vector<2x3x4xi8>
+  %6 = vector.extract %arg0[1, 1, 2] : i8 from vector<2x3x4xi8>
+  %7 = vector.extract %arg0[1, 1, 3] : i8 from vector<2x3x4xi8>
+  %8 = vector.extract %arg0[1, 2, 0] : i8 from vector<2x3x4xi8>
+  %9 = vector.extract %arg0[1, 2, 1] : i8 from vector<2x3x4xi8>
+  %10 = vector.extract %arg0[1, 2, 2] : i8 from vector<2x3x4xi8>
+  %11 = vector.extract %arg0[1, 2, 3] : i8 from vector<2x3x4xi8>
+  %12 = vector.from_elements %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11 : vector<12xi8>
+  return %12 : vector<12xi8>
+}
+
+// TODO(newling) add more tests where the source is not the same size as out. 
+
 // -----
 
 // The extracted elements are recombined into a single vector, but in a new order.
@@ -166,4 +199,3 @@ func.func @negative_nobijection_order(%arg0: vector<1x3xi8>) -> vector<3xi8> {
   %2 = vector.from_elements %0, %0, %1 : vector<3xi8>
   return %2 : vector<3xi8>
 }
-

>From e154b2248a89d34a6936476761be07e606a226b0 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 20 May 2025 11:29:12 -0700
Subject: [PATCH 09/10] generalize to extract cast

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 117 ++++++++---------
 .../canonicalize/vector-from-elements.mlir    | 121 ++++++++++++++----
 2 files changed, 149 insertions(+), 89 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ddc01dfd9f1fa..311c2b6387433 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2387,10 +2387,8 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
   return success();
 }
 
-/// Rewrite vector.from_elements(vector.extract, vector.extract, ...) as 
-///         vector.shape_cast(vector.extact) if possible. 
-///
-/// Example:
+/// Rewrite from_elements on multiple scalar extracts as a shape_cast
+/// on a single extract. Example:
 ///   %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8>
 ///   %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8>
 ///   %2 = vector.from_elements %0, %1 : vector<2xi8>
@@ -2401,30 +2399,32 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
 ///
 /// The requirements for this to be valid are
 ///
-/// i) all elements are extracted from the same vector (%source)
-/// ii) the elements form a suffix of %source
-/// iii) the elements are extracted contiguously in ascending order
+///   i) The elements are extracted from the same vector (%source).
+///
+///  ii) The elements form a suffix of %source. Specifically, the number
+///      of elements is the same as the product of the last N dimension sizes
+///      of %source, for some N.
+///
+/// iii) The elements are extracted contiguously in ascending order.
+
+class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
 
-class FromElementsToShapeCast
-    : public OpRewritePattern<FromElementsOp> {
-public:
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(FromElementsOp fromElements,
                                 PatternRewriter &rewriter) const override {
 
-    // Left for `rewriteFromElementsAsSplat` to avoid divergent
-    // canonicalizations
-    if (fromElements.getType().getNumElements() == 1) {
+    // Handled by `rewriteFromElementsAsSplat`
+    if (fromElements.getType().getNumElements() == 1)
       return failure();
-    }
 
-    // The source of the first element, the position (N-d vector) that the first
-    // element is extracted from, and the flattened position (index). These are
-    // all obtained in the first iteration of the loop over elements.
-    TypedValue<VectorType> firstElementSource;
-    ArrayRef<int64_t> firstElementExtractPosition;
-    int64_t firstElementExtractIndex;
+    // The common source that all elements are extracted from, if one exists.
+    TypedValue<VectorType> source;
+    // The position of the combined extract operation, if one is created.
+    ArrayRef<int64_t> combinedPosition;
+    // The expected index of extraction of the current element in the loop, if
+    // elements are extracted contiguously in ascending order.
+    SmallVector<int64_t> expectedPosition;
 
     for (auto [insertIndex, element] :
          llvm::enumerate(fromElements.getElements())) {
@@ -2440,77 +2440,70 @@ class FromElementsToShapeCast
       // Check condition (i) by checking that all elements have same source as
       // the first element.
       if (insertIndex == 0) {
-        firstElementSource = extractOp.getVector();
-      } else if (extractOp.getVector() != firstElementSource) {
+        source = extractOp.getVector();
+      } else if (extractOp.getVector() != source) {
         return rewriter.notifyMatchFailure(fromElements,
                                            "element from different vector");
       }
 
-      // Obtain the flattened index of extraction from the N-d position.
       ArrayRef<int64_t> position = extractOp.getStaticPosition();
-      int64_t extractIndex{0};
-      int64_t stride{1};
-      assert(position.size() ==
-                 static_cast<size_t>(firstElementSource.getType().getRank()) &&
+      int64_t rank = position.size();
+      assert(rank == source.getType().getRank() &&
              "scalar extract must have full rank position");
-      for (auto [pos, size] :
-           llvm::zip(llvm::reverse(position),
-                     llvm::reverse(firstElementSource.getType().getShape()))) {
-        if (pos == ShapedType::kDynamic) {
-          return rewriter.notifyMatchFailure(
-              fromElements, "elements not in ascending order (dynamic order)");
-        }
-        extractIndex += pos * stride;
-        stride *= size;
-      }
 
-      // Check condition (ii) using the extraction index of the first element.
-      // We check that the position that the first element is extracted
-      // from has sufficient trailing 0s. For example, in
-      // ```
-      // %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
-      // [...]
-      // %n = vector.from_elements %elm0, [...] : vector<12xi8>
-      // ```
-      // The 2 trailing 0s in the position of extraction of %0 cover 3*4 = 12
+      // Check condition (ii) by checking that the position that the first
+      // element is extracted from has sufficient trailing 0s. For example, in
+      //
+      //   %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
+      //   [...]
+      //   %elms = vector.from_elements %elm0, [...] : vector<12xi8>
+      //
+      // The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12
       // elements, which is the number of elements of %n, so this is valid.
       if (insertIndex == 0) {
-        const int64_t numFinalElements =
-            fromElements.getType().getNumElements();
-        int64_t numElementsInSourceSuffix = 1;
-        int index = position.size();
+        const int64_t numElms = fromElements.getType().getNumElements();
+        int64_t numSuffixElms = 1;
+        int64_t index = rank;
         while (index > 0 && position[index - 1] == 0 &&
-               numElementsInSourceSuffix < numFinalElements) {
-          numElementsInSourceSuffix *=
-              firstElementSource.getType().getDimSize(index - 1);
+               numSuffixElms < numElms) {
+          numSuffixElms *= source.getType().getDimSize(index - 1);
           --index;
         }
-        if (numElementsInSourceSuffix != numFinalElements) {
+        if (numSuffixElms != numElms) {
           return rewriter.notifyMatchFailure(
               fromElements, "elements do not form a suffix of source");
         }
-        firstElementExtractIndex = extractIndex;
-        firstElementExtractPosition =
-            position.drop_back(position.size() - index);
+        expectedPosition = llvm::to_vector(position);
+        combinedPosition = position.drop_back(rank - index);
       }
 
-      // Check condition (iii) by checking the index of extraction relative
-      // the first element.
-      else if (static_cast<int64_t>(insertIndex) + firstElementExtractIndex !=
-               extractIndex) {
+      // Check condition (iii).
+      else if (expectedPosition != position) {
         return rewriter.notifyMatchFailure(
             fromElements, "elements not in ascending order (static order)");
       }
+      increment(expectedPosition, source.getType().getShape());
     }
 
     auto extracted = rewriter.createOrFold<vector::ExtractOp>(
-        fromElements.getLoc(), firstElementSource, firstElementExtractPosition);
+        fromElements.getLoc(), source, combinedPosition);
 
     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
         fromElements, fromElements.getType(), extracted);
 
     return success();
   }
+
+  /// Increments n-D `indices` by 1 starting from the innermost dimension.
+  static void increment(MutableArrayRef<int64_t> indices,
+                        ArrayRef<int64_t> shape) {
+    for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
+      indices[dim] += 1;
+      if (indices[dim] < shape[dim])
+        break;
+      indices[dim] = 0;
+    }
+  }
 };
 
 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index ef7bfacdf1a5e..fdab2a8918a2e 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --mlir-disable-threading %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
 
 // This file contains some tests of folding/canonicalizing vector.from_elements
 
@@ -7,7 +7,7 @@
 ///===----------------------------------------------===//
 
 // CHECK-LABEL: func @extract_scalar_from_from_elements(
-//  CHECK-SAME:     %[[A:.*]]: f32, %[[B:.*]]: f32)
+//  CHECK-SAME:      %[[A:.*]]: f32, %[[B:.*]]: f32)
 func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
   // Extract from 0D.
   %0 = vector.from_elements %a : vector<f32>
@@ -33,7 +33,7 @@ func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32
 // -----
 
 // CHECK-LABEL: func @extract_1d_from_from_elements(
-//  CHECK-SAME:     %[[A:.*]]: f32, %[[B:.*]]: f32)
+//  CHECK-SAME:      %[[A:.*]]: f32, %[[B:.*]]: f32)
 func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
   %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
   // CHECK: %[[SPLAT1:.*]] = vector.splat %[[A]] : vector<3xf32>
@@ -47,7 +47,7 @@ func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, ve
 // -----
 
 // CHECK-LABEL: func @extract_2d_from_from_elements(
-//  CHECK-SAME:     %[[A:.*]]: f32, %[[B:.*]]: f32)
+//  CHECK-SAME:      %[[A:.*]]: f32, %[[B:.*]]: f32)
 func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
   %0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
   // CHECK: %[[SPLAT1:.*]] = vector.from_elements %[[A]], %[[A]], %[[A]], %[[B]] : vector<2x2xf32>
@@ -61,7 +61,7 @@ func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>,
 // -----
 
 // CHECK-LABEL: func @from_elements_to_splat(
-//  CHECK-SAME:     %[[A:.*]]: f32, %[[B:.*]]: f32)
+//  CHECK-SAME:      %[[A:.*]]: f32, %[[B:.*]]: f32)
 func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
   // CHECK: %[[SPLAT:.*]] = vector.splat %[[A]] : vector<2x3xf32>
   %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
@@ -81,8 +81,8 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
 
 // CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
 //  CHECK-SAME:       %[[A:.*]]: vector<1x2xi8>)
-//       CHECK:       %[[SHAPE_CAST:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8>
-//       CHECK:       return %[[SHAPE_CAST]] : vector<2xi8>
+//       CHECK:       %[[EXTRACT:.*]] = vector.extract %[[A]][0] : vector<2xi8> from vector<1x2xi8>
+//       CHECK:       return %[[EXTRACT]] : vector<2xi8>
 func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
   %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
   %1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
@@ -109,20 +109,13 @@ func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8>
   return %8 : vector<2x2x2xi8>
 }
 
-
 // -----
 
-//   func.func @bar(%arg0: vector<2x3x4xi8>) -> vector<12xi8> {
-//     %0 = vector.extract %arg0[1] : vector<3x4xi8> from vector<2x3x4xi8>
-//     %1 = vector.shape_cast %0 : vector<3x4xi8> to vector<12xi8>
-//     return %1 : vector<12xi8>
-
 // CHECK-LABEL: func @source_larger_than_out(
-//  CHECK-SAME:     %[[A:.*]]: vector<2x3x4xi8>)
-//       CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]] [1] : vector<3x4xi8> from vector<2x3x4xi8>
-//       CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<3x4xi8> to vector<12xi8>
-//       CHECK: return %[[SHAPE_CAST]] : vector<12xi8>
-
+//  CHECK-SAME:      %[[A:.*]]: vector<2x3x4xi8>)
+//       CHECK:      %[[EXTRACT:.*]] = vector.extract %[[A]][1] : vector<3x4xi8> from vector<2x3x4xi8>
+//       CHECK:      %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<3x4xi8> to vector<12xi8>
+//       CHECK:      return %[[SHAPE_CAST]] : vector<12xi8>
 func.func @source_larger_than_out(%arg0: vector<2x3x4xi8>) -> vector<12xi8> {
   %0 = vector.extract %arg0[1, 0, 0] : i8 from vector<2x3x4xi8>
   %1 = vector.extract %arg0[1, 0, 1] : i8 from vector<2x3x4xi8>
@@ -140,13 +133,70 @@ func.func @source_larger_than_out(%arg0: vector<2x3x4xi8>) -> vector<12xi8> {
   return %12 : vector<12xi8>
 }
 
-// TODO(newling) add more tests where the source is not the same size as out. 
+// -----
+
+// This test is similar to `source_larger_than_out` except here the number of elements
+// extracted contigously starting from the first position [0,0] could be 6 instead of 3
+// and the pattern would still match.
+// CHECK-LABEL: func @suffix_with_excess_zeros(
+//       CHECK:      %[[EXT:.*]] = vector.extract {{.*}}[0] : vector<3xi8> from vector<2x3xi8>
+//       CHECK:      return %[[EXT]] : vector<3xi8>
+func.func @suffix_with_excess_zeros(%arg0: vector<2x3xi8>) -> vector<3xi8> {
+  %0 = vector.extract %arg0[0, 0] : i8 from vector<2x3xi8>
+  %1 = vector.extract %arg0[0, 1] : i8 from vector<2x3xi8>
+  %2 = vector.extract %arg0[0, 2] : i8 from vector<2x3xi8>
+  %3 = vector.from_elements %0, %1, %2 : vector<3xi8>
+  return %3 : vector<3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @large_source_with_shape_cast_required(
+//  CHECK-SAME:      %[[A:.*]]: vector<2x2x2x2xi8>)
+//       CHECK:      %[[EXTRACT:.*]] = vector.extract %[[A]][0, 1] : vector<2x2xi8> from vector<2x2x2x2xi8>
+//       CHECK:      %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<2x2xi8> to vector<1x4x1xi8>
+//       CHECK:      return %[[SHAPE_CAST]] : vector<1x4x1xi8>
+func.func @large_source_with_shape_cast_required(%arg0: vector<2x2x2x2xi8>) -> vector<1x4x1xi8> {
+  %0 = vector.extract %arg0[0, 1, 0, 0] : i8 from vector<2x2x2x2xi8>
+  %1 = vector.extract %arg0[0, 1, 0, 1] : i8 from vector<2x2x2x2xi8>
+  %2 = vector.extract %arg0[0, 1, 1, 0] : i8 from vector<2x2x2x2xi8>
+  %3 = vector.extract %arg0[0, 1, 1, 1] : i8 from vector<2x2x2x2xi8>
+  %4 = vector.from_elements %0, %1, %2, %3 : vector<1x4x1xi8>
+  return %4 : vector<1x4x1xi8>
+}
+
+//  -----
+
+// Could match, but handled by `rewriteFromElementsAsSplat`.
+// CHECK-LABEL: func @extract_single_elm(
+//  CHECK-NEXT:      vector.extract
+//  CHECK-NEXT:      vector.splat
+//  CHECK-NEXT:      return
+func.func @extract_single_elm(%arg0 : vector<2x3xi8>) -> vector<1xi8> {
+  %0 = vector.extract %arg0[0, 0] : i8 from vector<2x3xi8>
+  %1 = vector.from_elements %0 : vector<1xi8>
+  return %1 : vector<1xi8>
+}
+
+// -----
+
+//   CHECK-LABEL: func @negative_source_contiguous_but_not_suffix(
+//     CHECK-NOT:      shape_cast
+//         CHECK:      from_elements
+func.func @negative_source_contiguous_but_not_suffix(%arg0: vector<2x3xi8>) -> vector<3xi8> {
+  %0 = vector.extract %arg0[0, 1] : i8 from vector<2x3xi8>
+  %1 = vector.extract %arg0[0, 2] : i8 from vector<2x3xi8>
+  %2 = vector.extract %arg0[1, 0] : i8 from vector<2x3xi8>
+  %3 = vector.from_elements %0, %1, %2 : vector<3xi8>
+  return %3 : vector<3xi8>
+}
 
 // -----
 
 // The extracted elements are recombined into a single vector, but in a new order.
 // CHECK-LABEL: func @negative_nonascending_order(
-//   CHECK-NOT: shape_cast
+//   CHECK-NOT:      shape_cast
+//       CHECK:      from_elements
 func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> {
   %0 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
   %1 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
@@ -157,7 +207,8 @@ func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> {
 // -----
 
 // CHECK-LABEL: func @negative_nonstatic_extract(
-//   CHECK-NOT: shape_cast
+//   CHECK-NOT:      shape_cast
+//       CHECK:      from_elements
 func.func @negative_nonstatic_extract(%arg0: vector<1x2xi8>, %i0 : index, %i1 : index) -> vector<2xi8> {
   %0 = vector.extract %arg0[0, %i0] : i8 from vector<1x2xi8>
   %1 = vector.extract %arg0[0, %i1] : i8 from vector<1x2xi8>
@@ -168,7 +219,8 @@ func.func @negative_nonstatic_extract(%arg0: vector<1x2xi8>, %i0 : index, %i1 :
 // -----
 
 // CHECK-LABEL: func @negative_different_sources(
-//   CHECK-NOT: shape_cast
+//   CHECK-NOT:      shape_cast
+//       CHECK:      from_elements
 func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi8>) -> vector<2xi8> {
   %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
   %1 = vector.extract %arg1[0, 1] : i8 from vector<1x2xi8>
@@ -178,9 +230,10 @@ func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi
 
 // -----
 
-// CHECK-LABEL: func @negative_source_too_large(
-//   CHECK-NOT: shape_cast
-func.func @negative_source_too_large(%arg0: vector<1x3xi8>) -> vector<2xi8> {
+// CHECK-LABEL: func @negative_source_not_suffix(
+//   CHECK-NOT:      shape_cast
+//       CHECK:      from_elements
+func.func @negative_source_not_suffix(%arg0: vector<1x3xi8>) -> vector<2xi8> {
   %0 = vector.extract %arg0[0, 0] : i8 from vector<1x3xi8>
   %1 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8>
   %2 = vector.from_elements %0, %1 : vector<2xi8>
@@ -189,13 +242,27 @@ func.func @negative_source_too_large(%arg0: vector<1x3xi8>) -> vector<2xi8> {
 
 // -----
 
-// The inserted elements are are a subset of the extracted elements.
+// The inserted elements are a subset of the extracted elements.
 // [0, 1, 2] -> [1, 1, 2]
 // CHECK-LABEL: func @negative_nobijection_order(
-//   CHECK-NOT: shape_cast
+//   CHECK-NOT:      shape_cast
+//       CHECK:      from_elements
 func.func @negative_nobijection_order(%arg0: vector<1x3xi8>) -> vector<3xi8> {
   %0 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8>
   %1 = vector.extract %arg0[0, 2] : i8 from vector<1x3xi8>
   %2 = vector.from_elements %0, %0, %1 : vector<3xi8>
   return %2 : vector<3xi8>
 }
+
+// -----
+
+// CHECK-LABEL: func @negative_source_too_small(
+//   CHECK-NOT:      shape_cast
+//       CHECK:      from_elements
+func.func @negative_source_too_small(%arg0: vector<2xi8>) -> vector<4xi8> {
+  %0 = vector.extract %arg0[0] : i8 from vector<2xi8>
+  %1 = vector.extract %arg0[1] : i8 from vector<2xi8>
+  %2 = vector.from_elements %0, %1, %1, %1 : vector<4xi8>
+  return %2 : vector<4xi8>
+}
+

>From dd3e231eefc4a9c6e9bf000af149537d87fa8b4d Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 29 May 2025 09:57:43 -0700
Subject: [PATCH 10/10] uber refinement (empty line removal and definite
 article use..)

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 311c2b6387433..9b6e17cd5abbc 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2379,7 +2379,6 @@ std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
 /// ==> rewrite to vector.splat %a : vector<3xf32>
 static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
                                                 PatternRewriter &rewriter) {
-
   if (!llvm::all_equal(fromElementsOp.getElements()))
     return failure();
   rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
@@ -2437,8 +2436,8 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
                                            "element not from vector.extract");
       }
 
-      // Check condition (i) by checking that all elements have same source as
-      // the first element.
+      // Check condition (i) by checking that all elements have the same source
+      // as the first element.
       if (insertIndex == 0) {
         source = extractOp.getVector();
       } else if (extractOp.getVector() != source) {



More information about the Mlir-commits mailing list