[Mlir-commits] [mlir] 543446a - [mli][vector] canonicalize vector.from_elements from ascending extracts (#139819)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 2 11:15:28 PDT 2025


Author: James Newling
Date: 2025-06-02T11:15:25-07:00
New Revision: 543446a35309d82919dc1b021125e978a7451b12

URL: https://github.com/llvm/llvm-project/commit/543446a35309d82919dc1b021125e978a7451b12
DIFF: https://github.com/llvm/llvm-project/commit/543446a35309d82919dc1b021125e978a7451b12.diff

LOG: [mli][vector] canonicalize vector.from_elements from ascending extracts (#139819)

Example:
```mlir
%0 = vector.extract %source[0, 0] : i8 from vector<1x2xi8>
%1 = vector.extract %source[0, 1] : i8 from vector<1x2xi8>
%2 = vector.from_elements %0, %1 : vector<2xi8>
```

becomes
```mlir
%2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8>
```

It was decided that we should spill canonicalization tests into new
files (see
[discussion](https://github.com/llvm/llvm-project/pull/135096#pullrequestreview-2760245596))
In view of this I added the new tests to a new file specifically for
canonicalization of from_elements. To be consistent in the location of
the tests, I moved existing tests `extract_scalar_from_from_element`,
`extract_1d_from_from_elements`, `extract_2d_from_from_elements` and
`from_elements_to_splat` from `canonicalize.mlir` to
`canonicalze/vector-from-elements.mlir`. In addition to moving I changed
the LIT variables to all be upper-case for consistency.

Added: 
    mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 890a5e9e5c9b4..fcfb401fd9867 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -33,6 +33,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/ValueRange.h"
 #include "mlir/Interfaces/SubsetOpInterface.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Support/LLVM.h"
@@ -2387,9 +2388,129 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
   return success();
 }
 
+/// Rewrite from_elements on multiple scalar extracts as a shape_cast
+/// on a single extract. Example:
+///   %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8>
+///   %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8>
+///   %2 = vector.from_elements %0, %1 : vector<2xi8>
+///
+/// becomes
+///   %1 = vector.extract %source[0] : vector<1x2xi8> from vector<2x2xi8>
+///   %2 = vector.shape_cast %1 : vector<1x2xi8> to vector<2xi8>
+///
+/// The requirements for this to be valid are
+///
+///   i) The elements are extracted from the same vector (%source).
+///
+///  ii) The elements form a suffix of %source. Specifically, the number
+///      of elements is the same as the product of the last N dimension sizes
+///      of %source, for some N.
+///
+/// iii) The elements are extracted contiguously in ascending order.
+
+class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
+
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(FromElementsOp fromElements,
+                                PatternRewriter &rewriter) const override {
+
+    // Handled by `rewriteFromElementsAsSplat`
+    if (fromElements.getType().getNumElements() == 1)
+      return failure();
+
+    // The common source that all elements are extracted from, if one exists.
+    TypedValue<VectorType> source;
+    // The position of the combined extract operation, if one is created.
+    ArrayRef<int64_t> combinedPosition;
+    // The expected index of extraction of the current element in the loop, if
+    // elements are extracted contiguously in ascending order.
+    SmallVector<int64_t> expectedPosition;
+
+    for (auto [insertIndex, element] :
+         llvm::enumerate(fromElements.getElements())) {
+
+      // Check that the element is from a vector.extract operation.
+      auto extractOp =
+          dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
+      if (!extractOp) {
+        return rewriter.notifyMatchFailure(fromElements,
+                                           "element not from vector.extract");
+      }
+
+      // Check condition (i) by checking that all elements have the same source
+      // as the first element.
+      if (insertIndex == 0) {
+        source = extractOp.getVector();
+      } else if (extractOp.getVector() != source) {
+        return rewriter.notifyMatchFailure(fromElements,
+                                           "element from 
diff erent vector");
+      }
+
+      ArrayRef<int64_t> position = extractOp.getStaticPosition();
+      int64_t rank = position.size();
+      assert(rank == source.getType().getRank() &&
+             "scalar extract must have full rank position");
+
+      // Check condition (ii) by checking that the position that the first
+      // element is extracted from has sufficient trailing 0s. For example, in
+      //
+      //   %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
+      //   [...]
+      //   %elms = vector.from_elements %elm0, [...] : vector<12xi8>
+      //
+      // The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12
+      // elements, which is the number of elements of %n, so this is valid.
+      if (insertIndex == 0) {
+        const int64_t numElms = fromElements.getType().getNumElements();
+        int64_t numSuffixElms = 1;
+        int64_t index = rank;
+        while (index > 0 && position[index - 1] == 0 &&
+               numSuffixElms < numElms) {
+          numSuffixElms *= source.getType().getDimSize(index - 1);
+          --index;
+        }
+        if (numSuffixElms != numElms) {
+          return rewriter.notifyMatchFailure(
+              fromElements, "elements do not form a suffix of source");
+        }
+        expectedPosition = llvm::to_vector(position);
+        combinedPosition = position.drop_back(rank - index);
+      }
+
+      // Check condition (iii).
+      else if (expectedPosition != position) {
+        return rewriter.notifyMatchFailure(
+            fromElements, "elements not in ascending order (static order)");
+      }
+      increment(expectedPosition, source.getType().getShape());
+    }
+
+    auto extracted = rewriter.createOrFold<vector::ExtractOp>(
+        fromElements.getLoc(), source, combinedPosition);
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+        fromElements, fromElements.getType(), extracted);
+
+    return success();
+  }
+
+  /// Increments n-D `indices` by 1 starting from the innermost dimension.
+  static void increment(MutableArrayRef<int64_t> indices,
+                        ArrayRef<int64_t> shape) {
+    for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
+      indices[dim] += 1;
+      if (indices[dim] < shape[dim])
+        break;
+      indices[dim] = 0;
+    }
+  }
+};
+
 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {
   results.add(rewriteFromElementsAsSplat);
+  results.add<FromElementsToShapeCast>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a6543aafd1c77..a06a9f67d54dc 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2943,75 +2943,6 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
 
 // -----
 
-// CHECK-LABEL: func @extract_scalar_from_from_elements(
-//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
-func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
-  // Extract from 0D.
-  %0 = vector.from_elements %a : vector<f32>
-  %1 = vector.extract %0[] : f32 from vector<f32>
-
-  // Extract from 1D.
-  %2 = vector.from_elements %a : vector<1xf32>
-  %3 = vector.extract %2[0] : f32 from vector<1xf32>
-  %4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
-  %5 = vector.extract %4[4] : f32 from vector<5xf32>
-
-  // Extract from 2D.
-  %6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
-  %7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
-  %8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
-  %9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
-  %10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>
-
-  // CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
-  return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
-}
-
-// -----
-
-// CHECK-LABEL: func @extract_1d_from_from_elements(
-//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
-func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
-  %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
-  // CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
-  %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
-  // CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
-  %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
-  // CHECK: return %[[splat1]], %[[splat2]]
-  return %1, %2 : vector<3xf32>, vector<3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @extract_2d_from_from_elements(
-//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
-func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
-  %0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
-  // CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
-  %1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
-  // CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
-  %2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
-  // CHECK: return %[[splat1]], %[[splat2]]
-  return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @from_elements_to_splat(
-//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
-func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
-  // CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
-  %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
-  // CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
-  %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
-  // CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
-  %2 = vector.from_elements %a : vector<f32>
-  // CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
-  return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
-}
-
-// -----
-
 // CHECK-LABEL: func @vector_insert_const_regression(
 //       CHECK:   llvm.mlir.undef
 //       CHECK:   vector.insert

diff  --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
new file mode 100644
index 0000000000000..fdab2a8918a2e
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -0,0 +1,268 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// This file contains some tests of folding/canonicalizing vector.from_elements
+
+///===----------------------------------------------===//
+///  Tests of `rewriteFromElementsAsSplat`
+///===----------------------------------------------===//
+
+// CHECK-LABEL: func @extract_scalar_from_from_elements(
+//  CHECK-SAME:      %[[A:.*]]: f32, %[[B:.*]]: f32)
+func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
+  // Extract from 0D.
+  %0 = vector.from_elements %a : vector<f32>
+  %1 = vector.extract %0[] : f32 from vector<f32>
+
+  // Extract from 1D.
+  %2 = vector.from_elements %a : vector<1xf32>
+  %3 = vector.extract %2[0] : f32 from vector<1xf32>
+  %4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
+  %5 = vector.extract %4[4] : f32 from vector<5xf32>
+
+  // Extract from 2D.
+  %6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
+  %7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
+  %8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
+  %9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
+  %10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>
+
+  // CHECK: return %[[A]], %[[A]], %[[B]], %[[A]], %[[A]], %[[B]], %[[B]]
+  return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_1d_from_from_elements(
+//  CHECK-SAME:      %[[A:.*]]: f32, %[[B:.*]]: f32)
+func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
+  %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
+  // CHECK: %[[SPLAT1:.*]] = vector.splat %[[A]] : vector<3xf32>
+  %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
+  // CHECK: %[[SPLAT2:.*]] = vector.splat %[[B]] : vector<3xf32>
+  %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
+  // CHECK: return %[[SPLAT1]], %[[SPLAT2]]
+  return %1, %2 : vector<3xf32>, vector<3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_2d_from_from_elements(
+//  CHECK-SAME:      %[[A:.*]]: f32, %[[B:.*]]: f32)
+func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
+  %0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
+  // CHECK: %[[SPLAT1:.*]] = vector.from_elements %[[A]], %[[A]], %[[A]], %[[B]] : vector<2x2xf32>
+  %1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
+  // CHECK: %[[SPLAT2:.*]] = vector.from_elements %[[B]], %[[B]], %[[B]], %[[A]] : vector<2x2xf32>
+  %2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
+  // CHECK: return %[[SPLAT1]], %[[SPLAT2]]
+  return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @from_elements_to_splat(
+//  CHECK-SAME:      %[[A:.*]]: f32, %[[B:.*]]: f32)
+func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
+  // CHECK: %[[SPLAT:.*]] = vector.splat %[[A]] : vector<2x3xf32>
+  %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
+  // CHECK: %[[FROM_EL:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
+  %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
+  // CHECK: %[[SPLAT2:.*]] = vector.splat %[[A]] : vector<f32>
+  %2 = vector.from_elements %a : vector<f32>
+  // CHECK: return %[[SPLAT]], %[[FROM_EL]], %[[SPLAT2]]
+  return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
+}
+
+// -----
+
+///===----------------------------------------------===//
+///  Tests of `FromElementsToShapeCast`
+///===----------------------------------------------===//
+
+// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
+//  CHECK-SAME:       %[[A:.*]]: vector<1x2xi8>)
+//       CHECK:       %[[EXTRACT:.*]] = vector.extract %[[A]][0] : vector<2xi8> from vector<1x2xi8>
+//       CHECK:       return %[[EXTRACT]] : vector<2xi8>
+func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
+  %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
+  %1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
+  %4 = vector.from_elements %0, %1 : vector<2xi8>
+  return %4 : vector<2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @to_shape_cast_rank1_to_rank3(
+//  CHECK-SAME:       %[[A:.*]]: vector<8xi8>)
+//       CHECK:       %[[SHAPE_CAST:.*]] = vector.shape_cast %[[A]] : vector<8xi8> to vector<2x2x2xi8>
+//       CHECK:       return %[[SHAPE_CAST]] : vector<2x2x2xi8>
+func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8> {
+  %0 = vector.extract %arg0[0] : i8 from vector<8xi8>
+  %1 = vector.extract %arg0[1] : i8 from vector<8xi8>
+  %2 = vector.extract %arg0[2] : i8 from vector<8xi8>
+  %3 = vector.extract %arg0[3] : i8 from vector<8xi8>
+  %4 = vector.extract %arg0[4] : i8 from vector<8xi8>
+  %5 = vector.extract %arg0[5] : i8 from vector<8xi8>
+  %6 = vector.extract %arg0[6] : i8 from vector<8xi8>
+  %7 = vector.extract %arg0[7] : i8 from vector<8xi8>
+  %8 = vector.from_elements %0, %1, %2, %3, %4, %5, %6, %7 : vector<2x2x2xi8>
+  return %8 : vector<2x2x2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @source_larger_than_out(
+//  CHECK-SAME:      %[[A:.*]]: vector<2x3x4xi8>)
+//       CHECK:      %[[EXTRACT:.*]] = vector.extract %[[A]][1] : vector<3x4xi8> from vector<2x3x4xi8>
+//       CHECK:      %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<3x4xi8> to vector<12xi8>
+//       CHECK:      return %[[SHAPE_CAST]] : vector<12xi8>
+func.func @source_larger_than_out(%arg0: vector<2x3x4xi8>) -> vector<12xi8> {
+  %0 = vector.extract %arg0[1, 0, 0] : i8 from vector<2x3x4xi8>
+  %1 = vector.extract %arg0[1, 0, 1] : i8 from vector<2x3x4xi8>
+  %2 = vector.extract %arg0[1, 0, 2] : i8 from vector<2x3x4xi8>
+  %3 = vector.extract %arg0[1, 0, 3] : i8 from vector<2x3x4xi8>
+  %4 = vector.extract %arg0[1, 1, 0] : i8 from vector<2x3x4xi8>
+  %5 = vector.extract %arg0[1, 1, 1] : i8 from vector<2x3x4xi8>
+  %6 = vector.extract %arg0[1, 1, 2] : i8 from vector<2x3x4xi8>
+  %7 = vector.extract %arg0[1, 1, 3] : i8 from vector<2x3x4xi8>
+  %8 = vector.extract %arg0[1, 2, 0] : i8 from vector<2x3x4xi8>
+  %9 = vector.extract %arg0[1, 2, 1] : i8 from vector<2x3x4xi8>
+  %10 = vector.extract %arg0[1, 2, 2] : i8 from vector<2x3x4xi8>
+  %11 = vector.extract %arg0[1, 2, 3] : i8 from vector<2x3x4xi8>
+  %12 = vector.from_elements %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11 : vector<12xi8>
+  return %12 : vector<12xi8>
+}
+
+// -----
+
+// This test is similar to `source_larger_than_out` except here the number of elements
+// extracted contigously starting from the first position [0,0] could be 6 instead of 3
+// and the pattern would still match.
+// CHECK-LABEL: func @suffix_with_excess_zeros(
+//       CHECK:      %[[EXT:.*]] = vector.extract {{.*}}[0] : vector<3xi8> from vector<2x3xi8>
+//       CHECK:      return %[[EXT]] : vector<3xi8>
+func.func @suffix_with_excess_zeros(%arg0: vector<2x3xi8>) -> vector<3xi8> {
+  %0 = vector.extract %arg0[0, 0] : i8 from vector<2x3xi8>
+  %1 = vector.extract %arg0[0, 1] : i8 from vector<2x3xi8>
+  %2 = vector.extract %arg0[0, 2] : i8 from vector<2x3xi8>
+  %3 = vector.from_elements %0, %1, %2 : vector<3xi8>
+  return %3 : vector<3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @large_source_with_shape_cast_required(
+//  CHECK-SAME:      %[[A:.*]]: vector<2x2x2x2xi8>)
+//       CHECK:      %[[EXTRACT:.*]] = vector.extract %[[A]][0, 1] : vector<2x2xi8> from vector<2x2x2x2xi8>
+//       CHECK:      %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<2x2xi8> to vector<1x4x1xi8>
+//       CHECK:      return %[[SHAPE_CAST]] : vector<1x4x1xi8>
+func.func @large_source_with_shape_cast_required(%arg0: vector<2x2x2x2xi8>) -> vector<1x4x1xi8> {
+  %0 = vector.extract %arg0[0, 1, 0, 0] : i8 from vector<2x2x2x2xi8>
+  %1 = vector.extract %arg0[0, 1, 0, 1] : i8 from vector<2x2x2x2xi8>
+  %2 = vector.extract %arg0[0, 1, 1, 0] : i8 from vector<2x2x2x2xi8>
+  %3 = vector.extract %arg0[0, 1, 1, 1] : i8 from vector<2x2x2x2xi8>
+  %4 = vector.from_elements %0, %1, %2, %3 : vector<1x4x1xi8>
+  return %4 : vector<1x4x1xi8>
+}
+
+//  -----
+
+// Could match, but handled by `rewriteFromElementsAsSplat`.
+// CHECK-LABEL: func @extract_single_elm(
+//  CHECK-NEXT:      vector.extract
+//  CHECK-NEXT:      vector.splat
+//  CHECK-NEXT:      return
+func.func @extract_single_elm(%arg0 : vector<2x3xi8>) -> vector<1xi8> {
+  %0 = vector.extract %arg0[0, 0] : i8 from vector<2x3xi8>
+  %1 = vector.from_elements %0 : vector<1xi8>
+  return %1 : vector<1xi8>
+}
+
+// -----
+
+//   CHECK-LABEL: func @negative_source_contiguous_but_not_suffix(
+//     CHECK-NOT:      shape_cast
+//         CHECK:      from_elements
+func.func @negative_source_contiguous_but_not_suffix(%arg0: vector<2x3xi8>) -> vector<3xi8> {
+  %0 = vector.extract %arg0[0, 1] : i8 from vector<2x3xi8>
+  %1 = vector.extract %arg0[0, 2] : i8 from vector<2x3xi8>
+  %2 = vector.extract %arg0[1, 0] : i8 from vector<2x3xi8>
+  %3 = vector.from_elements %0, %1, %2 : vector<3xi8>
+  return %3 : vector<3xi8>
+}
+
+// -----
+
+// The extracted elements are recombined into a single vector, but in a new order.
+// CHECK-LABEL: func @negative_nonascending_order(
+//   CHECK-NOT:      shape_cast
+//       CHECK:      from_elements
+func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> {
+  %0 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
+  %1 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
+  %2 = vector.from_elements %0, %1 : vector<2xi8>
+  return %2 : vector<2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_nonstatic_extract(
+//   CHECK-NOT:      shape_cast
+//       CHECK:      from_elements
+func.func @negative_nonstatic_extract(%arg0: vector<1x2xi8>, %i0 : index, %i1 : index) -> vector<2xi8> {
+  %0 = vector.extract %arg0[0, %i0] : i8 from vector<1x2xi8>
+  %1 = vector.extract %arg0[0, %i1] : i8 from vector<1x2xi8>
+  %2 = vector.from_elements %0, %1 : vector<2xi8>
+  return %2 : vector<2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_
diff erent_sources(
+//   CHECK-NOT:      shape_cast
+//       CHECK:      from_elements
+func.func @negative_
diff erent_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi8>) -> vector<2xi8> {
+  %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
+  %1 = vector.extract %arg1[0, 1] : i8 from vector<1x2xi8>
+  %2 = vector.from_elements %0, %1 : vector<2xi8>
+  return %2 : vector<2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_source_not_suffix(
+//   CHECK-NOT:      shape_cast
+//       CHECK:      from_elements
+func.func @negative_source_not_suffix(%arg0: vector<1x3xi8>) -> vector<2xi8> {
+  %0 = vector.extract %arg0[0, 0] : i8 from vector<1x3xi8>
+  %1 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8>
+  %2 = vector.from_elements %0, %1 : vector<2xi8>
+  return %2 : vector<2xi8>
+}
+
+// -----
+
+// The inserted elements are a subset of the extracted elements.
+// [0, 1, 2] -> [1, 1, 2]
+// CHECK-LABEL: func @negative_nobijection_order(
+//   CHECK-NOT:      shape_cast
+//       CHECK:      from_elements
+func.func @negative_nobijection_order(%arg0: vector<1x3xi8>) -> vector<3xi8> {
+  %0 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8>
+  %1 = vector.extract %arg0[0, 2] : i8 from vector<1x3xi8>
+  %2 = vector.from_elements %0, %0, %1 : vector<3xi8>
+  return %2 : vector<3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_source_too_small(
+//   CHECK-NOT:      shape_cast
+//       CHECK:      from_elements
+func.func @negative_source_too_small(%arg0: vector<2xi8>) -> vector<4xi8> {
+  %0 = vector.extract %arg0[0] : i8 from vector<2xi8>
+  %1 = vector.extract %arg0[1] : i8 from vector<2xi8>
+  %2 = vector.from_elements %0, %1, %1, %1 : vector<4xi8>
+  return %2 : vector<4xi8>
+}
+


        


More information about the Mlir-commits mailing list