[Mlir-commits] [mlir] [mlir][spirv] Add basic support for SPV_EXT_replicated_composites (PR #147067)
Mohammadreza Ameri Mahabadian
llvmlistbot at llvm.org
Fri Jul 11 03:55:21 PDT 2025
================
@@ -306,3 +306,92 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
}
}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompositesEXT], [SPV_EXT_replicated_composites]> {
+
+ // CHECK-LABEL: @splat_vector_i32
+ spirv.func @splat_vector_i32() -> (vector<3xi32>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [1 : i32] : vector<3xi32>
+ %1 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : vector<3xi32>
+ spirv.ReturnValue %1 : vector<3xi32>
+ }
+
+ // CHECK-LABEL: @splat_array_of_i32
+ spirv.func @splat_array_of_i32() -> (!spirv.array<3 x i32>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32>
+ %1 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32>
+ spirv.ReturnValue %1 : !spirv.array<3 x i32>
+ }
+
+ // CHECK-LABEL: @splat_array_of_vectors_of_i32
+ spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<3 x vector<2xi32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<3 x vector<2xi32>>
+ %0 = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<3 x vector<2xi32>>
+ spirv.ReturnValue %0 : !spirv.array<3 x vector<2xi32>>
+ }
+
+ // CHECK-LABEL: @splat_array_of_splat_vectors_of_i32
+ spirv.func @splat_array_of_splat_vectors_of_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>>
+ %0 = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
+ }
+
+ // CHECK-LABEL: @splat_tensor_of_i32
+ spirv.func @splat_tensor_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
+ %0 = spirv.EXT.ConstantCompositeReplicate [3 : i32] : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
----------------
mahabadm wrote:
Great point. Thank you so much, I now see your point on tensors and removed their tests. I agree there should not be any tensors at this level and as such I added assembly format and removed manual parser/printer.
Also, many thanks for pointing out the cases for multi dimensional arrays, I realized I have to update the verifier and serializer to handle them correctly. This is now applied and I have added tests like below:
```
spirv.func @splat_array_of_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
// CHECK: %0 = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
%0 = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
}
```
```
spirv.func @splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
// CHECK: %0 = spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
%0 = spirv.EXT.ConstantCompositeReplicate [[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
}
```
I think this should address your comments, but please let me know if I'm missing anything.
Many thanks again for the great review.
https://github.com/llvm/llvm-project/pull/147067
More information about the Mlir-commits
mailing list