[Mlir-commits] [mlir] [mlir][vector] Add linearization pattern for vector.splat (PR #137651)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu May 1 08:53:05 PDT 2025


================
@@ -399,3 +399,20 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
   %1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
   return %1 : vector<[4]x4xf16>
 }
+
+// -----
+// ALL-LABEL: linearize_vector_splat
+// ALL-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32>
+func.func @linearize_vector_splat(%arg0: i32) -> vector<4x2xi32> {
+  // DEFAULT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32>
+  // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32>
+  // DEFAULT: return %[[CAST]] : vector<4x2xi32>
+  // BW-128: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32>
+  // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32>
+  // BW-128: return %[[CAST]] : vector<4x2xi32>
+
+  // BW-0: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4x2xi32>
+  // BW-0: return %[[SPLAT]] : vector<4x2xi32>
+  %0 = vector.splat %arg0 : vector<4x2xi32>
----------------
banach-space wrote:

This should also work (i.e. scalable vector):
```mlir
  %0 = vector.splat %arg0 : vector<4x[2]xi32>
```

Could you try and if it works, add a test? Thanks!

https://github.com/llvm/llvm-project/pull/137651


More information about the Mlir-commits mailing list