[Mlir-commits] [mlir] [mlir][spirv] Add basic support for SPV_EXT_replicated_composites (PR #147067)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 4 07:32:15 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Mohammadreza Ameri Mahabadian (mahabadm)

<details>
<summary>Changes</summary>

This patch introduces two new ops to the SPIR-V dialect:
- `spirv.EXT.ConstantCompositeReplicate`
- `spirv.EXT.SpecConstantCompositeReplicate`

These ops represent composite constants and specialization constants, respectively, constructed by replicating a single splat constant across all elements. They correspond to `SPV_EXT_replicated_composites` extension instructions:
- `OpConstantCompositeReplicatedEXT`
- `OpSpecConstantCompositeReplicatedEXT`

No transformation to these new ops has been introduced in this patch.

This approach is chosen as per the discussions on RFC https://discourse.llvm.org/t/rfc-basic-support-for-spv-ext-replicated-composites-in-mlir-spir-v-compile-time-constant-lowering-only/86987

---

Patch is 36.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147067.diff


11 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+13-2) 
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td (+86) 
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+119) 
- (modified) mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp (+20-3) 
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+90-3) 
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.h (+39) 
- (modified) mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp (+36) 
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+46) 
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.h (+12) 
- (modified) mlir/test/Target/SPIRV/constant.mlir (+82-1) 
- (modified) mlir/test/Target/SPIRV/spec-constant.mlir (+27) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index d874817e6888d..6c24dbc613c82 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -359,6 +359,7 @@ def SPV_EXT_shader_atomic_float_min_max  : I32EnumAttrCase<"SPV_EXT_shader_atomi
 def SPV_EXT_shader_image_int64           : I32EnumAttrCase<"SPV_EXT_shader_image_int64", 1010>;
 def SPV_EXT_shader_atomic_float16_add    : I32EnumAttrCase<"SPV_EXT_shader_atomic_float16_add", 1011>;
 def SPV_EXT_mesh_shader                  : I32EnumAttrCase<"SPV_EXT_mesh_shader", 1012>;
+def SPV_EXT_replicated_composites        : I32EnumAttrCase<"SPV_EXT_replicated_composites", 1013>;
 
 def SPV_AMD_gpu_shader_half_float_fetch          : I32EnumAttrCase<"SPV_AMD_gpu_shader_half_float_fetch", 2000>;
 def SPV_AMD_shader_ballot                        : I32EnumAttrCase<"SPV_AMD_shader_ballot", 2001>;
@@ -446,7 +447,7 @@ def SPIRV_ExtensionAttr :
       SPV_EXT_shader_stencil_export, SPV_EXT_shader_viewport_index_layer,
       SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
       SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
-      SPV_EXT_mesh_shader,
+      SPV_EXT_mesh_shader, SPV_EXT_replicated_composites,
       SPV_ARM_tensors,
       SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
       SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
@@ -849,6 +850,12 @@ def SPIRV_C_CooperativeMatrixKHR                        : I32EnumAttrCase<"Coope
     MinVersion<SPIRV_V_1_6>
   ];
 }
+def SPIRV_C_ReplicatedCompositesEXT                     : I32EnumAttrCase<"ReplicatedCompositesEXT", 6024> {
+  list<Availability> availability = [
+    Extension<[SPV_EXT_replicated_composites]>,
+    MinVersion<SPIRV_V_1_0>
+  ];
+}
 def SPIRV_C_BitInstructions                             : I32EnumAttrCase<"BitInstructions", 6025> {
   list<Availability> availability = [
     Extension<[SPV_KHR_bit_instructions]>
@@ -1500,7 +1507,7 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_USMStorageClassesINTEL, SPIRV_C_IOPipesINTEL, SPIRV_C_BlockingPipesINTEL,
       SPIRV_C_FPGARegINTEL, SPIRV_C_DotProductInputAll,
       SPIRV_C_DotProductInput4x8BitPacked, SPIRV_C_DotProduct, SPIRV_C_RayCullMaskKHR,
-      SPIRV_C_CooperativeMatrixKHR,
+      SPIRV_C_CooperativeMatrixKHR, SPIRV_C_ReplicatedCompositesEXT,
       SPIRV_C_BitInstructions, SPIRV_C_AtomicFloat32AddEXT, SPIRV_C_AtomicFloat64AddEXT,
       SPIRV_C_LongConstantCompositeINTEL, SPIRV_C_OptNoneINTEL,
       SPIRV_C_AtomicFloat16AddEXT, SPIRV_C_DebugInfoModuleINTEL, SPIRV_C_SplitBarrierINTEL,
@@ -4564,6 +4571,8 @@ def SPIRV_OC_OpCooperativeMatrixLoadKHR       : I32EnumAttrCase<"OpCooperativeMa
 def SPIRV_OC_OpCooperativeMatrixStoreKHR      : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
 def SPIRV_OC_OpCooperativeMatrixMulAddKHR     : I32EnumAttrCase<"OpCooperativeMatrixMulAddKHR", 4459>;
 def SPIRV_OC_OpCooperativeMatrixLengthKHR     : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
+def SPIRV_OC_OpConstantCompositeReplicateEXT : I32EnumAttrCase<"OpConstantCompositeReplicateEXT", 4461>;
+def SPIRV_OC_OpSpecConstantCompositeReplicateEXT : I32EnumAttrCase<"OpSpecConstantCompositeReplicateEXT", 4462>;
 def SPIRV_OC_OpEmitMeshTasksEXT               : I32EnumAttrCase<"OpEmitMeshTasksEXT", 5294>;
 def SPIRV_OC_OpSetMeshOutputsEXT              : I32EnumAttrCase<"OpSetMeshOutputsEXT", 5295>;
 def SPIRV_OC_OpSubgroupBlockReadINTEL         : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>;
@@ -4672,6 +4681,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
       SPIRV_OC_OpCooperativeMatrixLoadKHR, SPIRV_OC_OpCooperativeMatrixStoreKHR,
       SPIRV_OC_OpCooperativeMatrixMulAddKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
+      SPIRV_OC_OpConstantCompositeReplicateEXT,
+      SPIRV_OC_OpSpecConstantCompositeReplicateEXT,
       SPIRV_OC_OpEmitMeshTasksEXT, SPIRV_OC_OpSetMeshOutputsEXT,
       SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
       SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index c5a85f881b35e..0a5b01fe9e8d0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -135,6 +135,52 @@ def SPIRV_ConstantOp : SPIRV_Op<"Constant",
   let autogenSerialization = 0;
 }
 
+
+// -----
+
+def SPIRV_EXTConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"ConstantCompositeReplicate", [Pure]> {
+  let summary = [{
+    Declare a new replicated composite constant op.
+  }];
+
+  let description = [{
+    This op declares a `spiv.EXT.ConstantCompositeReplicate` which represents a
+    splat composite constant i.e. all element of composite constant have the
+    same value. This op will be serialized to SPIR-V `OpConstantCompositeReplicateEXT`.
+    The splat value must come from a non-specialization constant instruction."
+
+    #### Example:
+
+    ```mlir
+    %0 = spirv.Constant 1 : i32
+    %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xf32>
+
+    %2 = spirv.Constant dense<[1, 2]> : vector<2xi32>
+    %3 = spirv.EXT.ConstantCompositeReplicate %2 : !spirv.array<2 x vector<2xi32>>
+
+    %5 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xi32>
+    %6 = spirv.EXT.ConstantCompositeReplicate %5 : !spirv.array<2 x vector<2xi32>>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_EXT_replicated_composites]>,
+    Capability<[SPIRV_C_ReplicatedCompositesEXT]>
+  ];
+
+  let arguments = (ins
+    SPIRV_Type:$constant
+  );
+
+  let results = (outs
+    SPIRV_Composite:$replicated_constant
+  );
+
+  let autogenSerialization = 0;
+}
+
 // -----
 
 def SPIRV_EntryPointOp : SPIRV_Op<"EntryPoint", [InModuleScope]> {
@@ -689,6 +735,46 @@ def SPIRV_SpecConstantCompositeOp : SPIRV_Op<"SpecConstantComposite", [
 
 // -----
 
+def SPIRV_EXTSpecConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"SpecConstantCompositeReplicate", [InModuleScope, Symbol]> {
+  let summary = "Declare a new replicated composite specialization constant op.";
+
+  let description = [{
+    This op declares a `spirv.EXT.SpecConstantCompositeReplicate` which represents
+    a splat specialization composite constant i.e. all element of specialization
+    composite constant have the same value. This op will be serialized to SPIR-V
+    `OpSpecConstantCompositeReplicateEXT`. The splat value must come from a
+    symbol reference of specialization constant instruction.
+
+    #### Example:
+
+    ```mlir
+    spirv.SpecConstant @sc_i32_1 = 1 : i32
+    spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_i32 (@sc_i32_1) : !spirv.array<3 x i32>
+    spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_i32 (@sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_EXT_replicated_composites]>,
+    Capability<[SPIRV_C_ReplicatedCompositesEXT]>
+  ];
+
+  let arguments = (ins
+    TypeAttr:$type,
+    StrAttr:$sym_name,
+    SymbolRefAttr:$constituent
+  );
+
+  let results = (outs);
+
+  let autogenSerialization = 0;
+
+}
+
+// -----
+
 def SPIRV_SpecConstantOperationOp : SPIRV_Op<"SpecConstantOperation", [
        Pure, InFunctionScope,
        SingleBlockImplicitTerminator<"YieldOp">]> {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index eb2974d62fdd1..c42b2d45d53a9 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -765,6 +765,67 @@ void mlir::spirv::AddressOfOp::getAsmResultNames(
   setNameFn(getResult(), specialName.str());
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.EXTConstantCompositeReplicate
+//===----------------------------------------------------------------------===//
+
+ParseResult
+spirv::EXTConstantCompositeReplicateOp::parse(OpAsmParser &parser,
+                                              OperationState &result) {
+  OpAsmParser::UnresolvedOperand constOperand;
+  Type compositeType;
+  if (parser.parseOperand(constOperand) ||
+      parser.parseColonType(compositeType)) {
+    return failure();
+  }
+
+  if (llvm::isa<TensorType>(compositeType)) {
+    if (parser.parseColonType(compositeType))
+      return failure();
+  }
+
+  auto constType = cast<spirv::CompositeType>(compositeType).getElementType(0);
+  while (auto type = llvm::dyn_cast<spirv::ArrayType>(constType)) {
+    constType = type.getElementType();
+  }
+
+  if (parser.resolveOperand(constOperand, constType, result.operands))
+    return failure();
+
+  return parser.addTypeToList(compositeType, result.types);
+}
+
+void spirv::EXTConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
+  printer << ' ' << getConstant() << " : " << getType();
+}
+
+LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
+  auto compositeType = dyn_cast<spirv::CompositeType>(getType());
+  if (!compositeType)
+    return emitError("result type must be a composite type, but provided ")
+           << getType();
+
+  auto constantDefiningOp = getConstant().getDefiningOp();
+  if (!constantDefiningOp)
+    return this->emitOpError("op defining the splat constant not found");
+
+  auto constantOp = dyn_cast_or_null<spirv::ConstantOp>(constantDefiningOp);
+  auto constantCompositeReplicateOp =
+      dyn_cast_or_null<spirv::EXTConstantCompositeReplicateOp>(
+          constantDefiningOp);
+
+  if (!constantOp && !constantCompositeReplicateOp)
+    return this->emitOpError(
+        "op defining the splat constant is not a spirv.Constant or a "
+        "spirv.EXT.ConstantCompositeReplicate");
+
+  if (constantOp)
+    return verifyConstantType(constantOp, constantOp.getValueAttr(),
+                              constantOp.getType());
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.ControlBarrierOp
 //===----------------------------------------------------------------------===//
@@ -1866,6 +1927,64 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.EXTSpecConstantCompositeReplicateOp
+//===----------------------------------------------------------------------===//
+
+ParseResult
+spirv::EXTSpecConstantCompositeReplicateOp::parse(OpAsmParser &parser,
+                                                  OperationState &result) {
+
+  StringAttr compositeName;
+  const char *attrName = "spec_const";
+  FlatSymbolRefAttr specConstRef;
+  NamedAttrList attrs;
+  Type type;
+
+  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
+                             result.attributes) ||
+      parser.parseLParen() ||
+      parser.parseAttribute(specConstRef, Type(), attrName, attrs) ||
+      parser.parseRParen() || parser.parseColonType(type))
+    return failure();
+
+  StringAttr compositeSpecConstituentName =
+      spirv::EXTSpecConstantCompositeReplicateOp::getConstituentAttrName(
+          result.name);
+  result.addAttribute(compositeSpecConstituentName, specConstRef);
+
+  StringAttr typeAttrName =
+      spirv::EXTSpecConstantCompositeReplicateOp::getTypeAttrName(result.name);
+  result.addAttribute(typeAttrName, TypeAttr::get(type));
+
+  return success();
+}
+
+void spirv::EXTSpecConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
+  printer << " ";
+  printer.printSymbolName(getSymName());
+  printer << " (" << this->getConstituent() << ") : " << getType();
+}
+
+LogicalResult spirv::EXTSpecConstantCompositeReplicateOp::verify() {
+  auto compositeType = dyn_cast<spirv::CompositeType>(getType());
+  if (!compositeType)
+    return emitError("result type must be a composite type, but provided ")
+           << getType();
+
+  auto constituentSpecConstOp =
+      dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
+          (*this)->getParentOp(), this->getConstituent()));
+
+  auto constituentType = constituentSpecConstOp.getDefaultValue().getType();
+  auto compositeElemType = compositeType.getElementType(0);
+  if (constituentType != compositeElemType)
+    return emitError("constituent has incorrect type: expected ")
+           << compositeElemType << ", but provided " << constituentType;
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.SpecConstantOperation
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 55d6a380d0bff..5f52308b4be35 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -45,6 +45,12 @@ Value spirv::Deserializer::getValue(uint32_t id) {
     return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
                                                constInfo->first);
   }
+  if (auto constCompositeReplicateInfo = getConstantCompositeReplicate(id)) {
+    auto constantId = constCompositeReplicateInfo->first;
+    auto element = getValue(constantId);
+    return opBuilder.create<spirv::EXTConstantCompositeReplicateOp>(
+        unknownLoc, constCompositeReplicateInfo->second, element);
+  }
   if (auto varOp = getGlobalVariable(id)) {
     auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
         unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation()));
@@ -56,10 +62,17 @@ Value spirv::Deserializer::getValue(uint32_t id) {
         SymbolRefAttr::get(constOp.getOperation()));
     return referenceOfOp.getReference();
   }
-  if (auto constCompositeOp = getSpecConstantComposite(id)) {
+  if (auto specConstCompositeOp = getSpecConstantComposite(id)) {
+    auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
+        unknownLoc, specConstCompositeOp.getType(),
+        SymbolRefAttr::get(specConstCompositeOp.getOperation()));
+    return referenceOfOp.getReference();
+  }
+  if (auto specConstCompositeReplicateOp =
+          getSpecConstantCompositeReplicate(id)) {
     auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
-        unknownLoc, constCompositeOp.getType(),
-        SymbolRefAttr::get(constCompositeOp.getOperation()));
+        unknownLoc, specConstCompositeReplicateOp.getType(),
+        SymbolRefAttr::get(specConstCompositeReplicateOp.getOperation()));
     return referenceOfOp.getReference();
   }
   if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
@@ -175,8 +188,12 @@ LogicalResult spirv::Deserializer::processInstruction(
     return processConstant(operands, /*isSpec=*/true);
   case spirv::Opcode::OpConstantComposite:
     return processConstantComposite(operands);
+  case spirv::Opcode::OpConstantCompositeReplicateEXT:
+    return processConstantCompositeReplicateEXT(operands);
   case spirv::Opcode::OpSpecConstantComposite:
     return processSpecConstantComposite(operands);
+  case spirv::Opcode::OpSpecConstantCompositeReplicateEXT:
+    return processSpecConstantCompositeReplicateEXT(operands);
   case spirv::Opcode::OpSpecConstantOp:
     return processSpecConstantOperation(operands);
   case spirv::Opcode::OpConstantTrue:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index b1abd8b3dffe9..2163ccff93c83 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -678,6 +678,14 @@ spirv::Deserializer::getConstant(uint32_t id) {
   return constIt->getSecond();
 }
 
+std::optional<std::pair<uint32_t, Type>>
+spirv::Deserializer::getConstantCompositeReplicate(uint32_t id) {
+  auto constIt = constantCompositeReplicateMap.find(id);
+  if (constIt == constantCompositeReplicateMap.end())
+    return std::nullopt;
+  return constIt->getSecond();
+}
+
 std::optional<spirv::SpecConstOperationMaterializationInfo>
 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
   auto constIt = specConstOperationMap.find(id);
@@ -1554,15 +1562,58 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult spirv::Deserializer::processConstantCompositeReplicateEXT(
+    ArrayRef<uint32_t> operands) {
+
+  if (operands.size() != 3) {
+    return emitError(
+        unknownLoc,
+        "OpConstantCompositeReplicateEXT must have type <id> and result <id> "
+        "and only one parameter which is <id> of splat constant");
+  }
+
+  Type resultType = getType(operands[0]);
+  if (!resultType) {
+    return emitError(unknownLoc, "undefined result type from <id> ")
+           << operands[0];
+  }
+
+  auto compositeType = dyn_cast<CompositeType>(resultType);
+  if (!compositeType) {
+    return emitError(unknownLoc,
+                     "result type from <id> is not a composite type")
+           << operands[0];
+  }
+
+  auto resultID = operands[1];
+  auto constantID = operands[2];
+
+  auto constantInfo = getConstant(constantID);
+  auto replicatedConstantCompositeInfo =
+      getConstantCompositeReplicate(constantID);
+  if (!constantInfo && !replicatedConstantCompositeInfo) {
+    return emitError(unknownLoc,
+                     "OpConstantCompositeReplicateEXT operand <id> ")
+           << constantID
+           << " must come from a normal constant or a "
+              "OpConstantCompositeReplicateEXT";
+  }
+
+  constantCompositeReplicateMap.try_emplace(resultID, constantID, resultType);
+
+  return success();
+}
+
 LogicalResult
 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
   if (operands.size() < 2) {
-    return emitError(unknownLoc,
-                     "OpConstantComposite must have type <id> and result <id>");
+    return emitError(
+        unknownLoc,
+        "OpSpecConstantComposite must have type <id> and result <id>");
   }
   if (operands.size() < 3) {
     return emitError(unknownLoc,
-                     "OpConstantComposite must have at least 1 parameter");
+                     "OpSpecConstantComposite must have at least 1 parameter");
   }
 
   Type resultType = getType(operands[0]);
@@ -1589,6 +1640,42 @@ spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult spirv::Deserializer::processSpecConstantCompositeReplicateEXT(
+    ArrayRef<uint32_t> operands) {
+
+  if (operands.size() != 3) {
+    return emitError(unknownLoc,
+                     "OpSpecConstantCompositeReplicateEXT must have "
+                     "type <id> and result <id> and only one parameter which "
+                     "is <id> of splat constant");
+  }
+
+  Type resultType = getType(operands[0]);
+  if (!resultType) {
+    return emitError(unknownLoc, "undefined result type from <id> ")
+           << operands[0];
+  }
+
+  auto compositeType = dyn_cast<CompositeType>(resultType);
+  if (!compositeType) {
+    return emitError(unknownLoc,
+                     "result type from <id> is not a composite type")
+           << operands[0];
+  }
+
+  auto resultID = operands[1];
+
+  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
+  auto constituentSpecConstantOp = getSpecConstant(operands[2]);
+  auto op = opBuilder.create<spirv::EXTSpecConstantCompositeReplicateOp>(
+      unknownLoc, TypeAttr::get(resultType), symName,
+      SymbolRefAttr::get(constituentSpecConstantOp));
+
+  specConstCompositeReplicateMap[resultID] = op;
+
+  return success();
+}
+
 LogicalResult
 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
   if (operands.size() < 3)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 1bc9e4a3c75d8..1fdecc3e6fe0d 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -190,6 +190,12 @@ class Deseria...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/147067


More information about the Mlir-commits mailing list