[Mlir-commits] [mlir] [mli][vector] canonicalize vector.from_elements from ascending extracts (PR #139819)
Andrzej Warzyński
llvmlistbot at llvm.org
Thu May 15 06:47:32 PDT 2025
================
@@ -0,0 +1,155 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// This file contains some tests of folding/canonicalizing vector.from_elements
+
+///===----------------------------------------------===//
+/// Tests of `rewriteFromElementsAsSplat`
+///===----------------------------------------------===//
+
+// CHECK-LABEL: func @extract_scalar_from_from_elements(
+// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
+ // Extract from 0D.
+ %0 = vector.from_elements %a : vector<f32>
+ %1 = vector.extract %0[] : f32 from vector<f32>
+
+ // Extract from 1D.
+ %2 = vector.from_elements %a : vector<1xf32>
+ %3 = vector.extract %2[0] : f32 from vector<1xf32>
+ %4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
+ %5 = vector.extract %4[4] : f32 from vector<5xf32>
+
+ // Extract from 2D.
+ %6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
+ %7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
+ %8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
+ %9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
+ %10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>
+
+ // CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
+ return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_1d_from_from_elements(
+// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
+ %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
+ // CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
+ %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
+ // CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
+ %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
+ // CHECK: return %[[splat1]], %[[splat2]]
+ return %1, %2 : vector<3xf32>, vector<3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_2d_from_from_elements(
+// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
+ %0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
+ // CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
+ %1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
+ // CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
+ %2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
+ // CHECK: return %[[splat1]], %[[splat2]]
+ return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @from_elements_to_splat(
+// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
+ // CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
+ %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
+ // CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
+ %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
+ // CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
+ %2 = vector.from_elements %a : vector<f32>
+ // CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
+ return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
+}
+
+// -----
+
+///===----------------------------------------------===//
+/// Tests of `FromElementsToShapeCast`
+///===----------------------------------------------===//
+
+// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
+// CHECK-SAME: %[[a:.*]]: vector<1x2xi8>)
+// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[a]] : vector<1x2xi8> to vector<2xi8>
+// CHECK: return %[[shape_cast]] : vector<2xi8>
----------------
banach-space wrote:
[nit] Could you use caps for LIT variables? That's much more common. And, IMHO, easier to parse 😅
I suspect that you wanted to maintain consistency with the tests for `rewriteFromElementsAsSplat`? I would just update those as well (fortunately, there arent' that many LIT vars there)
https://github.com/llvm/llvm-project/pull/139819
More information about the Mlir-commits
mailing list