[Mlir-commits] [mlir] 7668e58 - [mlir][spirv] Fix spv.CompositeConstruct assembly and validation

Lei Zhang llvmlistbot at llvm.org
Wed Jul 27 16:21:18 PDT 2022


Author: Lei Zhang
Date: 2022-07-27T19:17:23-04:00
New Revision: 7668e58210776a15d5e74d91223e6ca541ba9ba8

URL: https://github.com/llvm/llvm-project/commit/7668e58210776a15d5e74d91223e6ca541ba9ba8
DIFF: https://github.com/llvm/llvm-project/commit/7668e58210776a15d5e74d91223e6ca541ba9ba8.diff

LOG: [mlir][spirv] Fix spv.CompositeConstruct assembly and validation

This commit fixes spv.CompositeConstruct to assembly to list
operand types to enable vector construction out of smaller vectors.
Validation is also fixed to properly check the cases for vector
construction.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D130669

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
    mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
    mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir
    mlir/test/Target/SPIRV/composite-op.mlir
    mlir/test/Target/SPIRV/debug.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
index dd6ee93b4bae9..3b9b844f4d42a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
@@ -64,6 +64,10 @@ def SPV_CompositeConstructOp : SPV_Op<"CompositeConstruct", [NoSideEffect]> {
   let results = (outs
     SPV_Composite:$result
   );
+
+  let assemblyFormat = [{
+    $constituents attr-dict `:` `(` type(operands) `)` `->` type($result)
+  }];
 }
 
 // -----

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 373fe602562f8..e79d1d2c220f6 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -31,6 +31,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/bit.h"
+#include <numeric>
 
 using namespace mlir;
 
@@ -1618,66 +1619,64 @@ LogicalResult spirv::BranchConditionalOp::verify() {
 // spv.CompositeConstruct
 //===----------------------------------------------------------------------===//
 
-ParseResult spirv::CompositeConstructOp::parse(OpAsmParser &parser,
-                                               OperationState &state) {
-  SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
-  Type type;
-  auto loc = parser.getCurrentLocation();
-
-  if (parser.parseOperandList(operands) || parser.parseColonType(type)) {
-    return failure();
-  }
-  auto cType = type.dyn_cast<spirv::CompositeType>();
-  if (!cType) {
-    return parser.emitError(
-               loc, "result type must be a composite type, but provided ")
-           << type;
-  }
-
-  if (cType.hasCompileTimeKnownNumElements() &&
-      operands.size() != cType.getNumElements()) {
-    return parser.emitError(loc, "has incorrect number of operands: expected ")
-           << cType.getNumElements() << ", but provided " << operands.size();
-  }
-  // TODO: Add support for constructing a vector type from the vector operands.
-  // According to the spec: "for constructing a vector, the operands may
-  // also be vectors with the same component type as the Result Type component
-  // type".
-  SmallVector<Type, 4> elementTypes;
-  elementTypes.reserve(operands.size());
-  for (auto index : llvm::seq<uint32_t>(0, operands.size())) {
-    elementTypes.push_back(cType.getElementType(index));
-  }
-  state.addTypes(type);
-  return parser.resolveOperands(operands, elementTypes, loc, state.operands);
-}
-
-void spirv::CompositeConstructOp::print(OpAsmPrinter &printer) {
-  printer << " " << constituents() << " : " << getResult().getType();
-}
-
 LogicalResult spirv::CompositeConstructOp::verify() {
   auto cType = getType().cast<spirv::CompositeType>();
   operand_range constituents = this->constituents();
 
-  if (cType.isa<spirv::CooperativeMatrixNVType>()) {
+  if (auto coopType = cType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
     if (constituents.size() != 1)
-      return emitError("has incorrect number of operands: expected ")
+      return emitOpError("has incorrect number of operands: expected ")
              << "1, but provided " << constituents.size();
-  } else if (constituents.size() != cType.getNumElements()) {
-    return emitError("has incorrect number of operands: expected ")
-           << cType.getNumElements() << ", but provided "
-           << constituents.size();
+    if (coopType.getElementType() != constituents.front().getType())
+      return emitOpError("operand type mismatch: expected operand type ")
+             << coopType.getElementType() << ", but provided "
+             << constituents.front().getType();
+    return success();
   }
 
-  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
-    if (constituents[index].getType() != cType.getElementType(index)) {
-      return emitError("operand type mismatch: expected operand type ")
-             << cType.getElementType(index) << ", but provided "
-             << constituents[index].getType();
+  if (constituents.size() == cType.getNumElements()) {
+    for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
+      if (constituents[index].getType() != cType.getElementType(index)) {
+        return emitOpError("operand type mismatch: expected operand type ")
+               << cType.getElementType(index) << ", but provided "
+               << constituents[index].getType();
+      }
     }
+    return success();
   }
 
+  // If not constructing a cooperative matrix type, then we must be constructing
+  // a vector type.
+  auto resultType = cType.dyn_cast<VectorType>();
+  if (!resultType)
+    return emitOpError(
+        "expected to return a vector or cooperative matrix when the number of "
+        "constituents is less than what the result needs");
+
+  SmallVector<unsigned> sizes;
+  for (Value component : constituents) {
+    if (!component.getType().isa<VectorType>() &&
+        !component.getType().isIntOrFloat())
+      return emitOpError("operand type mismatch: expected operand to have "
+                         "a scalar or vector type, but provided ")
+             << component.getType();
+
+    Type elementType = component.getType();
+    if (auto vectorType = component.getType().dyn_cast<VectorType>()) {
+      sizes.push_back(vectorType.getNumElements());
+      elementType = vectorType.getElementType();
+    } else {
+      sizes.push_back(1);
+    }
+
+    if (elementType != resultType.getElementType())
+      return emitOpError("operand element type mismatch: expected to be ")
+             << resultType.getElementType() << ", but provided " << elementType;
+  }
+  unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
+  if (totalCount != cType.getNumElements())
+    return emitOpError("has incorrect number of operands: expected ")
+           << cType.getNumElements() << ", but provided " << totalCount;
   return success();
 }
 

diff  --git a/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir
index 7c36312ea6547..67a3b8ca8eb6e 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir
@@ -32,8 +32,8 @@ func.func @copy_sign_vector(%value: vector<3xf16>, %sign: vector<3xf16>) -> vect
 //  CHECK-SAME: (%[[VALUE:.+]]: vector<3xf16>, %[[SIGN:.+]]: vector<3xf16>)
 //       CHECK:   %[[SMASK:.+]] = spv.Constant -32768 : i16
 //       CHECK:   %[[VMASK:.+]] = spv.Constant 32767 : i16
-//       CHECK:   %[[SVMASK:.+]] = spv.CompositeConstruct %[[SMASK]], %[[SMASK]], %[[SMASK]] : vector<3xi16>
-//       CHECK:   %[[VVMASK:.+]] = spv.CompositeConstruct %[[VMASK]], %[[VMASK]], %[[VMASK]] : vector<3xi16>
+//       CHECK:   %[[SVMASK:.+]] = spv.CompositeConstruct %[[SMASK]], %[[SMASK]], %[[SMASK]]
+//       CHECK:   %[[VVMASK:.+]] = spv.CompositeConstruct %[[VMASK]], %[[VMASK]], %[[VMASK]]
 //       CHECK:   %[[VCAST:.+]] = spv.Bitcast %[[VALUE]] : vector<3xf16> to vector<3xi16>
 //       CHECK:   %[[SCAST:.+]] = spv.Bitcast %[[SIGN]] : vector<3xf16> to vector<3xi16>
 //       CHECK:   %[[VAND:.+]] = spv.BitwiseAnd %[[VCAST]], %[[VVMASK]] : vector<3xi16>

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 94af7382e1c89..bde9b2f1f02d7 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -18,8 +18,8 @@ func.func @bitcast(%arg0 : vector<2xf32>, %arg1: vector<2xf16>) -> (vector<4xf16
 
 // CHECK-LABEL: @broadcast
 //  CHECK-SAME: %[[A:.*]]: f32
-//       CHECK:   spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32>
-//       CHECK:   spv.CompositeConstruct %[[A]], %[[A]] : vector<2xf32>
+//       CHECK:   spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]]
+//       CHECK:   spv.CompositeConstruct %[[A]], %[[A]]
 func.func @broadcast(%arg0 : f32) -> (vector<4xf32>, vector<2xf32>) {
   %0 = vector.broadcast %arg0 : f32 to vector<4xf32>
   %1 = vector.broadcast %arg0 : f32 to vector<2xf32>
@@ -182,7 +182,7 @@ func.func @fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<1xf
 
 // CHECK-LABEL: func @splat
 //  CHECK-SAME: (%[[A:.+]]: f32)
-//       CHECK:   %[[VAL:.+]] = spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32>
+//       CHECK:   %[[VAL:.+]] = spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]]
 //       CHECK:   return %[[VAL]]
 func.func @splat(%f : f32) -> vector<4xf32> {
   %splat = vector.splat %f : vector<4xf32>
@@ -206,7 +206,7 @@ func.func @splat_size1_vector(%f : f32) -> vector<1xf32> {
 //  CHECK-SAME:  %[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32>
 //       CHECK:    %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
 //       CHECK:    %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]]
-//       CHECK:    spv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : vector<4xf32>
+//       CHECK:    spv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : (f32, f32, f32, f32) -> vector<4xf32>
 func.func @shuffle(%v0 : vector<1xf32>, %v1: vector<1xf32>) -> vector<4xf32> {
   %shuffle = vector.shuffle %v0, %v1 [0, 1, 1, 0] : vector<1xf32>, vector<1xf32>
   return %shuffle : vector<4xf32>

diff  --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index a4ccd2412a192..fc9ba780c3815 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -5,48 +5,41 @@
 //===----------------------------------------------------------------------===//
 
 func.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> {
-  // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : vector<3xf32>
-  %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : vector<3xf32>
+  // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (f32, f32, f32) -> vector<3xf32>
+  %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xf32>
   return %0: vector<3xf32>
 }
 
 // -----
 
 func.func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spv.array<4xf32>, %arg2 : !spv.struct<(f32)>) -> !spv.struct<(vector<3xf32>, !spv.array<4xf32>, !spv.struct<(f32)>)> {
-  // CHECK: spv.CompositeConstruct %arg0, %arg1, %arg2 : !spv.struct<(vector<3xf32>, !spv.array<4 x f32>, !spv.struct<(f32)>)>
-  %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : !spv.struct<(vector<3xf32>, !spv.array<4xf32>, !spv.struct<(f32)>)>
+  // CHECK: spv.CompositeConstruct
+  %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : (vector<3xf32>, !spv.array<4xf32>, !spv.struct<(f32)>) -> !spv.struct<(vector<3xf32>, !spv.array<4xf32>, !spv.struct<(f32)>)>
   return %0: !spv.struct<(vector<3xf32>, !spv.array<4xf32>, !spv.struct<(f32)>)>
 }
 
 // -----
 
-func.func @composite_construct_coopmatrix(%arg0 : f32) -> !spv.coopmatrix<8x16xf32, Subgroup> {
-  // CHECK: spv.CompositeConstruct {{%.*}} : !spv.coopmatrix<8x16xf32, Subgroup>
-  %0 = spv.CompositeConstruct %arg0 : !spv.coopmatrix<8x16xf32, Subgroup>
-  return %0: !spv.coopmatrix<8x16xf32, Subgroup>
+// CHECK-LABEL: func @composite_construct_mixed_scalar_vector
+func.func @composite_construct_mixed_scalar_vector(%arg0: f32, %arg1: f32, %arg2 : vector<2xf32>) -> vector<4xf32> {
+  // CHECK: spv.CompositeConstruct %{{.+}}, %{{.+}}, %{{.+}} : (f32, vector<2xf32>, f32) -> vector<4xf32>
+  %0 = spv.CompositeConstruct %arg0, %arg2, %arg1 : (f32, vector<2xf32>, f32) -> vector<4xf32>
+  return %0: vector<4xf32>
 }
 
 // -----
 
-func.func @composite_construct_empty_struct() -> !spv.struct<()> {
-  // CHECK: spv.CompositeConstruct : !spv.struct<()>
-  %0 = spv.CompositeConstruct : !spv.struct<()>
-  return %0: !spv.struct<()>
-}
-
-// -----
-
-func.func @composite_construct_invalid_num_of_elements(%arg0: f32) -> f32 {
-  // expected-error @+1 {{result type must be a composite type, but provided 'f32'}}
-  %0 = spv.CompositeConstruct %arg0 : f32
-  return %0: f32
+func.func @composite_construct_coopmatrix(%arg0 : f32) -> !spv.coopmatrix<8x16xf32, Subgroup> {
+  // CHECK: spv.CompositeConstruct {{%.*}} : (f32) -> !spv.coopmatrix<8x16xf32, Subgroup>
+  %0 = spv.CompositeConstruct %arg0 : (f32) -> !spv.coopmatrix<8x16xf32, Subgroup>
+  return %0: !spv.coopmatrix<8x16xf32, Subgroup>
 }
 
 // -----
 
 func.func @composite_construct_invalid_result_type(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> {
   // expected-error @+1 {{has incorrect number of operands: expected 3, but provided 2}}
-  %0 = spv.CompositeConstruct %arg0, %arg2 : vector<3xf32>
+  %0 = spv.CompositeConstruct %arg0, %arg2 : (f32, f32) -> vector<3xf32>
   return %0: vector<3xf32>
 }
 
@@ -54,20 +47,52 @@ func.func @composite_construct_invalid_result_type(%arg0: f32, %arg1: f32, %arg2
 
 func.func @composite_construct_invalid_operand_type(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xi32> {
   // expected-error @+1 {{operand type mismatch: expected operand type 'i32', but provided 'f32'}}
-  %0 = "spv.CompositeConstruct" (%arg0, %arg1, %arg2) : (f32, f32, f32) -> vector<3xi32>
+  %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xi32>
   return %0: vector<3xi32>
 }
 
 // -----
 
-func.func @composite_construct_coopmatrix(%arg0 : f32, %arg1 : f32) -> !spv.coopmatrix<8x16xf32, Subgroup> {
+func.func @composite_construct_coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -> !spv.coopmatrix<8x16xf32, Subgroup> {
   // expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}}
-  %0 = spv.CompositeConstruct %arg0, %arg1 : !spv.coopmatrix<8x16xf32, Subgroup>
+  %0 = spv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spv.coopmatrix<8x16xf32, Subgroup>
+  return %0: !spv.coopmatrix<8x16xf32, Subgroup>
+}
+
+// -----
+
+func.func @composite_construct_coopmatrix_incorrect_element_type(%arg0 : i32) -> !spv.coopmatrix<8x16xf32, Subgroup> {
+  // expected-error @+1 {{operand type mismatch: expected operand type 'f32', but provided 'i32'}}
+  %0 = spv.CompositeConstruct %arg0 : (i32) -> !spv.coopmatrix<8x16xf32, Subgroup>
   return %0: !spv.coopmatrix<8x16xf32, Subgroup>
 }
 
 // -----
 
+func.func @composite_construct_array(%arg0: f32) -> !spv.array<4xf32> {
+  // expected-error @+1 {{expected to return a vector or cooperative matrix when the number of constituents is less than what the result needs}}
+  %0 = spv.CompositeConstruct %arg0 : (f32) -> !spv.array<4xf32>
+  return %0: !spv.array<4xf32>
+}
+
+// -----
+
+func.func @composite_construct_vector_wrong_element_type(%arg0: f32, %arg1: f32, %arg2 : vector<2xi32>) -> vector<4xf32> {
+  // expected-error @+1 {{operand element type mismatch: expected to be 'f32', but provided 'i32'}}
+  %0 = spv.CompositeConstruct %arg0, %arg2, %arg1 : (f32, vector<2xi32>, f32) -> vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// -----
+
+func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2 : vector<2xf32>) -> vector<4xf32> {
+  // expected-error @+1 {{op has incorrect number of operands: expected 4, but provided 3}}
+  %0 = spv.CompositeConstruct %arg0, %arg2 : (f32, vector<2xf32>) -> vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spv.CompositeExtractOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir b/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir
index af6c1ce3ca5d5..cde1740448935 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir
@@ -3,26 +3,26 @@
 spv.module Logical GLSL450 {
   spv.func @rewrite(%value0 : f32, %value1 : f32, %value2 : f32, %value3 : i32, %value4: !spv.array<3xf32>) -> vector<3xf32> "None" {
     %0 = spv.Undef : vector<3xf32>
-    // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : vector<3xf32>
+    // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (f32, f32, f32) -> vector<3xf32>
     %1 = spv.CompositeInsert %value0, %0[0 : i32] : f32 into vector<3xf32>
     %2 = spv.CompositeInsert %value1, %1[1 : i32] : f32 into vector<3xf32>
     %3 = spv.CompositeInsert %value2, %2[2 : i32] : f32 into vector<3xf32>
 
     %4 = spv.Undef : !spv.array<4xf32>
-    // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : !spv.array<4 x f32>
+    // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : (f32, f32, f32, f32) -> !spv.array<4 x f32>
     %5 = spv.CompositeInsert %value0, %4[0 : i32] : f32 into !spv.array<4xf32>
     %6 = spv.CompositeInsert %value1, %5[1 : i32] : f32 into !spv.array<4xf32>
     %7 = spv.CompositeInsert %value2, %6[2 : i32] : f32 into !spv.array<4xf32>
     %8 = spv.CompositeInsert %value0, %7[3 : i32] : f32 into !spv.array<4xf32>
 
     %9 = spv.Undef : !spv.struct<(f32, i32, f32)>
-    // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : !spv.struct<(f32, i32, f32)>
+    // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (f32, i32, f32) -> !spv.struct<(f32, i32, f32)>
     %10 = spv.CompositeInsert %value0, %9[0 : i32] : f32 into !spv.struct<(f32, i32, f32)>
     %11 = spv.CompositeInsert %value3, %10[1 : i32] : i32 into !spv.struct<(f32, i32, f32)>
     %12 = spv.CompositeInsert %value1, %11[2 : i32] : f32 into !spv.struct<(f32, i32, f32)>
 
     %13 = spv.Undef : !spv.struct<(f32, !spv.array<3xf32>)>
-    // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}} : !spv.struct<(f32, !spv.array<3 x f32>)>
+    // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}} : (f32, !spv.array<3 x f32>) -> !spv.struct<(f32, !spv.array<3 x f32>)>
     %14 = spv.CompositeInsert %value0, %13[0 : i32] : f32 into !spv.struct<(f32, !spv.array<3xf32>)>
     %15 = spv.CompositeInsert %value4, %14[1 : i32] : !spv.array<3xf32> into !spv.struct<(f32, !spv.array<3xf32>)>
 

diff  --git a/mlir/test/Target/SPIRV/composite-op.mlir b/mlir/test/Target/SPIRV/composite-op.mlir
index 9b0462b5332f0..7192a078fa819 100644
--- a/mlir/test/Target/SPIRV/composite-op.mlir
+++ b/mlir/test/Target/SPIRV/composite-op.mlir
@@ -7,8 +7,8 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     spv.ReturnValue %0: !spv.struct<(f32, !spv.struct<(!spv.array<4xf32>, f32)>)>
   }
   spv.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> "None" {
-    // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : vector<3xf32>
-    %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : vector<3xf32>
+    // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (f32, f32, f32) -> vector<3xf32>
+    %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xf32>
     spv.ReturnValue %0: vector<3xf32>
   }
   spv.func @vector_dynamic_extract(%vec: vector<4xf32>, %id : i32) -> f32 "None" {

diff  --git a/mlir/test/Target/SPIRV/debug.mlir b/mlir/test/Target/SPIRV/debug.mlir
index cc6d67936618e..fcd118588ceef 100644
--- a/mlir/test/Target/SPIRV/debug.mlir
+++ b/mlir/test/Target/SPIRV/debug.mlir
@@ -33,7 +33,7 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     // CHECK: loc({{".*debug.mlir"}}:34:10)
     %0 = spv.CompositeInsert %arg1, %arg0[1 : i32, 0 : i32] : !spv.array<4xf32> into !spv.struct<(f32, !spv.struct<(!spv.array<4xf32>, f32)>)>
     // CHECK: loc({{".*debug.mlir"}}:36:10)
-    %1 = spv.CompositeConstruct %arg2, %arg3 : vector<2xf32>
+    %1 = spv.CompositeConstruct %arg2, %arg3 : (f32, f32) -> vector<2xf32>
     spv.Return
   }
 


        


More information about the Mlir-commits mailing list