[Mlir-commits] [mlir] [mlir][spirv] Add basic support for SPV_EXT_replicated_composites (PR #147067)

Mohammadreza Ameri Mahabadian llvmlistbot at llvm.org
Fri Jul 11 03:55:21 PDT 2025


================
@@ -306,3 +306,92 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
   }
 }
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompositesEXT], [SPV_EXT_replicated_composites]> {
+
+  // CHECK-LABEL: @splat_vector_i32
+  spirv.func @splat_vector_i32() -> (vector<3xi32>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [1 : i32] : vector<3xi32>
+    %1 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : vector<3xi32>
+    spirv.ReturnValue %1 : vector<3xi32>
+  }
+
+  // CHECK-LABEL: @splat_array_of_i32
+  spirv.func @splat_array_of_i32() -> (!spirv.array<3 x i32>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32>
+    %1 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32>
+    spirv.ReturnValue %1 : !spirv.array<3 x i32>
+  }
+
+  // CHECK-LABEL: @splat_array_of_vectors_of_i32
+  spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<3 x vector<2xi32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<3 x vector<2xi32>>
+    %0 = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<3 x vector<2xi32>>
+    spirv.ReturnValue %0 : !spirv.array<3 x vector<2xi32>>
+  }
+
+  // CHECK-LABEL: @splat_array_of_splat_vectors_of_i32
+  spirv.func @splat_array_of_splat_vectors_of_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>>
+    %0 = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
+  }
+
+  // CHECK-LABEL: @splat_tensor_of_i32
+  spirv.func @splat_tensor_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
+    %0 = spirv.EXT.ConstantCompositeReplicate [3 : i32] : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
----------------
mahabadm wrote:

Great point. Thank you so much, I now see your point on tensors and removed their tests. I agree there should not be any tensors at this level and as such I added assembly format and removed manual parser/printer.

Also, many thanks for pointing out the cases for multi dimensional arrays, I realized I have to update the verifier and serializer to handle them correctly. This is now applied and I have added tests like below:

```
spirv.func @splat_array_of_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
    // CHECK: %0 = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
    %0 = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
}
```

```
spirv.func @splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
    // CHECK: %0 = spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
    %0 = spirv.EXT.ConstantCompositeReplicate [[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
}
```
  
I think this should address your comments, but please let me know if I'm missing anything.
Many thanks again for the great review.
  

https://github.com/llvm/llvm-project/pull/147067


More information about the Mlir-commits mailing list