[Mlir-commits] [mlir] 4f9adb6 - [mlir][vector] Relax operand type restrictions for `vector.splat` (#145517)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 25 01:45:35 PDT 2025


Author: Matthias Springer
Date: 2025-06-25T10:45:31+02:00
New Revision: 4f9adb6889419a556ae972119c23c842e7bf4092

URL: https://github.com/llvm/llvm-project/commit/4f9adb6889419a556ae972119c23c842e7bf4092
DIFF: https://github.com/llvm/llvm-project/commit/4f9adb6889419a556ae972119c23c842e7bf4092.diff

LOG: [mlir][vector] Relax operand type restrictions for `vector.splat` (#145517)

The vector type allows element types that implement the
`VectorElementTypeInterface`. `vector.splat` should allow any element
type that is supported by the vector type.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 02e62930a742d..d58ee84bee63d 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2920,8 +2920,8 @@ def Vector_SplatOp : Vector_Op<"splat", [
   ]> {
   let summary = "vector splat or broadcast operation";
   let description = [{
-    Broadcast the operand to all elements of the result vector. The operand is
-    required to be of integer/index/float type.
+    Broadcast the operand to all elements of the result vector. The type of the
+    operand must match the element type of the vector type.
 
     Example:
 
@@ -2931,8 +2931,7 @@ def Vector_SplatOp : Vector_Op<"splat", [
     ```
   }];
 
-  let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
-                                 "integer/index/float type">:$input);
+  let arguments = (ins AnyType:$input);
   let results = (outs AnyVectorOfAnyRank:$aggregate);
 
   let builders = [

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index ec7cee7b2c641..4935ec8ba8e61 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1975,6 +1975,15 @@ func.func @flat_transpose_scalable(%arg0: vector<[16]xf32>) -> vector<[16]xf32>
 
 // -----
 
+// expected-note @+1 {{prior use here}}
+func.func @vector_splat_type_mismatch(%a: f32) {
+  // expected-error @+1 {{expects 
diff erent type than prior uses: 'i32' vs 'f32'}}
+  %0 = vector.splat %a : vector<1xi32>
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.load
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index c59f7bd001905..0121bcdbbba45 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -149,7 +149,7 @@ func.func @vector_transfer_ops_tensor(%arg0: tensor<?x?xf32>,
 }
 
 // CHECK-LABEL: @vector_broadcast
-func.func @vector_broadcast(%a: f32, %b: vector<f32>, %c: vector<16xf32>, %d: vector<1x16xf32>, %e: vector<8x1xf32>) -> vector<8x16xf32> {
+func.func @vector_broadcast(%a: f32, %b: vector<f32>, %c: vector<16xf32>, %d: vector<1x16xf32>, %e: vector<8x1xf32>, %f: vector<8x1x!llvm.ptr<1>>) {
   // CHECK: vector.broadcast %{{.*}} : f32 to vector<f32>
   %0 = vector.broadcast %a : f32 to vector<f32>
   // CHECK: vector.broadcast %{{.*}} : vector<f32> to vector<4xf32>
@@ -162,7 +162,9 @@ func.func @vector_broadcast(%a: f32, %b: vector<f32>, %c: vector<16xf32>, %d: ve
   %4 = vector.broadcast %d : vector<1x16xf32> to vector<8x16xf32>
   // CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1xf32> to vector<8x16xf32>
   %5 = vector.broadcast %e : vector<8x1xf32> to vector<8x16xf32>
-  return %4 : vector<8x16xf32>
+  // CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1x!llvm.ptr<1>> to vector<8x16x!llvm.ptr<1>>
+  %6 = vector.broadcast %f : vector<8x1x!llvm.ptr<1>> to vector<8x16x!llvm.ptr<1>>
+  return
 }
 
 // CHECK-LABEL: @shuffle0D
@@ -959,13 +961,16 @@ func.func @vector_scan(%0: vector<4x8x16x32xf32>) -> vector<4x8x16x32xf32> {
 }
 
 // CHECK-LABEL: func @test_splat_op
-// CHECK-SAME: [[S:%arg[0-9]+]]: f32
-func.func @test_splat_op(%s : f32) {
-  // CHECK: vector.splat [[S]] : vector<8xf32>
+// CHECK-SAME: %[[s:.*]]: f32, %[[s2:.*]]: !llvm.ptr<1>
+func.func @test_splat_op(%s : f32, %s2 : !llvm.ptr<1>) {
+  // CHECK: vector.splat %[[s]] : vector<8xf32>
   %v = vector.splat %s : vector<8xf32>
 
-  // CHECK: vector.splat [[S]] : vector<4xf32>
+  // CHECK: vector.splat %[[s]] : vector<4xf32>
   %u = "vector.splat"(%s) : (f32) -> vector<4xf32>
+
+  // CHECK: vector.splat %[[s2]] : vector<16x!llvm.ptr<1>>
+  %w = vector.splat %s2 : vector<16x!llvm.ptr<1>>
   return
 }
 


        


More information about the Mlir-commits mailing list