[Mlir-commits] [mlir] [mlir][vector] Add linearization pattern for vector.splat (PR #137651)

Nishant Patel llvmlistbot at llvm.org
Thu May 1 11:09:11 PDT 2025


================
@@ -399,3 +399,20 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
   %1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
   return %1 : vector<[4]x4xf16>
 }
+
+// -----
+// ALL-LABEL: linearize_vector_splat
+// ALL-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32>
+func.func @linearize_vector_splat(%arg0: i32) -> vector<4x2xi32> {
+  // DEFAULT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32>
+  // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32>
+  // DEFAULT: return %[[CAST]] : vector<4x2xi32>
+  // BW-128: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32>
+  // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32>
+  // BW-128: return %[[CAST]] : vector<4x2xi32>
+
+  // BW-0: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4x2xi32>
+  // BW-0: return %[[SPLAT]] : vector<4x2xi32>
+  %0 = vector.splat %arg0 : vector<4x2xi32>
----------------
nbpatel wrote:

added the test

https://github.com/llvm/llvm-project/pull/137651


More information about the Mlir-commits mailing list