[Mlir-commits] [mlir] [mlir][spirv] Add basic support for SPV_EXT_replicated_composites (PR #147067)

Mohammadreza Ameri Mahabadian llvmlistbot at llvm.org
Fri Jul 11 03:43:38 PDT 2025


https://github.com/mahabadm updated https://github.com/llvm/llvm-project/pull/147067

>From 955f819da1389d7e038fcea643bd9af287368f3d Mon Sep 17 00:00:00 2001
From: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Date: Sun, 29 Jun 2025 07:25:48 +0100
Subject: [PATCH 1/8] [mlir][spirv] Add basic support for
 SPV_EXT_replicated_composites

This patch introduces two new ops to the SPIR-V dialect:
- `spirv.EXT.ConstantCompositeReplicate`
- `spirv.EXT.SpecConstantCompositeReplicate`

These ops represent composite constants and specialization constants, respectively, constructed by replicating a single splat constant across all elements. They correspond to `SPV_EXT_replicated_composites` extension instructions:
- `OpConstantCompositeReplicatedEXT`
- `OpSpecConstantCompositeReplicatedEXT`

No transformation to these new ops has been introduced in this patch.

This approach is chosen as per the discussions on RFC https://discourse.llvm.org/t/rfc-basic-support-for-spv-ext-replicated-composites-in-mlir-spir-v-compile-time-constant-lowering-only/86987

Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  15 ++-
 .../Dialect/SPIRV/IR/SPIRVStructureOps.td     |  86 +++++++++++++
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        | 119 ++++++++++++++++++
 .../SPIRV/Deserialization/DeserializeOps.cpp  |  23 +++-
 .../SPIRV/Deserialization/Deserializer.cpp    |  93 +++++++++++++-
 .../SPIRV/Deserialization/Deserializer.h      |  39 ++++++
 .../SPIRV/Serialization/SerializeOps.cpp      |  36 ++++++
 .../Target/SPIRV/Serialization/Serializer.cpp |  46 +++++++
 .../Target/SPIRV/Serialization/Serializer.h   |  12 ++
 mlir/test/Target/SPIRV/constant.mlir          |  83 +++++++++++-
 mlir/test/Target/SPIRV/spec-constant.mlir     |  27 ++++
 11 files changed, 570 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index d874817e6888d..6c24dbc613c82 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -359,6 +359,7 @@ def SPV_EXT_shader_atomic_float_min_max  : I32EnumAttrCase<"SPV_EXT_shader_atomi
 def SPV_EXT_shader_image_int64           : I32EnumAttrCase<"SPV_EXT_shader_image_int64", 1010>;
 def SPV_EXT_shader_atomic_float16_add    : I32EnumAttrCase<"SPV_EXT_shader_atomic_float16_add", 1011>;
 def SPV_EXT_mesh_shader                  : I32EnumAttrCase<"SPV_EXT_mesh_shader", 1012>;
+def SPV_EXT_replicated_composites        : I32EnumAttrCase<"SPV_EXT_replicated_composites", 1013>;
 
 def SPV_AMD_gpu_shader_half_float_fetch          : I32EnumAttrCase<"SPV_AMD_gpu_shader_half_float_fetch", 2000>;
 def SPV_AMD_shader_ballot                        : I32EnumAttrCase<"SPV_AMD_shader_ballot", 2001>;
@@ -446,7 +447,7 @@ def SPIRV_ExtensionAttr :
       SPV_EXT_shader_stencil_export, SPV_EXT_shader_viewport_index_layer,
       SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
       SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
-      SPV_EXT_mesh_shader,
+      SPV_EXT_mesh_shader, SPV_EXT_replicated_composites,
       SPV_ARM_tensors,
       SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
       SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
@@ -849,6 +850,12 @@ def SPIRV_C_CooperativeMatrixKHR                        : I32EnumAttrCase<"Coope
     MinVersion<SPIRV_V_1_6>
   ];
 }
+def SPIRV_C_ReplicatedCompositesEXT                     : I32EnumAttrCase<"ReplicatedCompositesEXT", 6024> {
+  list<Availability> availability = [
+    Extension<[SPV_EXT_replicated_composites]>,
+    MinVersion<SPIRV_V_1_0>
+  ];
+}
 def SPIRV_C_BitInstructions                             : I32EnumAttrCase<"BitInstructions", 6025> {
   list<Availability> availability = [
     Extension<[SPV_KHR_bit_instructions]>
@@ -1500,7 +1507,7 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_USMStorageClassesINTEL, SPIRV_C_IOPipesINTEL, SPIRV_C_BlockingPipesINTEL,
       SPIRV_C_FPGARegINTEL, SPIRV_C_DotProductInputAll,
       SPIRV_C_DotProductInput4x8BitPacked, SPIRV_C_DotProduct, SPIRV_C_RayCullMaskKHR,
-      SPIRV_C_CooperativeMatrixKHR,
+      SPIRV_C_CooperativeMatrixKHR, SPIRV_C_ReplicatedCompositesEXT,
       SPIRV_C_BitInstructions, SPIRV_C_AtomicFloat32AddEXT, SPIRV_C_AtomicFloat64AddEXT,
       SPIRV_C_LongConstantCompositeINTEL, SPIRV_C_OptNoneINTEL,
       SPIRV_C_AtomicFloat16AddEXT, SPIRV_C_DebugInfoModuleINTEL, SPIRV_C_SplitBarrierINTEL,
@@ -4564,6 +4571,8 @@ def SPIRV_OC_OpCooperativeMatrixLoadKHR       : I32EnumAttrCase<"OpCooperativeMa
 def SPIRV_OC_OpCooperativeMatrixStoreKHR      : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
 def SPIRV_OC_OpCooperativeMatrixMulAddKHR     : I32EnumAttrCase<"OpCooperativeMatrixMulAddKHR", 4459>;
 def SPIRV_OC_OpCooperativeMatrixLengthKHR     : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
+def SPIRV_OC_OpConstantCompositeReplicateEXT : I32EnumAttrCase<"OpConstantCompositeReplicateEXT", 4461>;
+def SPIRV_OC_OpSpecConstantCompositeReplicateEXT : I32EnumAttrCase<"OpSpecConstantCompositeReplicateEXT", 4462>;
 def SPIRV_OC_OpEmitMeshTasksEXT               : I32EnumAttrCase<"OpEmitMeshTasksEXT", 5294>;
 def SPIRV_OC_OpSetMeshOutputsEXT              : I32EnumAttrCase<"OpSetMeshOutputsEXT", 5295>;
 def SPIRV_OC_OpSubgroupBlockReadINTEL         : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>;
@@ -4672,6 +4681,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
       SPIRV_OC_OpCooperativeMatrixLoadKHR, SPIRV_OC_OpCooperativeMatrixStoreKHR,
       SPIRV_OC_OpCooperativeMatrixMulAddKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
+      SPIRV_OC_OpConstantCompositeReplicateEXT,
+      SPIRV_OC_OpSpecConstantCompositeReplicateEXT,
       SPIRV_OC_OpEmitMeshTasksEXT, SPIRV_OC_OpSetMeshOutputsEXT,
       SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
       SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index c5a85f881b35e..0a5b01fe9e8d0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -135,6 +135,52 @@ def SPIRV_ConstantOp : SPIRV_Op<"Constant",
   let autogenSerialization = 0;
 }
 
+
+// -----
+
+def SPIRV_EXTConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"ConstantCompositeReplicate", [Pure]> {
+  let summary = [{
+    Declare a new replicated composite constant op.
+  }];
+
+  let description = [{
+    This op declares a `spiv.EXT.ConstantCompositeReplicate` which represents a
+    splat composite constant i.e. all element of composite constant have the
+    same value. This op will be serialized to SPIR-V `OpConstantCompositeReplicateEXT`.
+    The splat value must come from a non-specialization constant instruction."
+
+    #### Example:
+
+    ```mlir
+    %0 = spirv.Constant 1 : i32
+    %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xf32>
+
+    %2 = spirv.Constant dense<[1, 2]> : vector<2xi32>
+    %3 = spirv.EXT.ConstantCompositeReplicate %2 : !spirv.array<2 x vector<2xi32>>
+
+    %5 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xi32>
+    %6 = spirv.EXT.ConstantCompositeReplicate %5 : !spirv.array<2 x vector<2xi32>>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_EXT_replicated_composites]>,
+    Capability<[SPIRV_C_ReplicatedCompositesEXT]>
+  ];
+
+  let arguments = (ins
+    SPIRV_Type:$constant
+  );
+
+  let results = (outs
+    SPIRV_Composite:$replicated_constant
+  );
+
+  let autogenSerialization = 0;
+}
+
 // -----
 
 def SPIRV_EntryPointOp : SPIRV_Op<"EntryPoint", [InModuleScope]> {
@@ -689,6 +735,46 @@ def SPIRV_SpecConstantCompositeOp : SPIRV_Op<"SpecConstantComposite", [
 
 // -----
 
+def SPIRV_EXTSpecConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"SpecConstantCompositeReplicate", [InModuleScope, Symbol]> {
+  let summary = "Declare a new replicated composite specialization constant op.";
+
+  let description = [{
+    This op declares a `spirv.EXT.SpecConstantCompositeReplicate` which represents
+    a splat specialization composite constant i.e. all element of specialization
+    composite constant have the same value. This op will be serialized to SPIR-V
+    `OpSpecConstantCompositeReplicateEXT`. The splat value must come from a
+    symbol reference of specialization constant instruction.
+
+    #### Example:
+
+    ```mlir
+    spirv.SpecConstant @sc_i32_1 = 1 : i32
+    spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_i32 (@sc_i32_1) : !spirv.array<3 x i32>
+    spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_i32 (@sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_EXT_replicated_composites]>,
+    Capability<[SPIRV_C_ReplicatedCompositesEXT]>
+  ];
+
+  let arguments = (ins
+    TypeAttr:$type,
+    StrAttr:$sym_name,
+    SymbolRefAttr:$constituent
+  );
+
+  let results = (outs);
+
+  let autogenSerialization = 0;
+
+}
+
+// -----
+
 def SPIRV_SpecConstantOperationOp : SPIRV_Op<"SpecConstantOperation", [
        Pure, InFunctionScope,
        SingleBlockImplicitTerminator<"YieldOp">]> {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index eb2974d62fdd1..c42b2d45d53a9 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -765,6 +765,67 @@ void mlir::spirv::AddressOfOp::getAsmResultNames(
   setNameFn(getResult(), specialName.str());
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.EXTConstantCompositeReplicate
+//===----------------------------------------------------------------------===//
+
+ParseResult
+spirv::EXTConstantCompositeReplicateOp::parse(OpAsmParser &parser,
+                                              OperationState &result) {
+  OpAsmParser::UnresolvedOperand constOperand;
+  Type compositeType;
+  if (parser.parseOperand(constOperand) ||
+      parser.parseColonType(compositeType)) {
+    return failure();
+  }
+
+  if (llvm::isa<TensorType>(compositeType)) {
+    if (parser.parseColonType(compositeType))
+      return failure();
+  }
+
+  auto constType = cast<spirv::CompositeType>(compositeType).getElementType(0);
+  while (auto type = llvm::dyn_cast<spirv::ArrayType>(constType)) {
+    constType = type.getElementType();
+  }
+
+  if (parser.resolveOperand(constOperand, constType, result.operands))
+    return failure();
+
+  return parser.addTypeToList(compositeType, result.types);
+}
+
+void spirv::EXTConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
+  printer << ' ' << getConstant() << " : " << getType();
+}
+
+LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
+  auto compositeType = dyn_cast<spirv::CompositeType>(getType());
+  if (!compositeType)
+    return emitError("result type must be a composite type, but provided ")
+           << getType();
+
+  auto constantDefiningOp = getConstant().getDefiningOp();
+  if (!constantDefiningOp)
+    return this->emitOpError("op defining the splat constant not found");
+
+  auto constantOp = dyn_cast_or_null<spirv::ConstantOp>(constantDefiningOp);
+  auto constantCompositeReplicateOp =
+      dyn_cast_or_null<spirv::EXTConstantCompositeReplicateOp>(
+          constantDefiningOp);
+
+  if (!constantOp && !constantCompositeReplicateOp)
+    return this->emitOpError(
+        "op defining the splat constant is not a spirv.Constant or a "
+        "spirv.EXT.ConstantCompositeReplicate");
+
+  if (constantOp)
+    return verifyConstantType(constantOp, constantOp.getValueAttr(),
+                              constantOp.getType());
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.ControlBarrierOp
 //===----------------------------------------------------------------------===//
@@ -1866,6 +1927,64 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.EXTSpecConstantCompositeReplicateOp
+//===----------------------------------------------------------------------===//
+
+ParseResult
+spirv::EXTSpecConstantCompositeReplicateOp::parse(OpAsmParser &parser,
+                                                  OperationState &result) {
+
+  StringAttr compositeName;
+  const char *attrName = "spec_const";
+  FlatSymbolRefAttr specConstRef;
+  NamedAttrList attrs;
+  Type type;
+
+  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
+                             result.attributes) ||
+      parser.parseLParen() ||
+      parser.parseAttribute(specConstRef, Type(), attrName, attrs) ||
+      parser.parseRParen() || parser.parseColonType(type))
+    return failure();
+
+  StringAttr compositeSpecConstituentName =
+      spirv::EXTSpecConstantCompositeReplicateOp::getConstituentAttrName(
+          result.name);
+  result.addAttribute(compositeSpecConstituentName, specConstRef);
+
+  StringAttr typeAttrName =
+      spirv::EXTSpecConstantCompositeReplicateOp::getTypeAttrName(result.name);
+  result.addAttribute(typeAttrName, TypeAttr::get(type));
+
+  return success();
+}
+
+void spirv::EXTSpecConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
+  printer << " ";
+  printer.printSymbolName(getSymName());
+  printer << " (" << this->getConstituent() << ") : " << getType();
+}
+
+LogicalResult spirv::EXTSpecConstantCompositeReplicateOp::verify() {
+  auto compositeType = dyn_cast<spirv::CompositeType>(getType());
+  if (!compositeType)
+    return emitError("result type must be a composite type, but provided ")
+           << getType();
+
+  auto constituentSpecConstOp =
+      dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
+          (*this)->getParentOp(), this->getConstituent()));
+
+  auto constituentType = constituentSpecConstOp.getDefaultValue().getType();
+  auto compositeElemType = compositeType.getElementType(0);
+  if (constituentType != compositeElemType)
+    return emitError("constituent has incorrect type: expected ")
+           << compositeElemType << ", but provided " << constituentType;
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.SpecConstantOperation
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 55d6a380d0bff..5f52308b4be35 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -45,6 +45,12 @@ Value spirv::Deserializer::getValue(uint32_t id) {
     return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
                                                constInfo->first);
   }
+  if (auto constCompositeReplicateInfo = getConstantCompositeReplicate(id)) {
+    auto constantId = constCompositeReplicateInfo->first;
+    auto element = getValue(constantId);
+    return opBuilder.create<spirv::EXTConstantCompositeReplicateOp>(
+        unknownLoc, constCompositeReplicateInfo->second, element);
+  }
   if (auto varOp = getGlobalVariable(id)) {
     auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
         unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation()));
@@ -56,10 +62,17 @@ Value spirv::Deserializer::getValue(uint32_t id) {
         SymbolRefAttr::get(constOp.getOperation()));
     return referenceOfOp.getReference();
   }
-  if (auto constCompositeOp = getSpecConstantComposite(id)) {
+  if (auto specConstCompositeOp = getSpecConstantComposite(id)) {
+    auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
+        unknownLoc, specConstCompositeOp.getType(),
+        SymbolRefAttr::get(specConstCompositeOp.getOperation()));
+    return referenceOfOp.getReference();
+  }
+  if (auto specConstCompositeReplicateOp =
+          getSpecConstantCompositeReplicate(id)) {
     auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
-        unknownLoc, constCompositeOp.getType(),
-        SymbolRefAttr::get(constCompositeOp.getOperation()));
+        unknownLoc, specConstCompositeReplicateOp.getType(),
+        SymbolRefAttr::get(specConstCompositeReplicateOp.getOperation()));
     return referenceOfOp.getReference();
   }
   if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
@@ -175,8 +188,12 @@ LogicalResult spirv::Deserializer::processInstruction(
     return processConstant(operands, /*isSpec=*/true);
   case spirv::Opcode::OpConstantComposite:
     return processConstantComposite(operands);
+  case spirv::Opcode::OpConstantCompositeReplicateEXT:
+    return processConstantCompositeReplicateEXT(operands);
   case spirv::Opcode::OpSpecConstantComposite:
     return processSpecConstantComposite(operands);
+  case spirv::Opcode::OpSpecConstantCompositeReplicateEXT:
+    return processSpecConstantCompositeReplicateEXT(operands);
   case spirv::Opcode::OpSpecConstantOp:
     return processSpecConstantOperation(operands);
   case spirv::Opcode::OpConstantTrue:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index b1abd8b3dffe9..2163ccff93c83 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -678,6 +678,14 @@ spirv::Deserializer::getConstant(uint32_t id) {
   return constIt->getSecond();
 }
 
+std::optional<std::pair<uint32_t, Type>>
+spirv::Deserializer::getConstantCompositeReplicate(uint32_t id) {
+  auto constIt = constantCompositeReplicateMap.find(id);
+  if (constIt == constantCompositeReplicateMap.end())
+    return std::nullopt;
+  return constIt->getSecond();
+}
+
 std::optional<spirv::SpecConstOperationMaterializationInfo>
 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
   auto constIt = specConstOperationMap.find(id);
@@ -1554,15 +1562,58 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult spirv::Deserializer::processConstantCompositeReplicateEXT(
+    ArrayRef<uint32_t> operands) {
+
+  if (operands.size() != 3) {
+    return emitError(
+        unknownLoc,
+        "OpConstantCompositeReplicateEXT must have type <id> and result <id> "
+        "and only one parameter which is <id> of splat constant");
+  }
+
+  Type resultType = getType(operands[0]);
+  if (!resultType) {
+    return emitError(unknownLoc, "undefined result type from <id> ")
+           << operands[0];
+  }
+
+  auto compositeType = dyn_cast<CompositeType>(resultType);
+  if (!compositeType) {
+    return emitError(unknownLoc,
+                     "result type from <id> is not a composite type")
+           << operands[0];
+  }
+
+  auto resultID = operands[1];
+  auto constantID = operands[2];
+
+  auto constantInfo = getConstant(constantID);
+  auto replicatedConstantCompositeInfo =
+      getConstantCompositeReplicate(constantID);
+  if (!constantInfo && !replicatedConstantCompositeInfo) {
+    return emitError(unknownLoc,
+                     "OpConstantCompositeReplicateEXT operand <id> ")
+           << constantID
+           << " must come from a normal constant or a "
+              "OpConstantCompositeReplicateEXT";
+  }
+
+  constantCompositeReplicateMap.try_emplace(resultID, constantID, resultType);
+
+  return success();
+}
+
 LogicalResult
 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
   if (operands.size() < 2) {
-    return emitError(unknownLoc,
-                     "OpConstantComposite must have type <id> and result <id>");
+    return emitError(
+        unknownLoc,
+        "OpSpecConstantComposite must have type <id> and result <id>");
   }
   if (operands.size() < 3) {
     return emitError(unknownLoc,
-                     "OpConstantComposite must have at least 1 parameter");
+                     "OpSpecConstantComposite must have at least 1 parameter");
   }
 
   Type resultType = getType(operands[0]);
@@ -1589,6 +1640,42 @@ spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult spirv::Deserializer::processSpecConstantCompositeReplicateEXT(
+    ArrayRef<uint32_t> operands) {
+
+  if (operands.size() != 3) {
+    return emitError(unknownLoc,
+                     "OpSpecConstantCompositeReplicateEXT must have "
+                     "type <id> and result <id> and only one parameter which "
+                     "is <id> of splat constant");
+  }
+
+  Type resultType = getType(operands[0]);
+  if (!resultType) {
+    return emitError(unknownLoc, "undefined result type from <id> ")
+           << operands[0];
+  }
+
+  auto compositeType = dyn_cast<CompositeType>(resultType);
+  if (!compositeType) {
+    return emitError(unknownLoc,
+                     "result type from <id> is not a composite type")
+           << operands[0];
+  }
+
+  auto resultID = operands[1];
+
+  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
+  auto constituentSpecConstantOp = getSpecConstant(operands[2]);
+  auto op = opBuilder.create<spirv::EXTSpecConstantCompositeReplicateOp>(
+      unknownLoc, TypeAttr::get(resultType), symName,
+      SymbolRefAttr::get(constituentSpecConstantOp));
+
+  specConstCompositeReplicateMap[resultID] = op;
+
+  return success();
+}
+
 LogicalResult
 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
   if (operands.size() < 3)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 1bc9e4a3c75d8..1fdecc3e6fe0d 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -190,6 +190,12 @@ class Deserializer {
   /// Gets the constant's attribute and type associated with the given <id>.
   std::optional<std::pair<Attribute, Type>> getConstant(uint32_t id);
 
+  /// Gets the pair of id of `spirv.Constant` generating a
+  /// replicated composite and the type of resulting
+  /// `spirv.EXT.ConstantCompositeReplicate` given its <id>.
+  std::optional<std::pair<uint32_t, Type>>
+  getConstantCompositeReplicate(uint32_t id);
+
   /// Gets the info needed to materialize the spec constant operation op
   /// associated with the given <id>.
   std::optional<SpecConstOperationMaterializationInfo>
@@ -220,6 +226,13 @@ class Deserializer {
     return specConstCompositeMap.lookup(id);
   }
 
+  /// Gets the replicated composite specialization constant with the given
+  /// result <id>.
+  spirv::EXTSpecConstantCompositeReplicateOp
+  getSpecConstantCompositeReplicate(uint32_t id) {
+    return specConstCompositeReplicateMap.lookup(id);
+  }
+
   /// Creates a spirv::SpecConstantOp.
   spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID,
                                            TypedAttr defaultValue);
@@ -313,10 +326,20 @@ class Deserializer {
   /// `operands`.
   LogicalResult processConstantComposite(ArrayRef<uint32_t> operands);
 
+  /// Processes a SPIR-V OpConstantCompositeReplicateEXT instruction with
+  /// the given `operands`.
+  LogicalResult
+  processConstantCompositeReplicateEXT(ArrayRef<uint32_t> operands);
+
   /// Processes a SPIR-V OpSpecConstantComposite instruction with the given
   /// `operands`.
   LogicalResult processSpecConstantComposite(ArrayRef<uint32_t> operands);
 
+  /// Processes a SPIR-V OpSpecConstantCompositeReplicateEXT instruction with
+  /// the given `operands`.
+  LogicalResult
+  processSpecConstantCompositeReplicateEXT(ArrayRef<uint32_t> operands);
+
   /// Processes a SPIR-V OpSpecConstantOp instruction with the given
   /// `operands`.
   LogicalResult processSpecConstantOperation(ArrayRef<uint32_t> operands);
@@ -549,12 +572,28 @@ class Deserializer {
   /// (and type) here. Later when it's used, we materialize the constant.
   DenseMap<uint32_t, std::pair<Attribute, Type>> constantMap;
 
+  // Result <id> to replicated constant id and type mapping.
+  ///
+  /// In the SPIR-V binary format, OpConstantCompositeReplicateEXT is placed in
+  /// the module and shared by instructions at module level and in subsequent
+  /// functions. But in the SPIR-V dialect, this is materialized to where
+  /// it's used in the function. So when seeing a
+  /// OpConstantCompositeReplicateEXT in the binary format, we don't immediately
+  /// emit a `spirv.EXT.ConstantCompositeReplicate` op into the module, we keep
+  /// the id of its operand (the splat constant) and type) here. Later when it's
+  /// used, we materialize the `spirv.EXT.ConstantCompositeReplicate`.
+  DenseMap<uint32_t, std::pair<uint32_t, Type>> constantCompositeReplicateMap;
+
   // Result <id> to spec constant mapping.
   DenseMap<uint32_t, spirv::SpecConstantOp> specConstMap;
 
   // Result <id> to composite spec constant mapping.
   DenseMap<uint32_t, spirv::SpecConstantCompositeOp> specConstCompositeMap;
 
+  // Result <id> to replicated composite spec constant mapping.
+  DenseMap<uint32_t, spirv::EXTSpecConstantCompositeReplicateOp>
+      specConstCompositeReplicateMap;
+
   /// Result <id> to info needed to materialize an OpSpecConstantOp
   /// mapping.
   DenseMap<uint32_t, SpecConstOperationMaterializationInfo>
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index ff3cc92ee8078..89edf6c82b090 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -66,6 +66,15 @@ LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
   return failure();
 }
 
+LogicalResult Serializer::processConstantCompositeReplicateOp(
+    spirv::EXTConstantCompositeReplicateOp op) {
+  if (auto resultID = prepareConstantCompositeReplicate(op)) {
+    valueIDMap[op.getResult()] = resultID;
+    return success();
+  }
+  return failure();
+}
+
 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
   if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(),
                                             /*isSpec=*/true)) {
@@ -118,6 +127,33 @@ Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
   return processName(resultID, op.getSymName());
 }
 
+LogicalResult Serializer::processSpecConstantCompositeReplicateOp(
+    spirv::EXTSpecConstantCompositeReplicateOp op) {
+  uint32_t typeID = 0;
+  if (failed(processType(op.getLoc(), op.getType(), typeID))) {
+    return failure();
+  }
+
+  auto constituent = dyn_cast<FlatSymbolRefAttr>(op.getConstituent());
+  auto constituentName = constituent.getValue();
+  auto constituentID = getSpecConstID(constituentName);
+  if (!constituentID) {
+    return op.emitError("unknown result <id> for replicated spec constant ")
+           << constituentName;
+  }
+
+  auto resultID = getNextID();
+  SmallVector<uint32_t> operands = {typeID, resultID, constituentID};
+
+  encodeInstructionInto(typesGlobalValues,
+                        spirv::Opcode::OpSpecConstantCompositeReplicateEXT,
+                        operands);
+
+  specConstIDMap[op.getSymName()] = resultID;
+
+  return processName(resultID, op.getSymName());
+}
+
 LogicalResult
 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
   uint32_t typeID = 0;
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index ebebd2d283afa..2e767cd822617 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1109,6 +1109,46 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
   return resultID;
 }
 
+uint32_t Serializer::prepareConstantCompositeReplicate(
+    spirv::EXTConstantCompositeReplicateOp op) {
+  if (auto id = getValueID(op.getResult())) {
+    return id;
+  }
+
+  uint32_t typeID = 0;
+  if (failed(processType(op.getLoc(), op.getType(), typeID))) {
+    return 0;
+  }
+
+  auto definingOp = op.getConstant().getDefiningOp();
+  if (!definingOp) {
+    emitError(op.getLoc(), "op defining splat value not found");
+    return 0;
+  }
+
+  uint32_t operandID;
+  if (auto constantOp = dyn_cast_or_null<spirv::ConstantOp>(definingOp)) {
+    operandID = getConstantID(constantOp.getValue());
+
+  } else if (auto constantCompositeReplicateOp =
+                 dyn_cast_or_null<spirv::EXTConstantCompositeReplicateOp>(
+                     definingOp)) {
+    operandID = prepareConstantCompositeReplicate(constantCompositeReplicateOp);
+  } else {
+    emitError(op.getLoc(), "operand op type not supported");
+    return 0;
+  }
+
+  uint32_t resultID = getNextID();
+  SmallVector<uint32_t> operands = {typeID, resultID, operandID};
+
+  encodeInstructionInto(typesGlobalValues,
+                        spirv::Opcode::OpConstantCompositeReplicateEXT,
+                        operands);
+
+  return resultID;
+}
+
 //===----------------------------------------------------------------------===//
 // Control flow
 //===----------------------------------------------------------------------===//
@@ -1328,6 +1368,9 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
         return processBranchConditionalOp(op);
       })
       .Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
+      .Case([&](spirv::EXTConstantCompositeReplicateOp op) {
+        return processConstantCompositeReplicateOp(op);
+      })
       .Case([&](spirv::FuncOp op) { return processFuncOp(op); })
       .Case([&](spirv::GlobalVariableOp op) {
         return processGlobalVariableOp(op);
@@ -1339,6 +1382,9 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
       .Case([&](spirv::SpecConstantCompositeOp op) {
         return processSpecConstantCompositeOp(op);
       })
+      .Case([&](spirv::EXTSpecConstantCompositeReplicateOp op) {
+        return processSpecConstantCompositeReplicateOp(op);
+      })
       .Case([&](spirv::SpecConstantOperationOp op) {
         return processSpecConstantOperationOp(op);
       })
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index 9edb0f4af008d..183af41172be7 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -108,11 +108,17 @@ class Serializer {
 
   LogicalResult processConstantOp(spirv::ConstantOp op);
 
+  LogicalResult processConstantCompositeReplicateOp(
+      spirv::EXTConstantCompositeReplicateOp op);
+
   LogicalResult processSpecConstantOp(spirv::SpecConstantOp op);
 
   LogicalResult
   processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op);
 
+  LogicalResult processSpecConstantCompositeReplicateOp(
+      spirv::EXTSpecConstantCompositeReplicateOp op);
+
   LogicalResult
   processSpecConstantOperationOp(spirv::SpecConstantOperationOp op);
 
@@ -230,6 +236,12 @@ class Serializer {
   uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr,
                              bool isSpec = false);
 
+  /// Prepares `spirv.EXTConstantCompositeReplicateOp` serialization. This
+  /// method emits OpConstantCompositeReplicateEXT and returns the result <id>
+  /// associated with it.
+  uint32_t
+  prepareConstantCompositeReplicate(spirv::EXTConstantCompositeReplicateOp op);
+
   //===--------------------------------------------------------------------===//
   // Control flow
   //===--------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 8d4e53418b70f..0b9b28ae536ea 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip %s | FileCheck %s
+// RUN: mlir-translate --no-implicit-module -split-input-file --test-spirv-roundtrip %s | FileCheck %s
 
 spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
   // CHECK-LABEL: @bool_const
@@ -306,3 +306,84 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
   }
 }
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompositesEXT], [SPV_EXT_replicated_composites]> {
+
+  // CHECK-LABEL: @splat_vector_i32
+  spirv.func @splat_vector_i32() -> (vector<3xi32>) "None" {
+    // CHECK: spirv.Constant 1 : i32
+    %0 = spirv.Constant 1 : i32
+    // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : vector<3xi32>
+    %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<3xi32>
+    spirv.ReturnValue %1 : vector<3xi32>
+  }
+
+  // CHECK-LABEL: @splat_array_of_i32
+  spirv.func @splat_array_of_i32() -> (!spirv.array<3 x i32>) "None" {
+    // CHECK: spirv.Constant 1 : i32
+    %0 = spirv.Constant 1 : i32
+    // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<3 x i32>
+    %1 = spirv.EXT.ConstantCompositeReplicate %0 : !spirv.array<3 x i32>
+    spirv.ReturnValue %1 : !spirv.array<3 x i32>
+  }
+
+  // CHECK-LABEL: @splat_array_of_vectors_of_i32
+  spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" {
+    // CHECK: spirv.Constant dense<[1, 2]> : vector<2xi32>
+    %0 = spirv.Constant dense<[1, 2]> : vector<2xi32>
+    // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<2 x vector<2xi32>>
+    %1 = spirv.EXT.ConstantCompositeReplicate %0 : !spirv.array<2 x vector<2xi32>>
+    spirv.ReturnValue %1 : !spirv.array<2 x vector<2xi32>>
+  }
+
+  // CHECK-LABEL: @splat_array_of_splat_vector_i32
+  spirv.func @splat_array_of_splat_vector_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" {
+    // CHECK: spirv.Constant 2 : i32
+    %0 = spirv.Constant 2 : i32
+    // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : vector<2xi32>
+    %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xi32>
+    // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<2 x vector<2xi32>>
+    %2 = spirv.EXT.ConstantCompositeReplicate %1 : !spirv.array<2 x vector<2xi32>>
+    spirv.ReturnValue %2 : !spirv.array<2 x vector<2xi32>>
+  }
+
+  // CHECK-LABEL: @splat_vector_f32
+  spirv.func @splat_vector_f32() -> (vector<3xf32>) "None" {
+    // CHECK: spirv.Constant 1.000000e+00 : f32
+    %0 = spirv.Constant 1.0 : f32
+    // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : vector<3xf32>
+    %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<3xf32>
+    spirv.ReturnValue %1 : vector<3xf32>
+  }
+
+  // CHECK-LABEL: @splat_array_of_f32
+  spirv.func @splat_array_of_f32() -> (!spirv.array<3 x f32>) "None" {
+    // CHECK: spirv.Constant 1.000000e+00 : f32
+    %0 = spirv.Constant 1.0 : f32
+    // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<3 x f32>
+    %1 = spirv.EXT.ConstantCompositeReplicate %0 : !spirv.array<3 x f32>
+    spirv.ReturnValue %1 : !spirv.array<3 x f32>
+  }
+
+  // CHECK-LABEL: @splat_array_of_vectors_of_f32
+  spirv.func @splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" {
+    // CHECK: spirv.Constant dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>
+    %0 = spirv.Constant dense<[1.0, 2.0]> : vector<2xf32>
+    // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<2 x vector<2xf32>>
+    %1 = spirv.EXT.ConstantCompositeReplicate %0 : !spirv.array<2 x vector<2xf32>>
+    spirv.ReturnValue %1 : !spirv.array<2 x vector<2xf32>>
+  }
+
+  // CHECK-LABEL: @splat_array_of_splat_vector_f32
+  spirv.func @splat_array_of_splat_vector_f32() -> (!spirv.array<2 x vector<2xf32>>) "None" {
+    // CHECK: spirv.Constant 2.000000e+00 : f32
+    %0 = spirv.Constant 2.0 : f32
+    // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : vector<2xf32>
+    %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xf32>
+    // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<2 x vector<2xf32>>
+    %2 = spirv.EXT.ConstantCompositeReplicate %1 : !spirv.array<2 x vector<2xf32>>
+    spirv.ReturnValue %2 : !spirv.array<2 x vector<2xf32>>
+  }
+}
diff --git a/mlir/test/Target/SPIRV/spec-constant.mlir b/mlir/test/Target/SPIRV/spec-constant.mlir
index 078d77125b3fd..f434956ab34a3 100644
--- a/mlir/test/Target/SPIRV/spec-constant.mlir
+++ b/mlir/test/Target/SPIRV/spec-constant.mlir
@@ -88,6 +88,33 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
 
 // -----
 
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompositesEXT], [SPV_EXT_replicated_composites]> {
+
+  spirv.SpecConstant @sc_i32_1 = 1 : i32
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_i32 (@sc_i32_1) : !spirv.array<3 x i32>
+  spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_i32 (@sc_i32_1) : !spirv.array<3 x i32>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_i32 (@sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+  spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_i32 (@sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_i32 (@sc_i32_1) : vector<3xi32>
+  spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_i32 (@sc_i32_1) : vector<3 x i32>
+
+  spirv.SpecConstant @sc_f32_1 = 1.0 : f32
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_f32 (@sc_f32_1) : !spirv.array<3 x f32>
+  spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_f32 (@sc_f32_1) : !spirv.array<3 x f32>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_f32 (@sc_f32_1) : !spirv.struct<(f32, f32, f32)>
+  spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_f32 (@sc_f32_1) : !spirv.struct<(f32, f32, f32)>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_f32 (@sc_f32_1) : vector<3xf32>
+  spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_f32 (@sc_f32_1) : vector<3 x f32>
+}
+
+// -----
+
 spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
 
   spirv.SpecConstant @sc_i32_1 = 1 : i32

>From 63dd835a3e47ef4e8ce910cfe65c5b39d736a2c0 Mon Sep 17 00:00:00 2001
From: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Date: Mon, 7 Jul 2025 08:50:06 +0100
Subject: [PATCH 2/8] Addressing code review comments

Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
---
 .../Dialect/SPIRV/IR/SPIRVStructureOps.td     | 15 +++++-------
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        |  6 ++---
 .../SPIRV/Deserialization/Deserializer.cpp    | 24 +++++++++----------
 .../SPIRV/Serialization/SerializeOps.cpp      |  8 +++----
 mlir/test/Target/SPIRV/constant.mlir          |  2 +-
 5 files changed, 26 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 0a5b01fe9e8d0..4f5c5e7e0dd48 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -144,10 +144,9 @@ def SPIRV_EXTConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"ConstantComposite
   }];
 
   let description = [{
-    This op declares a `spiv.EXT.ConstantCompositeReplicate` which represents a
-    splat composite constant i.e. all element of composite constant have the
-    same value. This op will be serialized to SPIR-V `OpConstantCompositeReplicateEXT`.
-    The splat value must come from a non-specialization constant instruction."
+    Represents a splat composite constant i.e., all element of composite constant 
+    have the same value. The splat value must come from a non-specialization constant
+    instruction.
 
     #### Example:
 
@@ -739,11 +738,9 @@ def SPIRV_EXTSpecConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"SpecConstantC
   let summary = "Declare a new replicated composite specialization constant op.";
 
   let description = [{
-    This op declares a `spirv.EXT.SpecConstantCompositeReplicate` which represents
-    a splat specialization composite constant i.e. all element of specialization
-    composite constant have the same value. This op will be serialized to SPIR-V
-    `OpSpecConstantCompositeReplicateEXT`. The splat value must come from a
-    symbol reference of specialization constant instruction.
+    Represents a splat spec composite constant i.e., all element of spec composite
+    constant have the same value. The splat value must come from a symbol reference
+    of spec constant instruction.
 
     #### Example:
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index c42b2d45d53a9..b5c30051192f5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -779,13 +779,13 @@ spirv::EXTConstantCompositeReplicateOp::parse(OpAsmParser &parser,
     return failure();
   }
 
-  if (llvm::isa<TensorType>(compositeType)) {
+  if (isa<TensorType>(compositeType)) {
     if (parser.parseColonType(compositeType))
       return failure();
   }
 
-  auto constType = cast<spirv::CompositeType>(compositeType).getElementType(0);
-  while (auto type = llvm::dyn_cast<spirv::ArrayType>(constType)) {
+  Type constType = cast<spirv::CompositeType>(compositeType).getElementType(0);
+  while (auto type = dyn_cast<spirv::ArrayType>(constType)) {
     constType = type.getElementType();
   }
 
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 2163ccff93c83..40cc8f90cfee5 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -680,10 +680,10 @@ spirv::Deserializer::getConstant(uint32_t id) {
 
 std::optional<std::pair<uint32_t, Type>>
 spirv::Deserializer::getConstantCompositeReplicate(uint32_t id) {
-  auto constIt = constantCompositeReplicateMap.find(id);
-  if (constIt == constantCompositeReplicateMap.end())
-    return std::nullopt;
-  return constIt->getSecond();
+  if (auto it = constantCompositeReplicateMap.find(id);
+      it != constantCompositeReplicateMap.end())
+    return it->second;
+  return std::nullopt;
 }
 
 std::optional<spirv::SpecConstOperationMaterializationInfo>
@@ -1564,7 +1564,6 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
 
 LogicalResult spirv::Deserializer::processConstantCompositeReplicateEXT(
     ArrayRef<uint32_t> operands) {
-
   if (operands.size() != 3) {
     return emitError(
         unknownLoc,
@@ -1585,11 +1584,12 @@ LogicalResult spirv::Deserializer::processConstantCompositeReplicateEXT(
            << operands[0];
   }
 
-  auto resultID = operands[1];
-  auto constantID = operands[2];
+  uint32_t resultID = operands[1];
+  uint32_t constantID = operands[2];
 
-  auto constantInfo = getConstant(constantID);
-  auto replicatedConstantCompositeInfo =
+  std::optional<std::pair<Attribute, Type>> constantInfo =
+      getConstant(constantID);
+  std::optional<std::pair<uint32_t, Type>> replicatedConstantCompositeInfo =
       getConstantCompositeReplicate(constantID);
   if (!constantInfo && !replicatedConstantCompositeInfo) {
     return emitError(unknownLoc,
@@ -1642,7 +1642,6 @@ spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
 
 LogicalResult spirv::Deserializer::processSpecConstantCompositeReplicateEXT(
     ArrayRef<uint32_t> operands) {
-
   if (operands.size() != 3) {
     return emitError(unknownLoc,
                      "OpSpecConstantCompositeReplicateEXT must have "
@@ -1663,10 +1662,11 @@ LogicalResult spirv::Deserializer::processSpecConstantCompositeReplicateEXT(
            << operands[0];
   }
 
-  auto resultID = operands[1];
+  uint32_t resultID = operands[1];
 
   auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
-  auto constituentSpecConstantOp = getSpecConstant(operands[2]);
+  spirv::SpecConstantOp constituentSpecConstantOp =
+      getSpecConstant(operands[2]);
   auto op = opBuilder.create<spirv::EXTSpecConstantCompositeReplicateOp>(
       unknownLoc, TypeAttr::get(resultType), symName,
       SymbolRefAttr::get(constituentSpecConstantOp));
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 89edf6c82b090..02dc25ac92e8c 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -68,7 +68,7 @@ LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
 
 LogicalResult Serializer::processConstantCompositeReplicateOp(
     spirv::EXTConstantCompositeReplicateOp op) {
-  if (auto resultID = prepareConstantCompositeReplicate(op)) {
+  if (uint32_t resultID = prepareConstantCompositeReplicate(op)) {
     valueIDMap[op.getResult()] = resultID;
     return success();
   }
@@ -135,14 +135,14 @@ LogicalResult Serializer::processSpecConstantCompositeReplicateOp(
   }
 
   auto constituent = dyn_cast<FlatSymbolRefAttr>(op.getConstituent());
-  auto constituentName = constituent.getValue();
-  auto constituentID = getSpecConstID(constituentName);
+  StringRef constituentName = constituent.getValue();
+  uint32_t constituentID = getSpecConstID(constituentName);
   if (!constituentID) {
     return op.emitError("unknown result <id> for replicated spec constant ")
            << constituentName;
   }
 
-  auto resultID = getNextID();
+  uint32_t resultID = getNextID();
   SmallVector<uint32_t> operands = {typeID, resultID, constituentID};
 
   encodeInstructionInto(typesGlobalValues,
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 0b9b28ae536ea..f521deebe0bb8 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-translate --no-implicit-module -split-input-file --test-spirv-roundtrip %s | FileCheck %s
+// RUN: mlir-translate --no-implicit-module --split-input-file --test-spirv-roundtrip %s | FileCheck %s
 
 spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
   // CHECK-LABEL: @bool_const

>From 8c0b6cb6917cb4af5eb483094c15296a8610e029 Mon Sep 17 00:00:00 2001
From: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Date: Mon, 7 Jul 2025 16:47:44 +0100
Subject: [PATCH 3/8] Addressing further code review comments

Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
---
 .../Dialect/SPIRV/IR/SPIRVStructureOps.td     |  1 -
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        | 32 +++++----
 .../SPIRV/Deserialization/DeserializeOps.cpp  | 12 ++--
 .../Target/SPIRV/Serialization/Serializer.cpp |  4 +-
 mlir/test/Dialect/SPIRV/IR/availability.mlir  | 14 ++++
 mlir/test/Dialect/SPIRV/IR/structure-ops.mlir | 65 +++++++++++++++++++
 6 files changed, 108 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 4f5c5e7e0dd48..e6801159931e9 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -767,7 +767,6 @@ def SPIRV_EXTSpecConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"SpecConstantC
   let results = (outs);
 
   let autogenSerialization = 0;
-
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index b5c30051192f5..c34003cf0ad7d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -773,18 +773,22 @@ ParseResult
 spirv::EXTConstantCompositeReplicateOp::parse(OpAsmParser &parser,
                                               OperationState &result) {
   OpAsmParser::UnresolvedOperand constOperand;
-  Type compositeType;
-  if (parser.parseOperand(constOperand) ||
-      parser.parseColonType(compositeType)) {
+  Type resultType;
+  if (parser.parseOperand(constOperand) || parser.parseColonType(resultType)) {
     return failure();
   }
 
-  if (isa<TensorType>(compositeType)) {
-    if (parser.parseColonType(compositeType))
+  if (isa<TensorType>(resultType)) {
+    if (parser.parseColonType(resultType))
       return failure();
   }
 
-  Type constType = cast<spirv::CompositeType>(compositeType).getElementType(0);
+  auto compositeType = dyn_cast_or_null<spirv::CompositeType>(resultType);
+  if (!compositeType)
+    return parser.emitError(parser.getCurrentLocation(),
+                            "result is not a composite type");
+
+  Type constType = compositeType.getElementType(0);
   while (auto type = dyn_cast<spirv::ArrayType>(constType)) {
     constType = type.getElementType();
   }
@@ -805,7 +809,7 @@ LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
     return emitError("result type must be a composite type, but provided ")
            << getType();
 
-  auto constantDefiningOp = getConstant().getDefiningOp();
+  Operation *constantDefiningOp = getConstant().getDefiningOp();
   if (!constantDefiningOp)
     return this->emitOpError("op defining the splat constant not found");
 
@@ -1972,12 +1976,16 @@ LogicalResult spirv::EXTSpecConstantCompositeReplicateOp::verify() {
     return emitError("result type must be a composite type, but provided ")
            << getType();
 
-  auto constituentSpecConstOp =
-      dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
-          (*this)->getParentOp(), this->getConstituent()));
+  Operation *constituentOp = SymbolTable::lookupNearestSymbolFrom(
+      (*this)->getParentOp(), this->getConstituent());
+  if (!constituentOp)
+    return emitError(
+        "splat spec constant reference defining constituent not found");
+
+  auto constituentSpecConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp);
 
-  auto constituentType = constituentSpecConstOp.getDefaultValue().getType();
-  auto compositeElemType = compositeType.getElementType(0);
+  Type constituentType = constituentSpecConstOp.getDefaultValue().getType();
+  Type compositeElemType = compositeType.getElementType(0);
   if (constituentType != compositeElemType)
     return emitError("constituent has incorrect type: expected ")
            << compositeElemType << ", but provided " << constituentType;
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 5f52308b4be35..6c4d7b8cf3f37 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -45,11 +45,12 @@ Value spirv::Deserializer::getValue(uint32_t id) {
     return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
                                                constInfo->first);
   }
-  if (auto constCompositeReplicateInfo = getConstantCompositeReplicate(id)) {
-    auto constantId = constCompositeReplicateInfo->first;
-    auto element = getValue(constantId);
+  if (std::optional<std::pair<uint32_t, Type>> constCompositeReplicateInfo =
+          getConstantCompositeReplicate(id)) {
+    uint32_t constantId = constCompositeReplicateInfo->first;
+    Value constantValue = getValue(constantId);
     return opBuilder.create<spirv::EXTConstantCompositeReplicateOp>(
-        unknownLoc, constCompositeReplicateInfo->second, element);
+        unknownLoc, constCompositeReplicateInfo->second, constantValue);
   }
   if (auto varOp = getGlobalVariable(id)) {
     auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
@@ -62,7 +63,8 @@ Value spirv::Deserializer::getValue(uint32_t id) {
         SymbolRefAttr::get(constOp.getOperation()));
     return referenceOfOp.getReference();
   }
-  if (auto specConstCompositeOp = getSpecConstantComposite(id)) {
+  if (SpecConstantCompositeOp specConstCompositeOp =
+          getSpecConstantComposite(id)) {
     auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
         unknownLoc, specConstCompositeOp.getType(),
         SymbolRefAttr::get(specConstCompositeOp.getOperation()));
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 2e767cd822617..e9dd6da6530f7 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1111,7 +1111,7 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
 
 uint32_t Serializer::prepareConstantCompositeReplicate(
     spirv::EXTConstantCompositeReplicateOp op) {
-  if (auto id = getValueID(op.getResult())) {
+  if (uint32_t id = getValueID(op.getResult())) {
     return id;
   }
 
@@ -1120,7 +1120,7 @@ uint32_t Serializer::prepareConstantCompositeReplicate(
     return 0;
   }
 
-  auto definingOp = op.getConstant().getDefiningOp();
+  Operation *definingOp = op.getConstant().getDefiningOp();
   if (!definingOp) {
     emitError(op.getLoc(), "op defining splat value not found");
     return 0;
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index 64ba8e3fc249e..329f5a8eed9ae 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -278,3 +278,17 @@ func.func @set_mesh_outputs(%0 : i32, %1 : i32) -> () {
   spirv.EXT.SetMeshOutputs %0, %1 : i32, i32
   spirv.Return
 }
+
+//===----------------------------------------------------------------------===//
+// Replicated Composite Constant op
+//===----------------------------------------------------------------------===//
+// CHECK-LABEL: constant_composite_replicate
+func.func @constant_composite_replicate() -> () {
+  // CHECK: min version: v1.0
+  // CHECK: max version: v1.6
+  // CHECK: extensions: [ [SPV_EXT_replicated_composites] ]
+  // CHECK: capabilities: [ [ReplicatedCompositesEXT] ]
+  %0 = spirv.Constant 1 : i32
+  %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xi32>
+  spirv.Return
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 207549afdda94..39f7df10be634 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -165,6 +165,45 @@ func.func @coop_matrix_const_wrong_type() -> () {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.EXT.ConstantCompositeReplicate
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical GLSL450 {
+  spirv.func @ccr() -> i32 "None" {
+    %0 = spirv.Constant 1 : i32
+    // expected-error @+2 {{result is not a composite type}}
+    %1 = spirv.EXT.ConstantCompositeReplicate %0 : i32
+    spirv.ReturnValue %1: i32
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @ccr() -> vector<2xf32> "None" {
+    %0 = spirv.Constant 1 : i32
+    // expected-note at -1 {{prior use here}}
+    // expected-error @+1 {{use of value '%0' expects different type than prior uses: 'f32' vs 'i32'}}
+    %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xf32>
+    spirv.ReturnValue %1: vector<2xf32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @ccr() -> vector<2xi32> "None" {
+    %c1 = spirv.Constant 1 : i32
+    %0 = spirv.IAdd %c1, %c1 : i32
+    // expected-error @+1 {{op defining the splat constant is not a spirv.Constant or a spirv.EXT.ConstantCompositeReplicate}}
+    %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xi32>
+    spirv.ReturnValue %1: vector<2xi32>
+  }
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.EntryPoint
 //===----------------------------------------------------------------------===//
@@ -854,6 +893,32 @@ spirv.module Logical GLSL450 {
   spirv.SpecConstantComposite @scc (@sc1) : !spirv.coopmatrix<8x16xf32, Device, MatrixA>
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.EXT.SpecConstantCompositeReplicate
+//===----------------------------------------------------------------------===//
+
+// -----
+
+spirv.module Logical GLSL450 {
+  // expected-error @+1 {{result type must be a composite type, but provided 'i32'}}
+  spirv.EXT.SpecConstantCompositeReplicate @sccr (@sc_i32_1) : i32
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  // expected-error @+1 {{splat spec constant reference defining constituent not found}}
+  spirv.EXT.SpecConstantCompositeReplicate @sccr (@sc_f32_1) : !spirv.array<3 x i32>
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.SpecConstant @sc_f32_1 = 1.0 : f32
+  // expected-error @+1 {{constituent has incorrect type: expected 'i32', but provided 'f32'}}
+  spirv.EXT.SpecConstantCompositeReplicate @sccr (@sc_f32_1) : !spirv.array<3 x i32>
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.SpecConstantOperation
 //===----------------------------------------------------------------------===//

>From 9a6b19125f95d140f8fd93fd9d5e49da1c0bdfba Mon Sep 17 00:00:00 2001
From: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Date: Wed, 9 Jul 2025 15:45:56 +0100
Subject: [PATCH 4/8] Addressing further code review comments

Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
---
 .../Dialect/SPIRV/IR/SPIRVStructureOps.td     | 18 ++--
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        | 74 +++++++--------
 .../SPIRV/Deserialization/DeserializeOps.cpp  |  7 +-
 .../SPIRV/Deserialization/Deserializer.cpp    | 27 +++---
 .../SPIRV/Deserialization/Deserializer.h      | 15 ++--
 .../SPIRV/Serialization/SerializeOps.cpp      |  3 +-
 .../Target/SPIRV/Serialization/Serializer.cpp | 35 ++++----
 .../Target/SPIRV/Serialization/Serializer.h   | 12 ++-
 mlir/test/Dialect/SPIRV/IR/availability.mlir  |  3 +-
 mlir/test/Dialect/SPIRV/IR/structure-ops.mlir | 35 +++-----
 mlir/test/Target/SPIRV/constant.mlir          | 90 +++++++++----------
 11 files changed, 145 insertions(+), 174 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index e6801159931e9..bfe2eac3b89fc 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -144,21 +144,17 @@ def SPIRV_EXTConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"ConstantComposite
   }];
 
   let description = [{
-    Represents a splat composite constant i.e., all element of composite constant 
-    have the same value. The splat value must come from a non-specialization constant
-    instruction.
+    Represents a splat composite constant i.e., all elements of composite constant
+    have the same value.
 
     #### Example:
 
     ```mlir
-    %0 = spirv.Constant 1 : i32
-    %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xf32>
+    %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : vector<2xi32>
 
-    %2 = spirv.Constant dense<[1, 2]> : vector<2xi32>
-    %3 = spirv.EXT.ConstantCompositeReplicate %2 : !spirv.array<2 x vector<2xi32>>
+    %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<2 x vector<2xi32>>
 
-    %5 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xi32>
-    %6 = spirv.EXT.ConstantCompositeReplicate %5 : !spirv.array<2 x vector<2xi32>>
+    %0 = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
     ```
   }];
 
@@ -170,7 +166,7 @@ def SPIRV_EXTConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"ConstantComposite
   ];
 
   let arguments = (ins
-    SPIRV_Type:$constant
+    AnyAttr:$value
   );
 
   let results = (outs
@@ -738,7 +734,7 @@ def SPIRV_EXTSpecConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"SpecConstantC
   let summary = "Declare a new replicated composite specialization constant op.";
 
   let description = [{
-    Represents a splat spec composite constant i.e., all element of spec composite
+    Represents a splat spec composite constant i.e., all elements of spec composite
     constant have the same value. The splat value must come from a symbol reference
     of spec constant instruction.
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index c34003cf0ad7d..b6f1e1411ddfb 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -772,60 +772,46 @@ void mlir::spirv::AddressOfOp::getAsmResultNames(
 ParseResult
 spirv::EXTConstantCompositeReplicateOp::parse(OpAsmParser &parser,
                                               OperationState &result) {
-  OpAsmParser::UnresolvedOperand constOperand;
+
+  Attribute value;
+  StringRef valueAttrName =
+      spirv::EXTConstantCompositeReplicateOp::getValueAttrName(result.name);
   Type resultType;
-  if (parser.parseOperand(constOperand) || parser.parseColonType(resultType)) {
+
+  if (parser.parseLSquare() ||
+      parser.parseAttribute(value, valueAttrName, result.attributes) ||
+      parser.parseRSquare() || parser.parseColonType(resultType))
     return failure();
-  }
 
-  if (isa<TensorType>(resultType)) {
+  if (isa<NoneType, TensorType>(resultType))
     if (parser.parseColonType(resultType))
       return failure();
-  }
 
-  auto compositeType = dyn_cast_or_null<spirv::CompositeType>(resultType);
-  if (!compositeType)
-    return parser.emitError(parser.getCurrentLocation(),
-                            "result is not a composite type");
-
-  Type constType = compositeType.getElementType(0);
-  while (auto type = dyn_cast<spirv::ArrayType>(constType)) {
-    constType = type.getElementType();
-  }
-
-  if (parser.resolveOperand(constOperand, constType, result.operands))
-    return failure();
+  if (isa<TensorArmType>(resultType))
+    if (parser.parseOptionalColon().succeeded())
+      if (parser.parseType(resultType))
+        return failure();
 
-  return parser.addTypeToList(compositeType, result.types);
+  return parser.addTypeToList(resultType, result.types);
 }
 
 void spirv::EXTConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
-  printer << ' ' << getConstant() << " : " << getType();
+  printer << " [" << getValue() << "] : " << getType();
 }
 
 LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
-  auto compositeType = dyn_cast<spirv::CompositeType>(getType());
-  if (!compositeType)
-    return emitError("result type must be a composite type, but provided ")
-           << getType();
-
-  Operation *constantDefiningOp = getConstant().getDefiningOp();
-  if (!constantDefiningOp)
-    return this->emitOpError("op defining the splat constant not found");
-
-  auto constantOp = dyn_cast_or_null<spirv::ConstantOp>(constantDefiningOp);
-  auto constantCompositeReplicateOp =
-      dyn_cast_or_null<spirv::EXTConstantCompositeReplicateOp>(
-          constantDefiningOp);
-
-  if (!constantOp && !constantCompositeReplicateOp)
-    return this->emitOpError(
-        "op defining the splat constant is not a spirv.Constant or a "
-        "spirv.EXT.ConstantCompositeReplicate");
+  Type valueType = dyn_cast<TypedAttr>(getValue()).getType();
+  Type compositeElementType =
+      dyn_cast<spirv::CompositeType>(getType()).getElementType(0);
+  while (compositeElementType != valueType &&
+         isa<spirv::CompositeType>(compositeElementType)) {
+    compositeElementType =
+        cast<spirv::CompositeType>(compositeElementType).getElementType(0);
+  }
 
-  if (constantOp)
-    return verifyConstantType(constantOp, constantOp.getValueAttr(),
-                              constantOp.getType());
+  if (compositeElementType != valueType)
+    return emitError("expected splat element type")
+           << compositeElementType << ", but got: " << valueType;
 
   return success();
 }
@@ -1940,8 +1926,8 @@ spirv::EXTSpecConstantCompositeReplicateOp::parse(OpAsmParser &parser,
                                                   OperationState &result) {
 
   StringAttr compositeName;
-  const char *attrName = "spec_const";
   FlatSymbolRefAttr specConstRef;
+  const char *attrName = "spec_const";
   NamedAttrList attrs;
   Type type;
 
@@ -1985,10 +1971,10 @@ LogicalResult spirv::EXTSpecConstantCompositeReplicateOp::verify() {
   auto constituentSpecConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp);
 
   Type constituentType = constituentSpecConstOp.getDefaultValue().getType();
-  Type compositeElemType = compositeType.getElementType(0);
-  if (constituentType != compositeElemType)
+  Type compositeElementType = compositeType.getElementType(0);
+  if (constituentType != compositeElementType)
     return emitError("constituent has incorrect type: expected ")
-           << compositeElemType << ", but provided " << constituentType;
+           << compositeElementType << ", but provided " << constituentType;
 
   return success();
 }
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 6c4d7b8cf3f37..9fa03725d05ee 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -45,12 +45,11 @@ Value spirv::Deserializer::getValue(uint32_t id) {
     return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
                                                constInfo->first);
   }
-  if (std::optional<std::pair<uint32_t, Type>> constCompositeReplicateInfo =
+  if (std::optional<std::pair<Attribute, Type>> constCompositeReplicateInfo =
           getConstantCompositeReplicate(id)) {
-    uint32_t constantId = constCompositeReplicateInfo->first;
-    Value constantValue = getValue(constantId);
     return opBuilder.create<spirv::EXTConstantCompositeReplicateOp>(
-        unknownLoc, constCompositeReplicateInfo->second, constantValue);
+        unknownLoc, constCompositeReplicateInfo->second,
+        constCompositeReplicateInfo->first);
   }
   if (auto varOp = getGlobalVariable(id)) {
     auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 40cc8f90cfee5..e4dab971dcc14 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -678,7 +678,7 @@ spirv::Deserializer::getConstant(uint32_t id) {
   return constIt->getSecond();
 }
 
-std::optional<std::pair<uint32_t, Type>>
+std::optional<std::pair<Attribute, Type>>
 spirv::Deserializer::getConstantCompositeReplicate(uint32_t id) {
   if (auto it = constantCompositeReplicateMap.find(id);
       it != constantCompositeReplicateMap.end())
@@ -1589,19 +1589,24 @@ LogicalResult spirv::Deserializer::processConstantCompositeReplicateEXT(
 
   std::optional<std::pair<Attribute, Type>> constantInfo =
       getConstant(constantID);
-  std::optional<std::pair<uint32_t, Type>> replicatedConstantCompositeInfo =
-      getConstantCompositeReplicate(constantID);
-  if (!constantInfo && !replicatedConstantCompositeInfo) {
-    return emitError(unknownLoc,
-                     "OpConstantCompositeReplicateEXT operand <id> ")
-           << constantID
-           << " must come from a normal constant or a "
-              "OpConstantCompositeReplicateEXT";
+  if (constantInfo.has_value()) {
+    constantCompositeReplicateMap.try_emplace(
+        resultID, constantInfo.value().first, resultType);
+    return success();
   }
 
-  constantCompositeReplicateMap.try_emplace(resultID, constantID, resultType);
+  std::optional<std::pair<Attribute, Type>> replicatedConstantCompositeInfo =
+      getConstantCompositeReplicate(constantID);
+  if (replicatedConstantCompositeInfo.has_value()) {
+    constantCompositeReplicateMap.try_emplace(
+        resultID, replicatedConstantCompositeInfo.value().first, resultType);
+    return success();
+  }
 
-  return success();
+  return emitError(unknownLoc, "OpConstantCompositeReplicateEXT operand <id> ")
+         << constantID
+         << " must come from a normal constant or a "
+            "OpConstantCompositeReplicateEXT";
 }
 
 LogicalResult
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 1fdecc3e6fe0d..20482bd2bf501 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -190,10 +190,9 @@ class Deserializer {
   /// Gets the constant's attribute and type associated with the given <id>.
   std::optional<std::pair<Attribute, Type>> getConstant(uint32_t id);
 
-  /// Gets the pair of id of `spirv.Constant` generating a
-  /// replicated composite and the type of resulting
-  /// `spirv.EXT.ConstantCompositeReplicate` given its <id>.
-  std::optional<std::pair<uint32_t, Type>>
+  /// Gets the replicated composite constant's attribute and type associated
+  /// with the given <id>.
+  std::optional<std::pair<Attribute, Type>>
   getConstantCompositeReplicate(uint32_t id);
 
   /// Gets the info needed to materialize the spec constant operation op
@@ -572,7 +571,7 @@ class Deserializer {
   /// (and type) here. Later when it's used, we materialize the constant.
   DenseMap<uint32_t, std::pair<Attribute, Type>> constantMap;
 
-  // Result <id> to replicated constant id and type mapping.
+  // Result <id> to replicated constant attribute and type mapping.
   ///
   /// In the SPIR-V binary format, OpConstantCompositeReplicateEXT is placed in
   /// the module and shared by instructions at module level and in subsequent
@@ -580,9 +579,9 @@ class Deserializer {
   /// it's used in the function. So when seeing a
   /// OpConstantCompositeReplicateEXT in the binary format, we don't immediately
   /// emit a `spirv.EXT.ConstantCompositeReplicate` op into the module, we keep
-  /// the id of its operand (the splat constant) and type) here. Later when it's
-  /// used, we materialize the `spirv.EXT.ConstantCompositeReplicate`.
-  DenseMap<uint32_t, std::pair<uint32_t, Type>> constantCompositeReplicateMap;
+  /// the id of its value and type here. Later when it's used, we materialize
+  /// the `spirv.EXT.ConstantCompositeReplicate`.
+  DenseMap<uint32_t, std::pair<Attribute, Type>> constantCompositeReplicateMap;
 
   // Result <id> to spec constant mapping.
   DenseMap<uint32_t, spirv::SpecConstantOp> specConstMap;
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 02dc25ac92e8c..b722933ecdade 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -68,7 +68,8 @@ LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
 
 LogicalResult Serializer::processConstantCompositeReplicateOp(
     spirv::EXTConstantCompositeReplicateOp op) {
-  if (uint32_t resultID = prepareConstantCompositeReplicate(op)) {
+  if (uint32_t resultID = prepareConstantCompositeReplicate(
+          op.getLoc(), op.getType(), op.getValue())) {
     valueIDMap[op.getResult()] = resultID;
     return success();
   }
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index e9dd6da6530f7..e5ea01f6ce2a9 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1109,43 +1109,38 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
   return resultID;
 }
 
-uint32_t Serializer::prepareConstantCompositeReplicate(
-    spirv::EXTConstantCompositeReplicateOp op) {
-  if (uint32_t id = getValueID(op.getResult())) {
+uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
+                                                       Type resultType,
+                                                       Attribute valueAttr) {
+
+  std::pair<Attribute, Type> valueTypePair{valueAttr, resultType};
+  if (uint32_t id = getConstantCompositeReplicateID(valueTypePair)) {
     return id;
   }
 
   uint32_t typeID = 0;
-  if (failed(processType(op.getLoc(), op.getType(), typeID))) {
+  if (failed(processType(loc, resultType, typeID))) {
     return 0;
   }
 
-  Operation *definingOp = op.getConstant().getDefiningOp();
-  if (!definingOp) {
-    emitError(op.getLoc(), "op defining splat value not found");
-    return 0;
-  }
+  auto elementType = dyn_cast<CompositeType>(resultType).getElementType(0);
+  Type valueType = dyn_cast<TypedAttr>(valueAttr).getType();
 
-  uint32_t operandID;
-  if (auto constantOp = dyn_cast_or_null<spirv::ConstantOp>(definingOp)) {
-    operandID = getConstantID(constantOp.getValue());
-
-  } else if (auto constantCompositeReplicateOp =
-                 dyn_cast_or_null<spirv::EXTConstantCompositeReplicateOp>(
-                     definingOp)) {
-    operandID = prepareConstantCompositeReplicate(constantCompositeReplicateOp);
+  uint32_t constandID;
+  if (elementType == valueType) {
+    constandID = prepareConstant(loc, elementType, valueAttr);
   } else {
-    emitError(op.getLoc(), "operand op type not supported");
-    return 0;
+    constandID = prepareConstantCompositeReplicate(loc, elementType, valueAttr);
   }
 
   uint32_t resultID = getNextID();
-  SmallVector<uint32_t> operands = {typeID, resultID, operandID};
+  SmallVector<uint32_t> operands = {typeID, resultID, constandID};
 
   encodeInstructionInto(typesGlobalValues,
                         spirv::Opcode::OpConstantCompositeReplicateEXT,
                         operands);
 
+  constCompositeReplicateIDMap[valueTypePair] = resultID;
   return resultID;
 }
 
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index 183af41172be7..7047869bca4cd 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -197,6 +197,11 @@ class Serializer {
     return constIDMap.lookup(value);
   }
 
+  uint32_t getConstantCompositeReplicateID(
+      std::pair<Attribute, Type> valueTypePair) const {
+    return constCompositeReplicateIDMap.lookup(valueTypePair);
+  }
+
   /// Main dispatch method for processing a constant with the given `constType`
   /// and `valueAttr`. `constType` is needed here because we can interpret the
   /// `valueAttr` as a different type than the type of `valueAttr` itself; for
@@ -239,8 +244,8 @@ class Serializer {
   /// Prepares `spirv.EXTConstantCompositeReplicateOp` serialization. This
   /// method emits OpConstantCompositeReplicateEXT and returns the result <id>
   /// associated with it.
-  uint32_t
-  prepareConstantCompositeReplicate(spirv::EXTConstantCompositeReplicateOp op);
+  uint32_t prepareConstantCompositeReplicate(Location loc, Type resultType,
+                                             Attribute valueAttr);
 
   //===--------------------------------------------------------------------===//
   // Control flow
@@ -401,6 +406,9 @@ class Serializer {
   /// Map from constant values to their <id>s.
   DenseMap<Attribute, uint32_t> constIDMap;
 
+  /// Map from a replicated composite constant's value and type to their <id>s.
+  DenseMap<std::pair<Attribute, Type>, uint32_t> constCompositeReplicateIDMap;
+
   /// Map from specialization constant names to their <id>s.
   llvm::StringMap<uint32_t> specConstIDMap;
 
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index 329f5a8eed9ae..1640862d07d5c 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -288,7 +288,6 @@ func.func @constant_composite_replicate() -> () {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ [SPV_EXT_replicated_composites] ]
   // CHECK: capabilities: [ [ReplicatedCompositesEXT] ]
-  %0 = spirv.Constant 1 : i32
-  %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xi32>
+  %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : vector<2xi32>
   spirv.Return
 }
\ No newline at end of file
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 39f7df10be634..5da26111e9917 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -169,37 +169,26 @@ func.func @coop_matrix_const_wrong_type() -> () {
 // spirv.EXT.ConstantCompositeReplicate
 //===----------------------------------------------------------------------===//
 
-spirv.module Logical GLSL450 {
-  spirv.func @ccr() -> i32 "None" {
-    %0 = spirv.Constant 1 : i32
-    // expected-error @+2 {{result is not a composite type}}
-    %1 = spirv.EXT.ConstantCompositeReplicate %0 : i32
-    spirv.ReturnValue %1: i32
-  }
+func.func @ccr_result_not_composite() -> () {
+  // expected-error @+1 {{op result #0 must be vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16 or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type, but got 'i32'}}
+  %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : i32
+  return
 }
 
 // -----
 
-spirv.module Logical GLSL450 {
-  spirv.func @ccr() -> vector<2xf32> "None" {
-    %0 = spirv.Constant 1 : i32
-    // expected-note at -1 {{prior use here}}
-    // expected-error @+1 {{use of value '%0' expects different type than prior uses: 'f32' vs 'i32'}}
-    %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xf32>
-    spirv.ReturnValue %1: vector<2xf32>
-  }
+func.func @ccr_wrong_splat_type() -> () {
+  // expected-error @+1 {{expected splat element type'f32', but got: 'i32'}}
+  %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : vector<2xf32>
+  return
 }
 
 // -----
 
-spirv.module Logical GLSL450 {
-  spirv.func @ccr() -> vector<2xi32> "None" {
-    %c1 = spirv.Constant 1 : i32
-    %0 = spirv.IAdd %c1, %c1 : i32
-    // expected-error @+1 {{op defining the splat constant is not a spirv.Constant or a spirv.EXT.ConstantCompositeReplicate}}
-    %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xi32>
-    spirv.ReturnValue %1: vector<2xi32>
-  }
+func.func @ccr_wrong_splat_type() -> () {
+  // expected-error @+1 {{expected splat element type'i32', but got: 'vector<2xi32>'}}
+  %0 = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x !spirv.array<3 x i32>>
+  return
 }
 
 // -----
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index f521deebe0bb8..809065fdc651d 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -313,77 +313,71 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
 
   // CHECK-LABEL: @splat_vector_i32
   spirv.func @splat_vector_i32() -> (vector<3xi32>) "None" {
-    // CHECK: spirv.Constant 1 : i32
-    %0 = spirv.Constant 1 : i32
-    // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : vector<3xi32>
-    %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<3xi32>
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [1 : i32] : vector<3xi32>
+    %1 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : vector<3xi32>
     spirv.ReturnValue %1 : vector<3xi32>
   }
 
   // CHECK-LABEL: @splat_array_of_i32
   spirv.func @splat_array_of_i32() -> (!spirv.array<3 x i32>) "None" {
-    // CHECK: spirv.Constant 1 : i32
-    %0 = spirv.Constant 1 : i32
-    // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<3 x i32>
-    %1 = spirv.EXT.ConstantCompositeReplicate %0 : !spirv.array<3 x i32>
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32>
+    %1 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32>
     spirv.ReturnValue %1 : !spirv.array<3 x i32>
   }
 
   // CHECK-LABEL: @splat_array_of_vectors_of_i32
-  spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" {
-    // CHECK: spirv.Constant dense<[1, 2]> : vector<2xi32>
-    %0 = spirv.Constant dense<[1, 2]> : vector<2xi32>
-    // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<2 x vector<2xi32>>
-    %1 = spirv.EXT.ConstantCompositeReplicate %0 : !spirv.array<2 x vector<2xi32>>
-    spirv.ReturnValue %1 : !spirv.array<2 x vector<2xi32>>
+  spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<3 x vector<2xi32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<3 x vector<2xi32>>
+    %0 = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<3 x vector<2xi32>>
+    spirv.ReturnValue %0 : !spirv.array<3 x vector<2xi32>>
   }
 
-  // CHECK-LABEL: @splat_array_of_splat_vector_i32
-  spirv.func @splat_array_of_splat_vector_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" {
-    // CHECK: spirv.Constant 2 : i32
-    %0 = spirv.Constant 2 : i32
-    // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : vector<2xi32>
-    %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xi32>
-    // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<2 x vector<2xi32>>
-    %2 = spirv.EXT.ConstantCompositeReplicate %1 : !spirv.array<2 x vector<2xi32>>
-    spirv.ReturnValue %2 : !spirv.array<2 x vector<2xi32>>
+  // CHECK-LABEL: @splat_array_of_splat_vectors_of_i32
+  spirv.func @splat_array_of_splat_vectors_of_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>>
+    %0 = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
+  }
+
+  // CHECK-LABEL: @splat_tensor_of_i32
+  spirv.func @splat_tensor_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
+    %0 = spirv.EXT.ConstantCompositeReplicate [3 : i32] : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
   }
 
   // CHECK-LABEL: @splat_vector_f32
   spirv.func @splat_vector_f32() -> (vector<3xf32>) "None" {
-    // CHECK: spirv.Constant 1.000000e+00 : f32
-    %0 = spirv.Constant 1.0 : f32
-    // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : vector<3xf32>
-    %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<3xf32>
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : vector<3xf32>
+    %1 = spirv.EXT.ConstantCompositeReplicate [1.0 : f32] : vector<3xf32>
     spirv.ReturnValue %1 : vector<3xf32>
   }
 
   // CHECK-LABEL: @splat_array_of_f32
   spirv.func @splat_array_of_f32() -> (!spirv.array<3 x f32>) "None" {
-    // CHECK: spirv.Constant 1.000000e+00 : f32
-    %0 = spirv.Constant 1.0 : f32
-    // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<3 x f32>
-    %1 = spirv.EXT.ConstantCompositeReplicate %0 : !spirv.array<3 x f32>
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<3 x f32>
+    %1 = spirv.EXT.ConstantCompositeReplicate [1.0 : f32] : !spirv.array<3 x f32>
     spirv.ReturnValue %1 : !spirv.array<3 x f32>
   }
 
   // CHECK-LABEL: @splat_array_of_vectors_of_f32
-  spirv.func @splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" {
-    // CHECK: spirv.Constant dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>
-    %0 = spirv.Constant dense<[1.0, 2.0]> : vector<2xf32>
-    // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<2 x vector<2xf32>>
-    %1 = spirv.EXT.ConstantCompositeReplicate %0 : !spirv.array<2 x vector<2xf32>>
-    spirv.ReturnValue %1 : !spirv.array<2 x vector<2xf32>>
-  }
-
-  // CHECK-LABEL: @splat_array_of_splat_vector_f32
-  spirv.func @splat_array_of_splat_vector_f32() -> (!spirv.array<2 x vector<2xf32>>) "None" {
-    // CHECK: spirv.Constant 2.000000e+00 : f32
-    %0 = spirv.Constant 2.0 : f32
-    // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : vector<2xf32>
-    %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xf32>
-    // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<2 x vector<2xf32>>
-    %2 = spirv.EXT.ConstantCompositeReplicate %1 : !spirv.array<2 x vector<2xf32>>
-    spirv.ReturnValue %2 : !spirv.array<2 x vector<2xf32>>
+  spirv.func @splat_array_of_vectors_of_f32() -> (!spirv.array<3 x vector<2xf32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>] : !spirv.array<3 x vector<2xf32>>
+    %0 = spirv.EXT.ConstantCompositeReplicate [dense<[1.0, 2.0]> : vector<2xf32>] : !spirv.array<3 x vector<2xf32>>
+    spirv.ReturnValue %0 : !spirv.array<3 x vector<2xf32>>
+  }
+
+  // CHECK-LABEL: @splat_array_of_splat_vectors_of_f32
+  spirv.func @splat_array_of_splat_vectors_of_f32() -> (!spirv.array<2 x vector<2xf32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.array<2 x vector<2xf32>>
+    %0 = spirv.EXT.ConstantCompositeReplicate [2.0 : f32] : !spirv.array<2 x vector<2xf32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
+  }
+
+  // CHECK-LABEL: @splat_tensor_of_f32
+  spirv.func @splat_tensor_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
+    %0 = spirv.EXT.ConstantCompositeReplicate [3.0 : f32] : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
   }
 }

>From 1f07b83335381df9994e06d011cb287b6194a018 Mon Sep 17 00:00:00 2001
From: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Date: Wed, 9 Jul 2025 20:53:28 +0100
Subject: [PATCH 5/8] Addressed further code review comments

Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
---
 .../Dialect/SPIRV/IR/SPIRVStructureOps.td     |  6 ++--
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        | 10 ++-----
 .../SPIRV/Deserialization/Deserializer.cpp    | 13 ++++-----
 .../SPIRV/Serialization/SerializeOps.cpp      |  2 +-
 .../Target/SPIRV/Serialization/Serializer.cpp |  3 +-
 mlir/test/Dialect/SPIRV/IR/availability.mlir  |  2 +-
 mlir/test/Dialect/SPIRV/IR/structure-ops.mlir | 28 +++++++++++++++++--
 mlir/test/Target/SPIRV/constant.mlir          | 14 ++++++++++
 8 files changed, 52 insertions(+), 26 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index bfe2eac3b89fc..db3a609e993c8 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -151,10 +151,8 @@ def SPIRV_EXTConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"ConstantComposite
 
     ```mlir
     %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : vector<2xi32>
-
-    %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<2 x vector<2xi32>>
-
-    %0 = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+    %1 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<2 x vector<2xi32>>
+    %2 = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
     ```
   }];
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index b6f1e1411ddfb..3b900f8dda81f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -787,11 +787,6 @@ spirv::EXTConstantCompositeReplicateOp::parse(OpAsmParser &parser,
     if (parser.parseColonType(resultType))
       return failure();
 
-  if (isa<TensorArmType>(resultType))
-    if (parser.parseOptionalColon().succeeded())
-      if (parser.parseType(resultType))
-        return failure();
-
   return parser.addTypeToList(resultType, result.types);
 }
 
@@ -806,11 +801,11 @@ LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
   while (compositeElementType != valueType &&
          isa<spirv::CompositeType>(compositeElementType)) {
     compositeElementType =
-        cast<spirv::CompositeType>(compositeElementType).getElementType(0);
+        dyn_cast<spirv::CompositeType>(compositeElementType).getElementType(0);
   }
 
   if (compositeElementType != valueType)
-    return emitError("expected splat element type")
+    return emitError("expected value attribute type ")
            << compositeElementType << ", but got: " << valueType;
 
   return success();
@@ -1924,7 +1919,6 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() {
 ParseResult
 spirv::EXTSpecConstantCompositeReplicateOp::parse(OpAsmParser &parser,
                                                   OperationState &result) {
-
   StringAttr compositeName;
   FlatSymbolRefAttr specConstRef;
   const char *attrName = "spec_const";
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index e4dab971dcc14..d133d0332e271 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1566,9 +1566,9 @@ LogicalResult spirv::Deserializer::processConstantCompositeReplicateEXT(
     ArrayRef<uint32_t> operands) {
   if (operands.size() != 3) {
     return emitError(
-        unknownLoc,
-        "OpConstantCompositeReplicateEXT must have type <id> and result <id> "
-        "and only one parameter which is <id> of splat constant");
+               unknownLoc,
+               "OpConstantCompositeReplicateEXT expects 3 operands but found ")
+           << operands.size();
   }
 
   Type resultType = getType(operands[0]);
@@ -1648,10 +1648,9 @@ spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
 LogicalResult spirv::Deserializer::processSpecConstantCompositeReplicateEXT(
     ArrayRef<uint32_t> operands) {
   if (operands.size() != 3) {
-    return emitError(unknownLoc,
-                     "OpSpecConstantCompositeReplicateEXT must have "
-                     "type <id> and result <id> and only one parameter which "
-                     "is <id> of splat constant");
+    return emitError(unknownLoc, "OpSpecConstantCompositeReplicateEXT expects "
+                                 "3 operands but found ")
+           << operands.size();
   }
 
   Type resultType = getType(operands[0]);
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index b722933ecdade..ee29f787a73d0 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -144,7 +144,7 @@ LogicalResult Serializer::processSpecConstantCompositeReplicateOp(
   }
 
   uint32_t resultID = getNextID();
-  SmallVector<uint32_t> operands = {typeID, resultID, constituentID};
+  uint32_t operands[] = {typeID, resultID, constituentID};
 
   encodeInstructionInto(typesGlobalValues,
                         spirv::Opcode::OpSpecConstantCompositeReplicateEXT,
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index e5ea01f6ce2a9..d7fb2408fd724 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1112,7 +1112,6 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
 uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
                                                        Type resultType,
                                                        Attribute valueAttr) {
-
   std::pair<Attribute, Type> valueTypePair{valueAttr, resultType};
   if (uint32_t id = getConstantCompositeReplicateID(valueTypePair)) {
     return id;
@@ -1134,7 +1133,7 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
   }
 
   uint32_t resultID = getNextID();
-  SmallVector<uint32_t> operands = {typeID, resultID, constandID};
+  uint32_t operands[] = {typeID, resultID, constandID};
 
   encodeInstructionInto(typesGlobalValues,
                         spirv::Opcode::OpConstantCompositeReplicateEXT,
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index 1640862d07d5c..7c99b6b11b625 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -290,4 +290,4 @@ func.func @constant_composite_replicate() -> () {
   // CHECK: capabilities: [ [ReplicatedCompositesEXT] ]
   %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : vector<2xi32>
   spirv.Return
-}
\ No newline at end of file
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 5da26111e9917..5f7e94c5f9439 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -170,7 +170,7 @@ func.func @coop_matrix_const_wrong_type() -> () {
 //===----------------------------------------------------------------------===//
 
 func.func @ccr_result_not_composite() -> () {
-  // expected-error @+1 {{op result #0 must be vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16 or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type, but got 'i32'}}
+  // expected-error @+1 {{op result #0 must be vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16 or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V tensorArm type, but got 'i32'}}
   %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : i32
   return
 }
@@ -178,7 +178,7 @@ func.func @ccr_result_not_composite() -> () {
 // -----
 
 func.func @ccr_wrong_splat_type() -> () {
-  // expected-error @+1 {{expected splat element type'f32', but got: 'i32'}}
+  // expected-error @+1 {{expected value attribute type 'f32', but got: 'i32'}}
   %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : vector<2xf32>
   return
 }
@@ -186,11 +186,17 @@ func.func @ccr_wrong_splat_type() -> () {
 // -----
 
 func.func @ccr_wrong_splat_type() -> () {
-  // expected-error @+1 {{expected splat element type'i32', but got: 'vector<2xi32>'}}
+  // expected-error @+1 {{expected value attribute type 'i32', but got: 'vector<2xi32>'}}
   %0 = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x !spirv.array<3 x i32>>
   return
 }
 
+func.func @ccr_wrong_splat_type() -> () {
+  // expected-error @+1 {{expected value attribute type 'f32', but got: 'i32'}}
+  %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.arm.tensor<2x3xf32>
+  return
+}
+
 // -----
 
 //===----------------------------------------------------------------------===//
@@ -908,6 +914,22 @@ spirv.module Logical GLSL450 {
   spirv.EXT.SpecConstantCompositeReplicate @sccr (@sc_f32_1) : !spirv.array<3 x i32>
 }
 
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.SpecConstant @sc_f32_1 = 1.0 : f32
+  // expected-error @+1 {{constituent has incorrect type: expected 'i32', but provided 'f32'}}
+  spirv.EXT.SpecConstantCompositeReplicate @sccr (@sc_f32_1) : !spirv.struct<(i32, i32, i32)>
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.SpecConstant @sc_f32_1 = 1.0 : f32
+  // expected-error @+1 {{constituent has incorrect type: expected 'i32', but provided 'f32'}}
+  spirv.EXT.SpecConstantCompositeReplicate @sccr (@sc_f32_1) : !spirv.arm.tensor<2x3xi32>
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.SpecConstantOperation
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 809065fdc651d..6f66d61c01df5 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -346,6 +346,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
   }
 
+  // CHECK-LABEL: @splat_arm_tensor_of_i32
+  spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.arm.tensor<2x3xi32>
+    %0 = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.arm.tensor<2x3xi32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+  }
+
   // CHECK-LABEL: @splat_vector_f32
   spirv.func @splat_vector_f32() -> (vector<3xf32>) "None" {
     // CHECK: spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : vector<3xf32>
@@ -380,4 +387,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     %0 = spirv.EXT.ConstantCompositeReplicate [3.0 : f32] : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
     spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
   }
+
+  // CHECK-LABEL: @splat_arm_tensor_of_f32
+  spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.arm.tensor<2x3xf32>
+    %0 = spirv.EXT.ConstantCompositeReplicate [2.0 : f32] : !spirv.arm.tensor<2x3xf32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+  }
 }

>From a2829b18bfcb78b68d8674cb2d189c92a83e651a Mon Sep 17 00:00:00 2001
From: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Date: Thu, 10 Jul 2025 14:17:00 +0100
Subject: [PATCH 6/8] Addressed further code review comments

Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
---
 mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp | 5 +++++
 mlir/test/Dialect/SPIRV/IR/availability.mlir         | 1 +
 2 files changed, 6 insertions(+)

diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index ee29f787a73d0..d62529b85b3aa 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -136,6 +136,11 @@ LogicalResult Serializer::processSpecConstantCompositeReplicateOp(
   }
 
   auto constituent = dyn_cast<FlatSymbolRefAttr>(op.getConstituent());
+  if (!constituent)
+    return op.emitError(
+               "expected flat symbol reference for constituent instead of ")
+           << op.getConstituent();
+
   StringRef constituentName = constituent.getValue();
   uint32_t constituentID = getSpecConstID(constituentName);
   if (!constituentID) {
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index 7c99b6b11b625..0e7a1afaa77ae 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -282,6 +282,7 @@ func.func @set_mesh_outputs(%0 : i32, %1 : i32) -> () {
 //===----------------------------------------------------------------------===//
 // Replicated Composite Constant op
 //===----------------------------------------------------------------------===//
+
 // CHECK-LABEL: constant_composite_replicate
 func.func @constant_composite_replicate() -> () {
   // CHECK: min version: v1.0

>From c352896ac2a6a7f7a81129b81125876b64e3478e Mon Sep 17 00:00:00 2001
From: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Date: Thu, 10 Jul 2025 17:58:20 +0100
Subject: [PATCH 7/8] Addressed further code review comments

Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
---
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        | 22 ++++++++++++-------
 mlir/test/Dialect/SPIRV/IR/structure-ops.mlir | 13 ++++++++++-
 2 files changed, 26 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 3b900f8dda81f..d9154471d957f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -798,15 +798,21 @@ LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
   Type valueType = dyn_cast<TypedAttr>(getValue()).getType();
   Type compositeElementType =
       dyn_cast<spirv::CompositeType>(getType()).getElementType(0);
-  while (compositeElementType != valueType &&
-         isa<spirv::CompositeType>(compositeElementType)) {
-    compositeElementType =
-        dyn_cast<spirv::CompositeType>(compositeElementType).getElementType(0);
-  }
-
-  if (compositeElementType != valueType)
+  SmallVector<Type, 3> possibleTypes = {compositeElementType};
+  while (auto type = dyn_cast<spirv::CompositeType>(compositeElementType)) {
+    compositeElementType = type.getElementType(0);
+    possibleTypes.push_back(compositeElementType);
+  }
+
+  if (!is_contained(possibleTypes, valueType)) {
+    std::string strTypes;
+    llvm::raw_string_ostream os(strTypes);
+    interleave(
+        possibleTypes, os, [&](Type type) { os << "'" << type << "'"; },
+        " or ");
     return emitError("expected value attribute type ")
-           << compositeElementType << ", but got: " << valueType;
+           << strTypes << ", but got: " << valueType;
+  }
 
   return success();
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 5f7e94c5f9439..99ad2a8a2e64b 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -186,17 +186,28 @@ func.func @ccr_wrong_splat_type() -> () {
 // -----
 
 func.func @ccr_wrong_splat_type() -> () {
-  // expected-error @+1 {{expected value attribute type 'i32', but got: 'vector<2xi32>'}}
+  // expected-error @+1 {{expected value attribute type '!spirv.array<3 x i32>' or 'i32', but got: 'vector<2xi32>'}}
   %0 = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x !spirv.array<3 x i32>>
   return
 }
 
+// -----
+
 func.func @ccr_wrong_splat_type() -> () {
   // expected-error @+1 {{expected value attribute type 'f32', but got: 'i32'}}
   %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.arm.tensor<2x3xf32>
   return
 }
 
+// -----
+
+func.func @ccr_wrong_splat_type() -> () {
+  // expected-error @+1 {{expected value attribute type 'vector<3xi32>' or 'i32', but got: 'vector<2xi32>'}}
+  %0 = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<3xi32>>
+  return
+}
+
+
 // -----
 
 //===----------------------------------------------------------------------===//

>From 4ac5ad4fa81e00cc8ca76611bf300fe9932031f7 Mon Sep 17 00:00:00 2001
From: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Date: Fri, 11 Jul 2025 09:34:50 +0100
Subject: [PATCH 8/8] Addressed further code review comments

Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
---
 .../Dialect/SPIRV/IR/SPIRVStructureOps.td     |  2 +
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        | 43 +++++--------------
 .../Target/SPIRV/Serialization/Serializer.cpp | 10 ++++-
 mlir/test/Target/SPIRV/constant.mlir          | 42 ++++++++++++------
 4 files changed, 50 insertions(+), 47 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index db3a609e993c8..7986025d6ca31 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -172,6 +172,8 @@ def SPIRV_EXTConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"ConstantComposite
   );
 
   let autogenSerialization = 0;
+
+  let assemblyFormat = "` ` `[` $value `]` `:` type($replicated_constant) attr-dict";
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index d9154471d957f..8e56f9e5c5b21 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -769,33 +769,17 @@ void mlir::spirv::AddressOfOp::getAsmResultNames(
 // spirv.EXTConstantCompositeReplicate
 //===----------------------------------------------------------------------===//
 
-ParseResult
-spirv::EXTConstantCompositeReplicateOp::parse(OpAsmParser &parser,
-                                              OperationState &result) {
-
-  Attribute value;
-  StringRef valueAttrName =
-      spirv::EXTConstantCompositeReplicateOp::getValueAttrName(result.name);
-  Type resultType;
-
-  if (parser.parseLSquare() ||
-      parser.parseAttribute(value, valueAttrName, result.attributes) ||
-      parser.parseRSquare() || parser.parseColonType(resultType))
-    return failure();
-
-  if (isa<NoneType, TensorType>(resultType))
-    if (parser.parseColonType(resultType))
-      return failure();
-
-  return parser.addTypeToList(resultType, result.types);
-}
-
-void spirv::EXTConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
-  printer << " [" << getValue() << "] : " << getType();
-}
-
 LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
-  Type valueType = dyn_cast<TypedAttr>(getValue()).getType();
+  Type valueType;
+  if (auto typedAttr = dyn_cast<TypedAttr>(getValue())) {
+    valueType = typedAttr.getType();
+  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) {
+    auto elementType = dyn_cast<TypedAttr>(arrayAttr[0]).getType();
+    valueType = spirv::ArrayType::get(elementType, arrayAttr.size());
+  } else {
+    return emitError("unknown value attribute type");
+  }
+
   Type compositeElementType =
       dyn_cast<spirv::CompositeType>(getType()).getElementType(0);
   SmallVector<Type, 3> possibleTypes = {compositeElementType};
@@ -805,13 +789,8 @@ LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
   }
 
   if (!is_contained(possibleTypes, valueType)) {
-    std::string strTypes;
-    llvm::raw_string_ostream os(strTypes);
-    interleave(
-        possibleTypes, os, [&](Type type) { os << "'" << type << "'"; },
-        " or ");
     return emitError("expected value attribute type ")
-           << strTypes << ", but got: " << valueType;
+           << interleaved(possibleTypes, " or ") << ", but got: " << valueType;
   }
 
   return success();
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index d7fb2408fd724..4ab9d98447601 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1123,7 +1123,15 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
   }
 
   auto elementType = dyn_cast<CompositeType>(resultType).getElementType(0);
-  Type valueType = dyn_cast<TypedAttr>(valueAttr).getType();
+  Type valueType;
+  if (auto typedAttr = dyn_cast<TypedAttr>(valueAttr)) {
+    valueType = typedAttr.getType();
+  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
+    auto elementType = dyn_cast<TypedAttr>(arrayAttr[0]).getType();
+    valueType = spirv::ArrayType::get(elementType, arrayAttr.size());
+  } else {
+    return 0;
+  }
 
   uint32_t constandID;
   if (elementType == valueType) {
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 6f66d61c01df5..76d34c2a96e67 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -332,6 +332,20 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     spirv.ReturnValue %0 : !spirv.array<3 x vector<2xi32>>
   }
 
+  // CHECK-LABEL: @splat_array_of_splat_array_of_i32
+  spirv.func @splat_array_of_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+    // CHECK: %0 = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
+    %0 = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
+  }
+
+  // CHECK-LABEL: @splat_array_of_non_splat_array_of_i32
+  spirv.func @splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+    // CHECK: %0 = spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+    %0 = spirv.EXT.ConstantCompositeReplicate [[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
+  }
+
   // CHECK-LABEL: @splat_array_of_splat_vectors_of_i32
   spirv.func @splat_array_of_splat_vectors_of_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" {
     // CHECK: spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>>
@@ -339,13 +353,6 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
   }
 
-  // CHECK-LABEL: @splat_tensor_of_i32
-  spirv.func @splat_tensor_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
-    // CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
-    %0 = spirv.EXT.ConstantCompositeReplicate [3 : i32] : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
-    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
-  }
-
   // CHECK-LABEL: @splat_arm_tensor_of_i32
   spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
     // CHECK: spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.arm.tensor<2x3xi32>
@@ -366,6 +373,20 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     %1 = spirv.EXT.ConstantCompositeReplicate [1.0 : f32] : !spirv.array<3 x f32>
     spirv.ReturnValue %1 : !spirv.array<3 x f32>
   }
+ 
+  // CHECK-LABEL: @splat_array_of_splat_array_of_f32
+  spirv.func @splat_array_of_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+    // CHECK: %0 = spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
+    %0 = spirv.EXT.ConstantCompositeReplicate [3.0 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
+  }
+
+  // CHECK-LABEL: @splat_array_of_non_splat_array_of_f32
+  spirv.func @splat_array_of_non_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+    // CHECK: %0 = spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+    %0 = spirv.EXT.ConstantCompositeReplicate [[1.0 : f32, 2.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
+  }
 
   // CHECK-LABEL: @splat_array_of_vectors_of_f32
   spirv.func @splat_array_of_vectors_of_f32() -> (!spirv.array<3 x vector<2xf32>>) "None" {
@@ -381,13 +402,6 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
   }
 
-  // CHECK-LABEL: @splat_tensor_of_f32
-  spirv.func @splat_tensor_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
-    // CHECK: spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
-    %0 = spirv.EXT.ConstantCompositeReplicate [3.0 : f32] : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
-    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
-  }
-
   // CHECK-LABEL: @splat_arm_tensor_of_f32
   spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
     // CHECK: spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.arm.tensor<2x3xf32>



More information about the Mlir-commits mailing list