[Mlir-commits] [mlir] [mlir][spirv] Add basic support for SPV_EXT_replicated_composites (PR #147067)
Mohammadreza Ameri Mahabadian
llvmlistbot at llvm.org
Wed Jul 9 08:09:00 PDT 2025
https://github.com/mahabadm updated https://github.com/llvm/llvm-project/pull/147067
>From 955f819da1389d7e038fcea643bd9af287368f3d Mon Sep 17 00:00:00 2001
From: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Date: Sun, 29 Jun 2025 07:25:48 +0100
Subject: [PATCH 1/4] [mlir][spirv] Add basic support for
SPV_EXT_replicated_composites
This patch introduces two new ops to the SPIR-V dialect:
- `spirv.EXT.ConstantCompositeReplicate`
- `spirv.EXT.SpecConstantCompositeReplicate`
These ops represent composite constants and specialization constants, respectively, constructed by replicating a single splat constant across all elements. They correspond to `SPV_EXT_replicated_composites` extension instructions:
- `OpConstantCompositeReplicatedEXT`
- `OpSpecConstantCompositeReplicatedEXT`
No transformation to these new ops has been introduced in this patch.
This approach is chosen as per the discussions on RFC https://discourse.llvm.org/t/rfc-basic-support-for-spv-ext-replicated-composites-in-mlir-spir-v-compile-time-constant-lowering-only/86987
Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 15 ++-
.../Dialect/SPIRV/IR/SPIRVStructureOps.td | 86 +++++++++++++
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 119 ++++++++++++++++++
.../SPIRV/Deserialization/DeserializeOps.cpp | 23 +++-
.../SPIRV/Deserialization/Deserializer.cpp | 93 +++++++++++++-
.../SPIRV/Deserialization/Deserializer.h | 39 ++++++
.../SPIRV/Serialization/SerializeOps.cpp | 36 ++++++
.../Target/SPIRV/Serialization/Serializer.cpp | 46 +++++++
.../Target/SPIRV/Serialization/Serializer.h | 12 ++
mlir/test/Target/SPIRV/constant.mlir | 83 +++++++++++-
mlir/test/Target/SPIRV/spec-constant.mlir | 27 ++++
11 files changed, 570 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index d874817e6888d..6c24dbc613c82 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -359,6 +359,7 @@ def SPV_EXT_shader_atomic_float_min_max : I32EnumAttrCase<"SPV_EXT_shader_atomi
def SPV_EXT_shader_image_int64 : I32EnumAttrCase<"SPV_EXT_shader_image_int64", 1010>;
def SPV_EXT_shader_atomic_float16_add : I32EnumAttrCase<"SPV_EXT_shader_atomic_float16_add", 1011>;
def SPV_EXT_mesh_shader : I32EnumAttrCase<"SPV_EXT_mesh_shader", 1012>;
+def SPV_EXT_replicated_composites : I32EnumAttrCase<"SPV_EXT_replicated_composites", 1013>;
def SPV_AMD_gpu_shader_half_float_fetch : I32EnumAttrCase<"SPV_AMD_gpu_shader_half_float_fetch", 2000>;
def SPV_AMD_shader_ballot : I32EnumAttrCase<"SPV_AMD_shader_ballot", 2001>;
@@ -446,7 +447,7 @@ def SPIRV_ExtensionAttr :
SPV_EXT_shader_stencil_export, SPV_EXT_shader_viewport_index_layer,
SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
- SPV_EXT_mesh_shader,
+ SPV_EXT_mesh_shader, SPV_EXT_replicated_composites,
SPV_ARM_tensors,
SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
@@ -849,6 +850,12 @@ def SPIRV_C_CooperativeMatrixKHR : I32EnumAttrCase<"Coope
MinVersion<SPIRV_V_1_6>
];
}
+def SPIRV_C_ReplicatedCompositesEXT : I32EnumAttrCase<"ReplicatedCompositesEXT", 6024> {
+ list<Availability> availability = [
+ Extension<[SPV_EXT_replicated_composites]>,
+ MinVersion<SPIRV_V_1_0>
+ ];
+}
def SPIRV_C_BitInstructions : I32EnumAttrCase<"BitInstructions", 6025> {
list<Availability> availability = [
Extension<[SPV_KHR_bit_instructions]>
@@ -1500,7 +1507,7 @@ def SPIRV_CapabilityAttr :
SPIRV_C_USMStorageClassesINTEL, SPIRV_C_IOPipesINTEL, SPIRV_C_BlockingPipesINTEL,
SPIRV_C_FPGARegINTEL, SPIRV_C_DotProductInputAll,
SPIRV_C_DotProductInput4x8BitPacked, SPIRV_C_DotProduct, SPIRV_C_RayCullMaskKHR,
- SPIRV_C_CooperativeMatrixKHR,
+ SPIRV_C_CooperativeMatrixKHR, SPIRV_C_ReplicatedCompositesEXT,
SPIRV_C_BitInstructions, SPIRV_C_AtomicFloat32AddEXT, SPIRV_C_AtomicFloat64AddEXT,
SPIRV_C_LongConstantCompositeINTEL, SPIRV_C_OptNoneINTEL,
SPIRV_C_AtomicFloat16AddEXT, SPIRV_C_DebugInfoModuleINTEL, SPIRV_C_SplitBarrierINTEL,
@@ -4564,6 +4571,8 @@ def SPIRV_OC_OpCooperativeMatrixLoadKHR : I32EnumAttrCase<"OpCooperativeMa
def SPIRV_OC_OpCooperativeMatrixStoreKHR : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
def SPIRV_OC_OpCooperativeMatrixMulAddKHR : I32EnumAttrCase<"OpCooperativeMatrixMulAddKHR", 4459>;
def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
+def SPIRV_OC_OpConstantCompositeReplicateEXT : I32EnumAttrCase<"OpConstantCompositeReplicateEXT", 4461>;
+def SPIRV_OC_OpSpecConstantCompositeReplicateEXT : I32EnumAttrCase<"OpSpecConstantCompositeReplicateEXT", 4462>;
def SPIRV_OC_OpEmitMeshTasksEXT : I32EnumAttrCase<"OpEmitMeshTasksEXT", 5294>;
def SPIRV_OC_OpSetMeshOutputsEXT : I32EnumAttrCase<"OpSetMeshOutputsEXT", 5295>;
def SPIRV_OC_OpSubgroupBlockReadINTEL : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>;
@@ -4672,6 +4681,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
SPIRV_OC_OpCooperativeMatrixLoadKHR, SPIRV_OC_OpCooperativeMatrixStoreKHR,
SPIRV_OC_OpCooperativeMatrixMulAddKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
+ SPIRV_OC_OpConstantCompositeReplicateEXT,
+ SPIRV_OC_OpSpecConstantCompositeReplicateEXT,
SPIRV_OC_OpEmitMeshTasksEXT, SPIRV_OC_OpSetMeshOutputsEXT,
SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index c5a85f881b35e..0a5b01fe9e8d0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -135,6 +135,52 @@ def SPIRV_ConstantOp : SPIRV_Op<"Constant",
let autogenSerialization = 0;
}
+
+// -----
+
+def SPIRV_EXTConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"ConstantCompositeReplicate", [Pure]> {
+ let summary = [{
+ Declare a new replicated composite constant op.
+ }];
+
+ let description = [{
+ This op declares a `spiv.EXT.ConstantCompositeReplicate` which represents a
+ splat composite constant i.e. all element of composite constant have the
+ same value. This op will be serialized to SPIR-V `OpConstantCompositeReplicateEXT`.
+ The splat value must come from a non-specialization constant instruction."
+
+ #### Example:
+
+ ```mlir
+ %0 = spirv.Constant 1 : i32
+ %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xf32>
+
+ %2 = spirv.Constant dense<[1, 2]> : vector<2xi32>
+ %3 = spirv.EXT.ConstantCompositeReplicate %2 : !spirv.array<2 x vector<2xi32>>
+
+ %5 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xi32>
+ %6 = spirv.EXT.ConstantCompositeReplicate %5 : !spirv.array<2 x vector<2xi32>>
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_0>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_EXT_replicated_composites]>,
+ Capability<[SPIRV_C_ReplicatedCompositesEXT]>
+ ];
+
+ let arguments = (ins
+ SPIRV_Type:$constant
+ );
+
+ let results = (outs
+ SPIRV_Composite:$replicated_constant
+ );
+
+ let autogenSerialization = 0;
+}
+
// -----
def SPIRV_EntryPointOp : SPIRV_Op<"EntryPoint", [InModuleScope]> {
@@ -689,6 +735,46 @@ def SPIRV_SpecConstantCompositeOp : SPIRV_Op<"SpecConstantComposite", [
// -----
+def SPIRV_EXTSpecConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"SpecConstantCompositeReplicate", [InModuleScope, Symbol]> {
+ let summary = "Declare a new replicated composite specialization constant op.";
+
+ let description = [{
+ This op declares a `spirv.EXT.SpecConstantCompositeReplicate` which represents
+ a splat specialization composite constant i.e. all element of specialization
+ composite constant have the same value. This op will be serialized to SPIR-V
+ `OpSpecConstantCompositeReplicateEXT`. The splat value must come from a
+ symbol reference of specialization constant instruction.
+
+ #### Example:
+
+ ```mlir
+ spirv.SpecConstant @sc_i32_1 = 1 : i32
+ spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_i32 (@sc_i32_1) : !spirv.array<3 x i32>
+ spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_i32 (@sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_0>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_EXT_replicated_composites]>,
+ Capability<[SPIRV_C_ReplicatedCompositesEXT]>
+ ];
+
+ let arguments = (ins
+ TypeAttr:$type,
+ StrAttr:$sym_name,
+ SymbolRefAttr:$constituent
+ );
+
+ let results = (outs);
+
+ let autogenSerialization = 0;
+
+}
+
+// -----
+
def SPIRV_SpecConstantOperationOp : SPIRV_Op<"SpecConstantOperation", [
Pure, InFunctionScope,
SingleBlockImplicitTerminator<"YieldOp">]> {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index eb2974d62fdd1..c42b2d45d53a9 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -765,6 +765,67 @@ void mlir::spirv::AddressOfOp::getAsmResultNames(
setNameFn(getResult(), specialName.str());
}
+//===----------------------------------------------------------------------===//
+// spirv.EXTConstantCompositeReplicate
+//===----------------------------------------------------------------------===//
+
+ParseResult
+spirv::EXTConstantCompositeReplicateOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::UnresolvedOperand constOperand;
+ Type compositeType;
+ if (parser.parseOperand(constOperand) ||
+ parser.parseColonType(compositeType)) {
+ return failure();
+ }
+
+ if (llvm::isa<TensorType>(compositeType)) {
+ if (parser.parseColonType(compositeType))
+ return failure();
+ }
+
+ auto constType = cast<spirv::CompositeType>(compositeType).getElementType(0);
+ while (auto type = llvm::dyn_cast<spirv::ArrayType>(constType)) {
+ constType = type.getElementType();
+ }
+
+ if (parser.resolveOperand(constOperand, constType, result.operands))
+ return failure();
+
+ return parser.addTypeToList(compositeType, result.types);
+}
+
+void spirv::EXTConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
+ printer << ' ' << getConstant() << " : " << getType();
+}
+
+LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
+ auto compositeType = dyn_cast<spirv::CompositeType>(getType());
+ if (!compositeType)
+ return emitError("result type must be a composite type, but provided ")
+ << getType();
+
+ auto constantDefiningOp = getConstant().getDefiningOp();
+ if (!constantDefiningOp)
+ return this->emitOpError("op defining the splat constant not found");
+
+ auto constantOp = dyn_cast_or_null<spirv::ConstantOp>(constantDefiningOp);
+ auto constantCompositeReplicateOp =
+ dyn_cast_or_null<spirv::EXTConstantCompositeReplicateOp>(
+ constantDefiningOp);
+
+ if (!constantOp && !constantCompositeReplicateOp)
+ return this->emitOpError(
+ "op defining the splat constant is not a spirv.Constant or a "
+ "spirv.EXT.ConstantCompositeReplicate");
+
+ if (constantOp)
+ return verifyConstantType(constantOp, constantOp.getValueAttr(),
+ constantOp.getType());
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spirv.ControlBarrierOp
//===----------------------------------------------------------------------===//
@@ -1866,6 +1927,64 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// spirv.EXTSpecConstantCompositeReplicateOp
+//===----------------------------------------------------------------------===//
+
+ParseResult
+spirv::EXTSpecConstantCompositeReplicateOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+
+ StringAttr compositeName;
+ const char *attrName = "spec_const";
+ FlatSymbolRefAttr specConstRef;
+ NamedAttrList attrs;
+ Type type;
+
+ if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
+ result.attributes) ||
+ parser.parseLParen() ||
+ parser.parseAttribute(specConstRef, Type(), attrName, attrs) ||
+ parser.parseRParen() || parser.parseColonType(type))
+ return failure();
+
+ StringAttr compositeSpecConstituentName =
+ spirv::EXTSpecConstantCompositeReplicateOp::getConstituentAttrName(
+ result.name);
+ result.addAttribute(compositeSpecConstituentName, specConstRef);
+
+ StringAttr typeAttrName =
+ spirv::EXTSpecConstantCompositeReplicateOp::getTypeAttrName(result.name);
+ result.addAttribute(typeAttrName, TypeAttr::get(type));
+
+ return success();
+}
+
+void spirv::EXTSpecConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
+ printer << " ";
+ printer.printSymbolName(getSymName());
+ printer << " (" << this->getConstituent() << ") : " << getType();
+}
+
+LogicalResult spirv::EXTSpecConstantCompositeReplicateOp::verify() {
+ auto compositeType = dyn_cast<spirv::CompositeType>(getType());
+ if (!compositeType)
+ return emitError("result type must be a composite type, but provided ")
+ << getType();
+
+ auto constituentSpecConstOp =
+ dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
+ (*this)->getParentOp(), this->getConstituent()));
+
+ auto constituentType = constituentSpecConstOp.getDefaultValue().getType();
+ auto compositeElemType = compositeType.getElementType(0);
+ if (constituentType != compositeElemType)
+ return emitError("constituent has incorrect type: expected ")
+ << compositeElemType << ", but provided " << constituentType;
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spirv.SpecConstantOperation
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 55d6a380d0bff..5f52308b4be35 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -45,6 +45,12 @@ Value spirv::Deserializer::getValue(uint32_t id) {
return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
constInfo->first);
}
+ if (auto constCompositeReplicateInfo = getConstantCompositeReplicate(id)) {
+ auto constantId = constCompositeReplicateInfo->first;
+ auto element = getValue(constantId);
+ return opBuilder.create<spirv::EXTConstantCompositeReplicateOp>(
+ unknownLoc, constCompositeReplicateInfo->second, element);
+ }
if (auto varOp = getGlobalVariable(id)) {
auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation()));
@@ -56,10 +62,17 @@ Value spirv::Deserializer::getValue(uint32_t id) {
SymbolRefAttr::get(constOp.getOperation()));
return referenceOfOp.getReference();
}
- if (auto constCompositeOp = getSpecConstantComposite(id)) {
+ if (auto specConstCompositeOp = getSpecConstantComposite(id)) {
+ auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
+ unknownLoc, specConstCompositeOp.getType(),
+ SymbolRefAttr::get(specConstCompositeOp.getOperation()));
+ return referenceOfOp.getReference();
+ }
+ if (auto specConstCompositeReplicateOp =
+ getSpecConstantCompositeReplicate(id)) {
auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
- unknownLoc, constCompositeOp.getType(),
- SymbolRefAttr::get(constCompositeOp.getOperation()));
+ unknownLoc, specConstCompositeReplicateOp.getType(),
+ SymbolRefAttr::get(specConstCompositeReplicateOp.getOperation()));
return referenceOfOp.getReference();
}
if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
@@ -175,8 +188,12 @@ LogicalResult spirv::Deserializer::processInstruction(
return processConstant(operands, /*isSpec=*/true);
case spirv::Opcode::OpConstantComposite:
return processConstantComposite(operands);
+ case spirv::Opcode::OpConstantCompositeReplicateEXT:
+ return processConstantCompositeReplicateEXT(operands);
case spirv::Opcode::OpSpecConstantComposite:
return processSpecConstantComposite(operands);
+ case spirv::Opcode::OpSpecConstantCompositeReplicateEXT:
+ return processSpecConstantCompositeReplicateEXT(operands);
case spirv::Opcode::OpSpecConstantOp:
return processSpecConstantOperation(operands);
case spirv::Opcode::OpConstantTrue:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index b1abd8b3dffe9..2163ccff93c83 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -678,6 +678,14 @@ spirv::Deserializer::getConstant(uint32_t id) {
return constIt->getSecond();
}
+std::optional<std::pair<uint32_t, Type>>
+spirv::Deserializer::getConstantCompositeReplicate(uint32_t id) {
+ auto constIt = constantCompositeReplicateMap.find(id);
+ if (constIt == constantCompositeReplicateMap.end())
+ return std::nullopt;
+ return constIt->getSecond();
+}
+
std::optional<spirv::SpecConstOperationMaterializationInfo>
spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
auto constIt = specConstOperationMap.find(id);
@@ -1554,15 +1562,58 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult spirv::Deserializer::processConstantCompositeReplicateEXT(
+ ArrayRef<uint32_t> operands) {
+
+ if (operands.size() != 3) {
+ return emitError(
+ unknownLoc,
+ "OpConstantCompositeReplicateEXT must have type <id> and result <id> "
+ "and only one parameter which is <id> of splat constant");
+ }
+
+ Type resultType = getType(operands[0]);
+ if (!resultType) {
+ return emitError(unknownLoc, "undefined result type from <id> ")
+ << operands[0];
+ }
+
+ auto compositeType = dyn_cast<CompositeType>(resultType);
+ if (!compositeType) {
+ return emitError(unknownLoc,
+ "result type from <id> is not a composite type")
+ << operands[0];
+ }
+
+ auto resultID = operands[1];
+ auto constantID = operands[2];
+
+ auto constantInfo = getConstant(constantID);
+ auto replicatedConstantCompositeInfo =
+ getConstantCompositeReplicate(constantID);
+ if (!constantInfo && !replicatedConstantCompositeInfo) {
+ return emitError(unknownLoc,
+ "OpConstantCompositeReplicateEXT operand <id> ")
+ << constantID
+ << " must come from a normal constant or a "
+ "OpConstantCompositeReplicateEXT";
+ }
+
+ constantCompositeReplicateMap.try_emplace(resultID, constantID, resultType);
+
+ return success();
+}
+
LogicalResult
spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
if (operands.size() < 2) {
- return emitError(unknownLoc,
- "OpConstantComposite must have type <id> and result <id>");
+ return emitError(
+ unknownLoc,
+ "OpSpecConstantComposite must have type <id> and result <id>");
}
if (operands.size() < 3) {
return emitError(unknownLoc,
- "OpConstantComposite must have at least 1 parameter");
+ "OpSpecConstantComposite must have at least 1 parameter");
}
Type resultType = getType(operands[0]);
@@ -1589,6 +1640,42 @@ spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult spirv::Deserializer::processSpecConstantCompositeReplicateEXT(
+ ArrayRef<uint32_t> operands) {
+
+ if (operands.size() != 3) {
+ return emitError(unknownLoc,
+ "OpSpecConstantCompositeReplicateEXT must have "
+ "type <id> and result <id> and only one parameter which "
+ "is <id> of splat constant");
+ }
+
+ Type resultType = getType(operands[0]);
+ if (!resultType) {
+ return emitError(unknownLoc, "undefined result type from <id> ")
+ << operands[0];
+ }
+
+ auto compositeType = dyn_cast<CompositeType>(resultType);
+ if (!compositeType) {
+ return emitError(unknownLoc,
+ "result type from <id> is not a composite type")
+ << operands[0];
+ }
+
+ auto resultID = operands[1];
+
+ auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
+ auto constituentSpecConstantOp = getSpecConstant(operands[2]);
+ auto op = opBuilder.create<spirv::EXTSpecConstantCompositeReplicateOp>(
+ unknownLoc, TypeAttr::get(resultType), symName,
+ SymbolRefAttr::get(constituentSpecConstantOp));
+
+ specConstCompositeReplicateMap[resultID] = op;
+
+ return success();
+}
+
LogicalResult
spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
if (operands.size() < 3)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 1bc9e4a3c75d8..1fdecc3e6fe0d 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -190,6 +190,12 @@ class Deserializer {
/// Gets the constant's attribute and type associated with the given <id>.
std::optional<std::pair<Attribute, Type>> getConstant(uint32_t id);
+ /// Gets the pair of id of `spirv.Constant` generating a
+ /// replicated composite and the type of resulting
+ /// `spirv.EXT.ConstantCompositeReplicate` given its <id>.
+ std::optional<std::pair<uint32_t, Type>>
+ getConstantCompositeReplicate(uint32_t id);
+
/// Gets the info needed to materialize the spec constant operation op
/// associated with the given <id>.
std::optional<SpecConstOperationMaterializationInfo>
@@ -220,6 +226,13 @@ class Deserializer {
return specConstCompositeMap.lookup(id);
}
+ /// Gets the replicated composite specialization constant with the given
+ /// result <id>.
+ spirv::EXTSpecConstantCompositeReplicateOp
+ getSpecConstantCompositeReplicate(uint32_t id) {
+ return specConstCompositeReplicateMap.lookup(id);
+ }
+
/// Creates a spirv::SpecConstantOp.
spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID,
TypedAttr defaultValue);
@@ -313,10 +326,20 @@ class Deserializer {
/// `operands`.
LogicalResult processConstantComposite(ArrayRef<uint32_t> operands);
+ /// Processes a SPIR-V OpConstantCompositeReplicateEXT instruction with
+ /// the given `operands`.
+ LogicalResult
+ processConstantCompositeReplicateEXT(ArrayRef<uint32_t> operands);
+
/// Processes a SPIR-V OpSpecConstantComposite instruction with the given
/// `operands`.
LogicalResult processSpecConstantComposite(ArrayRef<uint32_t> operands);
+ /// Processes a SPIR-V OpSpecConstantCompositeReplicateEXT instruction with
+ /// the given `operands`.
+ LogicalResult
+ processSpecConstantCompositeReplicateEXT(ArrayRef<uint32_t> operands);
+
/// Processes a SPIR-V OpSpecConstantOp instruction with the given
/// `operands`.
LogicalResult processSpecConstantOperation(ArrayRef<uint32_t> operands);
@@ -549,12 +572,28 @@ class Deserializer {
/// (and type) here. Later when it's used, we materialize the constant.
DenseMap<uint32_t, std::pair<Attribute, Type>> constantMap;
+ // Result <id> to replicated constant id and type mapping.
+ ///
+ /// In the SPIR-V binary format, OpConstantCompositeReplicateEXT is placed in
+ /// the module and shared by instructions at module level and in subsequent
+ /// functions. But in the SPIR-V dialect, this is materialized to where
+ /// it's used in the function. So when seeing a
+ /// OpConstantCompositeReplicateEXT in the binary format, we don't immediately
+ /// emit a `spirv.EXT.ConstantCompositeReplicate` op into the module, we keep
+ /// the id of its operand (the splat constant) and type) here. Later when it's
+ /// used, we materialize the `spirv.EXT.ConstantCompositeReplicate`.
+ DenseMap<uint32_t, std::pair<uint32_t, Type>> constantCompositeReplicateMap;
+
// Result <id> to spec constant mapping.
DenseMap<uint32_t, spirv::SpecConstantOp> specConstMap;
// Result <id> to composite spec constant mapping.
DenseMap<uint32_t, spirv::SpecConstantCompositeOp> specConstCompositeMap;
+ // Result <id> to replicated composite spec constant mapping.
+ DenseMap<uint32_t, spirv::EXTSpecConstantCompositeReplicateOp>
+ specConstCompositeReplicateMap;
+
/// Result <id> to info needed to materialize an OpSpecConstantOp
/// mapping.
DenseMap<uint32_t, SpecConstOperationMaterializationInfo>
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index ff3cc92ee8078..89edf6c82b090 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -66,6 +66,15 @@ LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
return failure();
}
+LogicalResult Serializer::processConstantCompositeReplicateOp(
+ spirv::EXTConstantCompositeReplicateOp op) {
+ if (auto resultID = prepareConstantCompositeReplicate(op)) {
+ valueIDMap[op.getResult()] = resultID;
+ return success();
+ }
+ return failure();
+}
+
LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(),
/*isSpec=*/true)) {
@@ -118,6 +127,33 @@ Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
return processName(resultID, op.getSymName());
}
+LogicalResult Serializer::processSpecConstantCompositeReplicateOp(
+ spirv::EXTSpecConstantCompositeReplicateOp op) {
+ uint32_t typeID = 0;
+ if (failed(processType(op.getLoc(), op.getType(), typeID))) {
+ return failure();
+ }
+
+ auto constituent = dyn_cast<FlatSymbolRefAttr>(op.getConstituent());
+ auto constituentName = constituent.getValue();
+ auto constituentID = getSpecConstID(constituentName);
+ if (!constituentID) {
+ return op.emitError("unknown result <id> for replicated spec constant ")
+ << constituentName;
+ }
+
+ auto resultID = getNextID();
+ SmallVector<uint32_t> operands = {typeID, resultID, constituentID};
+
+ encodeInstructionInto(typesGlobalValues,
+ spirv::Opcode::OpSpecConstantCompositeReplicateEXT,
+ operands);
+
+ specConstIDMap[op.getSymName()] = resultID;
+
+ return processName(resultID, op.getSymName());
+}
+
LogicalResult
Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
uint32_t typeID = 0;
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index ebebd2d283afa..2e767cd822617 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1109,6 +1109,46 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
return resultID;
}
+uint32_t Serializer::prepareConstantCompositeReplicate(
+ spirv::EXTConstantCompositeReplicateOp op) {
+ if (auto id = getValueID(op.getResult())) {
+ return id;
+ }
+
+ uint32_t typeID = 0;
+ if (failed(processType(op.getLoc(), op.getType(), typeID))) {
+ return 0;
+ }
+
+ auto definingOp = op.getConstant().getDefiningOp();
+ if (!definingOp) {
+ emitError(op.getLoc(), "op defining splat value not found");
+ return 0;
+ }
+
+ uint32_t operandID;
+ if (auto constantOp = dyn_cast_or_null<spirv::ConstantOp>(definingOp)) {
+ operandID = getConstantID(constantOp.getValue());
+
+ } else if (auto constantCompositeReplicateOp =
+ dyn_cast_or_null<spirv::EXTConstantCompositeReplicateOp>(
+ definingOp)) {
+ operandID = prepareConstantCompositeReplicate(constantCompositeReplicateOp);
+ } else {
+ emitError(op.getLoc(), "operand op type not supported");
+ return 0;
+ }
+
+ uint32_t resultID = getNextID();
+ SmallVector<uint32_t> operands = {typeID, resultID, operandID};
+
+ encodeInstructionInto(typesGlobalValues,
+ spirv::Opcode::OpConstantCompositeReplicateEXT,
+ operands);
+
+ return resultID;
+}
+
//===----------------------------------------------------------------------===//
// Control flow
//===----------------------------------------------------------------------===//
@@ -1328,6 +1368,9 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
return processBranchConditionalOp(op);
})
.Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
+ .Case([&](spirv::EXTConstantCompositeReplicateOp op) {
+ return processConstantCompositeReplicateOp(op);
+ })
.Case([&](spirv::FuncOp op) { return processFuncOp(op); })
.Case([&](spirv::GlobalVariableOp op) {
return processGlobalVariableOp(op);
@@ -1339,6 +1382,9 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
.Case([&](spirv::SpecConstantCompositeOp op) {
return processSpecConstantCompositeOp(op);
})
+ .Case([&](spirv::EXTSpecConstantCompositeReplicateOp op) {
+ return processSpecConstantCompositeReplicateOp(op);
+ })
.Case([&](spirv::SpecConstantOperationOp op) {
return processSpecConstantOperationOp(op);
})
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index 9edb0f4af008d..183af41172be7 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -108,11 +108,17 @@ class Serializer {
LogicalResult processConstantOp(spirv::ConstantOp op);
+ LogicalResult processConstantCompositeReplicateOp(
+ spirv::EXTConstantCompositeReplicateOp op);
+
LogicalResult processSpecConstantOp(spirv::SpecConstantOp op);
LogicalResult
processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op);
+ LogicalResult processSpecConstantCompositeReplicateOp(
+ spirv::EXTSpecConstantCompositeReplicateOp op);
+
LogicalResult
processSpecConstantOperationOp(spirv::SpecConstantOperationOp op);
@@ -230,6 +236,12 @@ class Serializer {
uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr,
bool isSpec = false);
+ /// Prepares `spirv.EXTConstantCompositeReplicateOp` serialization. This
+ /// method emits OpConstantCompositeReplicateEXT and returns the result <id>
+ /// associated with it.
+ uint32_t
+ prepareConstantCompositeReplicate(spirv::EXTConstantCompositeReplicateOp op);
+
//===--------------------------------------------------------------------===//
// Control flow
//===--------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 8d4e53418b70f..0b9b28ae536ea 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip %s | FileCheck %s
+// RUN: mlir-translate --no-implicit-module -split-input-file --test-spirv-roundtrip %s | FileCheck %s
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK-LABEL: @bool_const
@@ -306,3 +306,84 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
}
}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompositesEXT], [SPV_EXT_replicated_composites]> {
+
+ // CHECK-LABEL: @splat_vector_i32
+ spirv.func @splat_vector_i32() -> (vector<3xi32>) "None" {
+ // CHECK: spirv.Constant 1 : i32
+ %0 = spirv.Constant 1 : i32
+ // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : vector<3xi32>
+ %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<3xi32>
+ spirv.ReturnValue %1 : vector<3xi32>
+ }
+
+ // CHECK-LABEL: @splat_array_of_i32
+ spirv.func @splat_array_of_i32() -> (!spirv.array<3 x i32>) "None" {
+ // CHECK: spirv.Constant 1 : i32
+ %0 = spirv.Constant 1 : i32
+ // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<3 x i32>
+ %1 = spirv.EXT.ConstantCompositeReplicate %0 : !spirv.array<3 x i32>
+ spirv.ReturnValue %1 : !spirv.array<3 x i32>
+ }
+
+ // CHECK-LABEL: @splat_array_of_vectors_of_i32
+ spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" {
+ // CHECK: spirv.Constant dense<[1, 2]> : vector<2xi32>
+ %0 = spirv.Constant dense<[1, 2]> : vector<2xi32>
+ // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<2 x vector<2xi32>>
+ %1 = spirv.EXT.ConstantCompositeReplicate %0 : !spirv.array<2 x vector<2xi32>>
+ spirv.ReturnValue %1 : !spirv.array<2 x vector<2xi32>>
+ }
+
+ // CHECK-LABEL: @splat_array_of_splat_vector_i32
+ spirv.func @splat_array_of_splat_vector_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" {
+ // CHECK: spirv.Constant 2 : i32
+ %0 = spirv.Constant 2 : i32
+ // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : vector<2xi32>
+ %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xi32>
+ // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<2 x vector<2xi32>>
+ %2 = spirv.EXT.ConstantCompositeReplicate %1 : !spirv.array<2 x vector<2xi32>>
+ spirv.ReturnValue %2 : !spirv.array<2 x vector<2xi32>>
+ }
+
+ // CHECK-LABEL: @splat_vector_f32
+ spirv.func @splat_vector_f32() -> (vector<3xf32>) "None" {
+ // CHECK: spirv.Constant 1.000000e+00 : f32
+ %0 = spirv.Constant 1.0 : f32
+ // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : vector<3xf32>
+ %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<3xf32>
+ spirv.ReturnValue %1 : vector<3xf32>
+ }
+
+ // CHECK-LABEL: @splat_array_of_f32
+ spirv.func @splat_array_of_f32() -> (!spirv.array<3 x f32>) "None" {
+ // CHECK: spirv.Constant 1.000000e+00 : f32
+ %0 = spirv.Constant 1.0 : f32
+ // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<3 x f32>
+ %1 = spirv.EXT.ConstantCompositeReplicate %0 : !spirv.array<3 x f32>
+ spirv.ReturnValue %1 : !spirv.array<3 x f32>
+ }
+
+ // CHECK-LABEL: @splat_array_of_vectors_of_f32
+ spirv.func @splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" {
+ // CHECK: spirv.Constant dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>
+ %0 = spirv.Constant dense<[1.0, 2.0]> : vector<2xf32>
+ // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<2 x vector<2xf32>>
+ %1 = spirv.EXT.ConstantCompositeReplicate %0 : !spirv.array<2 x vector<2xf32>>
+ spirv.ReturnValue %1 : !spirv.array<2 x vector<2xf32>>
+ }
+
+ // CHECK-LABEL: @splat_array_of_splat_vector_f32
+ spirv.func @splat_array_of_splat_vector_f32() -> (!spirv.array<2 x vector<2xf32>>) "None" {
+ // CHECK: spirv.Constant 2.000000e+00 : f32
+ %0 = spirv.Constant 2.0 : f32
+ // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : vector<2xf32>
+ %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xf32>
+ // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<2 x vector<2xf32>>
+ %2 = spirv.EXT.ConstantCompositeReplicate %1 : !spirv.array<2 x vector<2xf32>>
+ spirv.ReturnValue %2 : !spirv.array<2 x vector<2xf32>>
+ }
+}
diff --git a/mlir/test/Target/SPIRV/spec-constant.mlir b/mlir/test/Target/SPIRV/spec-constant.mlir
index 078d77125b3fd..f434956ab34a3 100644
--- a/mlir/test/Target/SPIRV/spec-constant.mlir
+++ b/mlir/test/Target/SPIRV/spec-constant.mlir
@@ -88,6 +88,33 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// -----
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompositesEXT], [SPV_EXT_replicated_composites]> {
+
+ spirv.SpecConstant @sc_i32_1 = 1 : i32
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_i32 (@sc_i32_1) : !spirv.array<3 x i32>
+ spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_i32 (@sc_i32_1) : !spirv.array<3 x i32>
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_i32 (@sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+ spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_i32 (@sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_i32 (@sc_i32_1) : vector<3xi32>
+ spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_i32 (@sc_i32_1) : vector<3 x i32>
+
+ spirv.SpecConstant @sc_f32_1 = 1.0 : f32
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_f32 (@sc_f32_1) : !spirv.array<3 x f32>
+ spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_f32 (@sc_f32_1) : !spirv.array<3 x f32>
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_f32 (@sc_f32_1) : !spirv.struct<(f32, f32, f32)>
+ spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_f32 (@sc_f32_1) : !spirv.struct<(f32, f32, f32)>
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_f32 (@sc_f32_1) : vector<3xf32>
+ spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_f32 (@sc_f32_1) : vector<3 x f32>
+}
+
+// -----
+
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.SpecConstant @sc_i32_1 = 1 : i32
>From 63dd835a3e47ef4e8ce910cfe65c5b39d736a2c0 Mon Sep 17 00:00:00 2001
From: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Date: Mon, 7 Jul 2025 08:50:06 +0100
Subject: [PATCH 2/4] Addressing code review comments
Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
---
.../Dialect/SPIRV/IR/SPIRVStructureOps.td | 15 +++++-------
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 6 ++---
.../SPIRV/Deserialization/Deserializer.cpp | 24 +++++++++----------
.../SPIRV/Serialization/SerializeOps.cpp | 8 +++----
mlir/test/Target/SPIRV/constant.mlir | 2 +-
5 files changed, 26 insertions(+), 29 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 0a5b01fe9e8d0..4f5c5e7e0dd48 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -144,10 +144,9 @@ def SPIRV_EXTConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"ConstantComposite
}];
let description = [{
- This op declares a `spiv.EXT.ConstantCompositeReplicate` which represents a
- splat composite constant i.e. all element of composite constant have the
- same value. This op will be serialized to SPIR-V `OpConstantCompositeReplicateEXT`.
- The splat value must come from a non-specialization constant instruction."
+ Represents a splat composite constant i.e., all element of composite constant
+ have the same value. The splat value must come from a non-specialization constant
+ instruction.
#### Example:
@@ -739,11 +738,9 @@ def SPIRV_EXTSpecConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"SpecConstantC
let summary = "Declare a new replicated composite specialization constant op.";
let description = [{
- This op declares a `spirv.EXT.SpecConstantCompositeReplicate` which represents
- a splat specialization composite constant i.e. all element of specialization
- composite constant have the same value. This op will be serialized to SPIR-V
- `OpSpecConstantCompositeReplicateEXT`. The splat value must come from a
- symbol reference of specialization constant instruction.
+ Represents a splat spec composite constant i.e., all element of spec composite
+ constant have the same value. The splat value must come from a symbol reference
+ of spec constant instruction.
#### Example:
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index c42b2d45d53a9..b5c30051192f5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -779,13 +779,13 @@ spirv::EXTConstantCompositeReplicateOp::parse(OpAsmParser &parser,
return failure();
}
- if (llvm::isa<TensorType>(compositeType)) {
+ if (isa<TensorType>(compositeType)) {
if (parser.parseColonType(compositeType))
return failure();
}
- auto constType = cast<spirv::CompositeType>(compositeType).getElementType(0);
- while (auto type = llvm::dyn_cast<spirv::ArrayType>(constType)) {
+ Type constType = cast<spirv::CompositeType>(compositeType).getElementType(0);
+ while (auto type = dyn_cast<spirv::ArrayType>(constType)) {
constType = type.getElementType();
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 2163ccff93c83..40cc8f90cfee5 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -680,10 +680,10 @@ spirv::Deserializer::getConstant(uint32_t id) {
std::optional<std::pair<uint32_t, Type>>
spirv::Deserializer::getConstantCompositeReplicate(uint32_t id) {
- auto constIt = constantCompositeReplicateMap.find(id);
- if (constIt == constantCompositeReplicateMap.end())
- return std::nullopt;
- return constIt->getSecond();
+ if (auto it = constantCompositeReplicateMap.find(id);
+ it != constantCompositeReplicateMap.end())
+ return it->second;
+ return std::nullopt;
}
std::optional<spirv::SpecConstOperationMaterializationInfo>
@@ -1564,7 +1564,6 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
LogicalResult spirv::Deserializer::processConstantCompositeReplicateEXT(
ArrayRef<uint32_t> operands) {
-
if (operands.size() != 3) {
return emitError(
unknownLoc,
@@ -1585,11 +1584,12 @@ LogicalResult spirv::Deserializer::processConstantCompositeReplicateEXT(
<< operands[0];
}
- auto resultID = operands[1];
- auto constantID = operands[2];
+ uint32_t resultID = operands[1];
+ uint32_t constantID = operands[2];
- auto constantInfo = getConstant(constantID);
- auto replicatedConstantCompositeInfo =
+ std::optional<std::pair<Attribute, Type>> constantInfo =
+ getConstant(constantID);
+ std::optional<std::pair<uint32_t, Type>> replicatedConstantCompositeInfo =
getConstantCompositeReplicate(constantID);
if (!constantInfo && !replicatedConstantCompositeInfo) {
return emitError(unknownLoc,
@@ -1642,7 +1642,6 @@ spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
LogicalResult spirv::Deserializer::processSpecConstantCompositeReplicateEXT(
ArrayRef<uint32_t> operands) {
-
if (operands.size() != 3) {
return emitError(unknownLoc,
"OpSpecConstantCompositeReplicateEXT must have "
@@ -1663,10 +1662,11 @@ LogicalResult spirv::Deserializer::processSpecConstantCompositeReplicateEXT(
<< operands[0];
}
- auto resultID = operands[1];
+ uint32_t resultID = operands[1];
auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
- auto constituentSpecConstantOp = getSpecConstant(operands[2]);
+ spirv::SpecConstantOp constituentSpecConstantOp =
+ getSpecConstant(operands[2]);
auto op = opBuilder.create<spirv::EXTSpecConstantCompositeReplicateOp>(
unknownLoc, TypeAttr::get(resultType), symName,
SymbolRefAttr::get(constituentSpecConstantOp));
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 89edf6c82b090..02dc25ac92e8c 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -68,7 +68,7 @@ LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
LogicalResult Serializer::processConstantCompositeReplicateOp(
spirv::EXTConstantCompositeReplicateOp op) {
- if (auto resultID = prepareConstantCompositeReplicate(op)) {
+ if (uint32_t resultID = prepareConstantCompositeReplicate(op)) {
valueIDMap[op.getResult()] = resultID;
return success();
}
@@ -135,14 +135,14 @@ LogicalResult Serializer::processSpecConstantCompositeReplicateOp(
}
auto constituent = dyn_cast<FlatSymbolRefAttr>(op.getConstituent());
- auto constituentName = constituent.getValue();
- auto constituentID = getSpecConstID(constituentName);
+ StringRef constituentName = constituent.getValue();
+ uint32_t constituentID = getSpecConstID(constituentName);
if (!constituentID) {
return op.emitError("unknown result <id> for replicated spec constant ")
<< constituentName;
}
- auto resultID = getNextID();
+ uint32_t resultID = getNextID();
SmallVector<uint32_t> operands = {typeID, resultID, constituentID};
encodeInstructionInto(typesGlobalValues,
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 0b9b28ae536ea..f521deebe0bb8 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-translate --no-implicit-module -split-input-file --test-spirv-roundtrip %s | FileCheck %s
+// RUN: mlir-translate --no-implicit-module --split-input-file --test-spirv-roundtrip %s | FileCheck %s
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK-LABEL: @bool_const
>From 8c0b6cb6917cb4af5eb483094c15296a8610e029 Mon Sep 17 00:00:00 2001
From: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Date: Mon, 7 Jul 2025 16:47:44 +0100
Subject: [PATCH 3/4] Addressing further code review comments
Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
---
.../Dialect/SPIRV/IR/SPIRVStructureOps.td | 1 -
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 32 +++++----
.../SPIRV/Deserialization/DeserializeOps.cpp | 12 ++--
.../Target/SPIRV/Serialization/Serializer.cpp | 4 +-
mlir/test/Dialect/SPIRV/IR/availability.mlir | 14 ++++
mlir/test/Dialect/SPIRV/IR/structure-ops.mlir | 65 +++++++++++++++++++
6 files changed, 108 insertions(+), 20 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 4f5c5e7e0dd48..e6801159931e9 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -767,7 +767,6 @@ def SPIRV_EXTSpecConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"SpecConstantC
let results = (outs);
let autogenSerialization = 0;
-
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index b5c30051192f5..c34003cf0ad7d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -773,18 +773,22 @@ ParseResult
spirv::EXTConstantCompositeReplicateOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::UnresolvedOperand constOperand;
- Type compositeType;
- if (parser.parseOperand(constOperand) ||
- parser.parseColonType(compositeType)) {
+ Type resultType;
+ if (parser.parseOperand(constOperand) || parser.parseColonType(resultType)) {
return failure();
}
- if (isa<TensorType>(compositeType)) {
- if (parser.parseColonType(compositeType))
+ if (isa<TensorType>(resultType)) {
+ if (parser.parseColonType(resultType))
return failure();
}
- Type constType = cast<spirv::CompositeType>(compositeType).getElementType(0);
+ auto compositeType = dyn_cast_or_null<spirv::CompositeType>(resultType);
+ if (!compositeType)
+ return parser.emitError(parser.getCurrentLocation(),
+ "result is not a composite type");
+
+ Type constType = compositeType.getElementType(0);
while (auto type = dyn_cast<spirv::ArrayType>(constType)) {
constType = type.getElementType();
}
@@ -805,7 +809,7 @@ LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
return emitError("result type must be a composite type, but provided ")
<< getType();
- auto constantDefiningOp = getConstant().getDefiningOp();
+ Operation *constantDefiningOp = getConstant().getDefiningOp();
if (!constantDefiningOp)
return this->emitOpError("op defining the splat constant not found");
@@ -1972,12 +1976,16 @@ LogicalResult spirv::EXTSpecConstantCompositeReplicateOp::verify() {
return emitError("result type must be a composite type, but provided ")
<< getType();
- auto constituentSpecConstOp =
- dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
- (*this)->getParentOp(), this->getConstituent()));
+ Operation *constituentOp = SymbolTable::lookupNearestSymbolFrom(
+ (*this)->getParentOp(), this->getConstituent());
+ if (!constituentOp)
+ return emitError(
+ "splat spec constant reference defining constituent not found");
+
+ auto constituentSpecConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp);
- auto constituentType = constituentSpecConstOp.getDefaultValue().getType();
- auto compositeElemType = compositeType.getElementType(0);
+ Type constituentType = constituentSpecConstOp.getDefaultValue().getType();
+ Type compositeElemType = compositeType.getElementType(0);
if (constituentType != compositeElemType)
return emitError("constituent has incorrect type: expected ")
<< compositeElemType << ", but provided " << constituentType;
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 5f52308b4be35..6c4d7b8cf3f37 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -45,11 +45,12 @@ Value spirv::Deserializer::getValue(uint32_t id) {
return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
constInfo->first);
}
- if (auto constCompositeReplicateInfo = getConstantCompositeReplicate(id)) {
- auto constantId = constCompositeReplicateInfo->first;
- auto element = getValue(constantId);
+ if (std::optional<std::pair<uint32_t, Type>> constCompositeReplicateInfo =
+ getConstantCompositeReplicate(id)) {
+ uint32_t constantId = constCompositeReplicateInfo->first;
+ Value constantValue = getValue(constantId);
return opBuilder.create<spirv::EXTConstantCompositeReplicateOp>(
- unknownLoc, constCompositeReplicateInfo->second, element);
+ unknownLoc, constCompositeReplicateInfo->second, constantValue);
}
if (auto varOp = getGlobalVariable(id)) {
auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
@@ -62,7 +63,8 @@ Value spirv::Deserializer::getValue(uint32_t id) {
SymbolRefAttr::get(constOp.getOperation()));
return referenceOfOp.getReference();
}
- if (auto specConstCompositeOp = getSpecConstantComposite(id)) {
+ if (SpecConstantCompositeOp specConstCompositeOp =
+ getSpecConstantComposite(id)) {
auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
unknownLoc, specConstCompositeOp.getType(),
SymbolRefAttr::get(specConstCompositeOp.getOperation()));
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 2e767cd822617..e9dd6da6530f7 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1111,7 +1111,7 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
uint32_t Serializer::prepareConstantCompositeReplicate(
spirv::EXTConstantCompositeReplicateOp op) {
- if (auto id = getValueID(op.getResult())) {
+ if (uint32_t id = getValueID(op.getResult())) {
return id;
}
@@ -1120,7 +1120,7 @@ uint32_t Serializer::prepareConstantCompositeReplicate(
return 0;
}
- auto definingOp = op.getConstant().getDefiningOp();
+ Operation *definingOp = op.getConstant().getDefiningOp();
if (!definingOp) {
emitError(op.getLoc(), "op defining splat value not found");
return 0;
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index 64ba8e3fc249e..329f5a8eed9ae 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -278,3 +278,17 @@ func.func @set_mesh_outputs(%0 : i32, %1 : i32) -> () {
spirv.EXT.SetMeshOutputs %0, %1 : i32, i32
spirv.Return
}
+
+//===----------------------------------------------------------------------===//
+// Replicated Composite Constant op
+//===----------------------------------------------------------------------===//
+// CHECK-LABEL: constant_composite_replicate
+func.func @constant_composite_replicate() -> () {
+ // CHECK: min version: v1.0
+ // CHECK: max version: v1.6
+ // CHECK: extensions: [ [SPV_EXT_replicated_composites] ]
+ // CHECK: capabilities: [ [ReplicatedCompositesEXT] ]
+ %0 = spirv.Constant 1 : i32
+ %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xi32>
+ spirv.Return
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 207549afdda94..39f7df10be634 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -165,6 +165,45 @@ func.func @coop_matrix_const_wrong_type() -> () {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.EXT.ConstantCompositeReplicate
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical GLSL450 {
+ spirv.func @ccr() -> i32 "None" {
+ %0 = spirv.Constant 1 : i32
+ // expected-error @+2 {{result is not a composite type}}
+ %1 = spirv.EXT.ConstantCompositeReplicate %0 : i32
+ spirv.ReturnValue %1: i32
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @ccr() -> vector<2xf32> "None" {
+ %0 = spirv.Constant 1 : i32
+ // expected-note at -1 {{prior use here}}
+ // expected-error @+1 {{use of value '%0' expects different type than prior uses: 'f32' vs 'i32'}}
+ %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xf32>
+ spirv.ReturnValue %1: vector<2xf32>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @ccr() -> vector<2xi32> "None" {
+ %c1 = spirv.Constant 1 : i32
+ %0 = spirv.IAdd %c1, %c1 : i32
+ // expected-error @+1 {{op defining the splat constant is not a spirv.Constant or a spirv.EXT.ConstantCompositeReplicate}}
+ %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xi32>
+ spirv.ReturnValue %1: vector<2xi32>
+ }
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.EntryPoint
//===----------------------------------------------------------------------===//
@@ -854,6 +893,32 @@ spirv.module Logical GLSL450 {
spirv.SpecConstantComposite @scc (@sc1) : !spirv.coopmatrix<8x16xf32, Device, MatrixA>
}
+//===----------------------------------------------------------------------===//
+// spirv.EXT.SpecConstantCompositeReplicate
+//===----------------------------------------------------------------------===//
+
+// -----
+
+spirv.module Logical GLSL450 {
+ // expected-error @+1 {{result type must be a composite type, but provided 'i32'}}
+ spirv.EXT.SpecConstantCompositeReplicate @sccr (@sc_i32_1) : i32
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ // expected-error @+1 {{splat spec constant reference defining constituent not found}}
+ spirv.EXT.SpecConstantCompositeReplicate @sccr (@sc_f32_1) : !spirv.array<3 x i32>
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.SpecConstant @sc_f32_1 = 1.0 : f32
+ // expected-error @+1 {{constituent has incorrect type: expected 'i32', but provided 'f32'}}
+ spirv.EXT.SpecConstantCompositeReplicate @sccr (@sc_f32_1) : !spirv.array<3 x i32>
+}
+
//===----------------------------------------------------------------------===//
// spirv.SpecConstantOperation
//===----------------------------------------------------------------------===//
>From 9a6b19125f95d140f8fd93fd9d5e49da1c0bdfba Mon Sep 17 00:00:00 2001
From: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Date: Wed, 9 Jul 2025 15:45:56 +0100
Subject: [PATCH 4/4] Addressing further code review comments
Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
---
.../Dialect/SPIRV/IR/SPIRVStructureOps.td | 18 ++--
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 74 +++++++--------
.../SPIRV/Deserialization/DeserializeOps.cpp | 7 +-
.../SPIRV/Deserialization/Deserializer.cpp | 27 +++---
.../SPIRV/Deserialization/Deserializer.h | 15 ++--
.../SPIRV/Serialization/SerializeOps.cpp | 3 +-
.../Target/SPIRV/Serialization/Serializer.cpp | 35 ++++----
.../Target/SPIRV/Serialization/Serializer.h | 12 ++-
mlir/test/Dialect/SPIRV/IR/availability.mlir | 3 +-
mlir/test/Dialect/SPIRV/IR/structure-ops.mlir | 35 +++-----
mlir/test/Target/SPIRV/constant.mlir | 90 +++++++++----------
11 files changed, 145 insertions(+), 174 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index e6801159931e9..bfe2eac3b89fc 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -144,21 +144,17 @@ def SPIRV_EXTConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"ConstantComposite
}];
let description = [{
- Represents a splat composite constant i.e., all element of composite constant
- have the same value. The splat value must come from a non-specialization constant
- instruction.
+ Represents a splat composite constant i.e., all elements of composite constant
+ have the same value.
#### Example:
```mlir
- %0 = spirv.Constant 1 : i32
- %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xf32>
+ %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : vector<2xi32>
- %2 = spirv.Constant dense<[1, 2]> : vector<2xi32>
- %3 = spirv.EXT.ConstantCompositeReplicate %2 : !spirv.array<2 x vector<2xi32>>
+ %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<2 x vector<2xi32>>
- %5 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xi32>
- %6 = spirv.EXT.ConstantCompositeReplicate %5 : !spirv.array<2 x vector<2xi32>>
+ %0 = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
```
}];
@@ -170,7 +166,7 @@ def SPIRV_EXTConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"ConstantComposite
];
let arguments = (ins
- SPIRV_Type:$constant
+ AnyAttr:$value
);
let results = (outs
@@ -738,7 +734,7 @@ def SPIRV_EXTSpecConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"SpecConstantC
let summary = "Declare a new replicated composite specialization constant op.";
let description = [{
- Represents a splat spec composite constant i.e., all element of spec composite
+ Represents a splat spec composite constant i.e., all elements of spec composite
constant have the same value. The splat value must come from a symbol reference
of spec constant instruction.
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index c34003cf0ad7d..b6f1e1411ddfb 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -772,60 +772,46 @@ void mlir::spirv::AddressOfOp::getAsmResultNames(
ParseResult
spirv::EXTConstantCompositeReplicateOp::parse(OpAsmParser &parser,
OperationState &result) {
- OpAsmParser::UnresolvedOperand constOperand;
+
+ Attribute value;
+ StringRef valueAttrName =
+ spirv::EXTConstantCompositeReplicateOp::getValueAttrName(result.name);
Type resultType;
- if (parser.parseOperand(constOperand) || parser.parseColonType(resultType)) {
+
+ if (parser.parseLSquare() ||
+ parser.parseAttribute(value, valueAttrName, result.attributes) ||
+ parser.parseRSquare() || parser.parseColonType(resultType))
return failure();
- }
- if (isa<TensorType>(resultType)) {
+ if (isa<NoneType, TensorType>(resultType))
if (parser.parseColonType(resultType))
return failure();
- }
- auto compositeType = dyn_cast_or_null<spirv::CompositeType>(resultType);
- if (!compositeType)
- return parser.emitError(parser.getCurrentLocation(),
- "result is not a composite type");
-
- Type constType = compositeType.getElementType(0);
- while (auto type = dyn_cast<spirv::ArrayType>(constType)) {
- constType = type.getElementType();
- }
-
- if (parser.resolveOperand(constOperand, constType, result.operands))
- return failure();
+ if (isa<TensorArmType>(resultType))
+ if (parser.parseOptionalColon().succeeded())
+ if (parser.parseType(resultType))
+ return failure();
- return parser.addTypeToList(compositeType, result.types);
+ return parser.addTypeToList(resultType, result.types);
}
void spirv::EXTConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
- printer << ' ' << getConstant() << " : " << getType();
+ printer << " [" << getValue() << "] : " << getType();
}
LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
- auto compositeType = dyn_cast<spirv::CompositeType>(getType());
- if (!compositeType)
- return emitError("result type must be a composite type, but provided ")
- << getType();
-
- Operation *constantDefiningOp = getConstant().getDefiningOp();
- if (!constantDefiningOp)
- return this->emitOpError("op defining the splat constant not found");
-
- auto constantOp = dyn_cast_or_null<spirv::ConstantOp>(constantDefiningOp);
- auto constantCompositeReplicateOp =
- dyn_cast_or_null<spirv::EXTConstantCompositeReplicateOp>(
- constantDefiningOp);
-
- if (!constantOp && !constantCompositeReplicateOp)
- return this->emitOpError(
- "op defining the splat constant is not a spirv.Constant or a "
- "spirv.EXT.ConstantCompositeReplicate");
+ Type valueType = dyn_cast<TypedAttr>(getValue()).getType();
+ Type compositeElementType =
+ dyn_cast<spirv::CompositeType>(getType()).getElementType(0);
+ while (compositeElementType != valueType &&
+ isa<spirv::CompositeType>(compositeElementType)) {
+ compositeElementType =
+ cast<spirv::CompositeType>(compositeElementType).getElementType(0);
+ }
- if (constantOp)
- return verifyConstantType(constantOp, constantOp.getValueAttr(),
- constantOp.getType());
+ if (compositeElementType != valueType)
+ return emitError("expected splat element type")
+ << compositeElementType << ", but got: " << valueType;
return success();
}
@@ -1940,8 +1926,8 @@ spirv::EXTSpecConstantCompositeReplicateOp::parse(OpAsmParser &parser,
OperationState &result) {
StringAttr compositeName;
- const char *attrName = "spec_const";
FlatSymbolRefAttr specConstRef;
+ const char *attrName = "spec_const";
NamedAttrList attrs;
Type type;
@@ -1985,10 +1971,10 @@ LogicalResult spirv::EXTSpecConstantCompositeReplicateOp::verify() {
auto constituentSpecConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp);
Type constituentType = constituentSpecConstOp.getDefaultValue().getType();
- Type compositeElemType = compositeType.getElementType(0);
- if (constituentType != compositeElemType)
+ Type compositeElementType = compositeType.getElementType(0);
+ if (constituentType != compositeElementType)
return emitError("constituent has incorrect type: expected ")
- << compositeElemType << ", but provided " << constituentType;
+ << compositeElementType << ", but provided " << constituentType;
return success();
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 6c4d7b8cf3f37..9fa03725d05ee 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -45,12 +45,11 @@ Value spirv::Deserializer::getValue(uint32_t id) {
return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
constInfo->first);
}
- if (std::optional<std::pair<uint32_t, Type>> constCompositeReplicateInfo =
+ if (std::optional<std::pair<Attribute, Type>> constCompositeReplicateInfo =
getConstantCompositeReplicate(id)) {
- uint32_t constantId = constCompositeReplicateInfo->first;
- Value constantValue = getValue(constantId);
return opBuilder.create<spirv::EXTConstantCompositeReplicateOp>(
- unknownLoc, constCompositeReplicateInfo->second, constantValue);
+ unknownLoc, constCompositeReplicateInfo->second,
+ constCompositeReplicateInfo->first);
}
if (auto varOp = getGlobalVariable(id)) {
auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 40cc8f90cfee5..e4dab971dcc14 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -678,7 +678,7 @@ spirv::Deserializer::getConstant(uint32_t id) {
return constIt->getSecond();
}
-std::optional<std::pair<uint32_t, Type>>
+std::optional<std::pair<Attribute, Type>>
spirv::Deserializer::getConstantCompositeReplicate(uint32_t id) {
if (auto it = constantCompositeReplicateMap.find(id);
it != constantCompositeReplicateMap.end())
@@ -1589,19 +1589,24 @@ LogicalResult spirv::Deserializer::processConstantCompositeReplicateEXT(
std::optional<std::pair<Attribute, Type>> constantInfo =
getConstant(constantID);
- std::optional<std::pair<uint32_t, Type>> replicatedConstantCompositeInfo =
- getConstantCompositeReplicate(constantID);
- if (!constantInfo && !replicatedConstantCompositeInfo) {
- return emitError(unknownLoc,
- "OpConstantCompositeReplicateEXT operand <id> ")
- << constantID
- << " must come from a normal constant or a "
- "OpConstantCompositeReplicateEXT";
+ if (constantInfo.has_value()) {
+ constantCompositeReplicateMap.try_emplace(
+ resultID, constantInfo.value().first, resultType);
+ return success();
}
- constantCompositeReplicateMap.try_emplace(resultID, constantID, resultType);
+ std::optional<std::pair<Attribute, Type>> replicatedConstantCompositeInfo =
+ getConstantCompositeReplicate(constantID);
+ if (replicatedConstantCompositeInfo.has_value()) {
+ constantCompositeReplicateMap.try_emplace(
+ resultID, replicatedConstantCompositeInfo.value().first, resultType);
+ return success();
+ }
- return success();
+ return emitError(unknownLoc, "OpConstantCompositeReplicateEXT operand <id> ")
+ << constantID
+ << " must come from a normal constant or a "
+ "OpConstantCompositeReplicateEXT";
}
LogicalResult
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 1fdecc3e6fe0d..20482bd2bf501 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -190,10 +190,9 @@ class Deserializer {
/// Gets the constant's attribute and type associated with the given <id>.
std::optional<std::pair<Attribute, Type>> getConstant(uint32_t id);
- /// Gets the pair of id of `spirv.Constant` generating a
- /// replicated composite and the type of resulting
- /// `spirv.EXT.ConstantCompositeReplicate` given its <id>.
- std::optional<std::pair<uint32_t, Type>>
+ /// Gets the replicated composite constant's attribute and type associated
+ /// with the given <id>.
+ std::optional<std::pair<Attribute, Type>>
getConstantCompositeReplicate(uint32_t id);
/// Gets the info needed to materialize the spec constant operation op
@@ -572,7 +571,7 @@ class Deserializer {
/// (and type) here. Later when it's used, we materialize the constant.
DenseMap<uint32_t, std::pair<Attribute, Type>> constantMap;
- // Result <id> to replicated constant id and type mapping.
+ // Result <id> to replicated constant attribute and type mapping.
///
/// In the SPIR-V binary format, OpConstantCompositeReplicateEXT is placed in
/// the module and shared by instructions at module level and in subsequent
@@ -580,9 +579,9 @@ class Deserializer {
/// it's used in the function. So when seeing a
/// OpConstantCompositeReplicateEXT in the binary format, we don't immediately
/// emit a `spirv.EXT.ConstantCompositeReplicate` op into the module, we keep
- /// the id of its operand (the splat constant) and type) here. Later when it's
- /// used, we materialize the `spirv.EXT.ConstantCompositeReplicate`.
- DenseMap<uint32_t, std::pair<uint32_t, Type>> constantCompositeReplicateMap;
+ /// the id of its value and type here. Later when it's used, we materialize
+ /// the `spirv.EXT.ConstantCompositeReplicate`.
+ DenseMap<uint32_t, std::pair<Attribute, Type>> constantCompositeReplicateMap;
// Result <id> to spec constant mapping.
DenseMap<uint32_t, spirv::SpecConstantOp> specConstMap;
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 02dc25ac92e8c..b722933ecdade 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -68,7 +68,8 @@ LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
LogicalResult Serializer::processConstantCompositeReplicateOp(
spirv::EXTConstantCompositeReplicateOp op) {
- if (uint32_t resultID = prepareConstantCompositeReplicate(op)) {
+ if (uint32_t resultID = prepareConstantCompositeReplicate(
+ op.getLoc(), op.getType(), op.getValue())) {
valueIDMap[op.getResult()] = resultID;
return success();
}
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index e9dd6da6530f7..e5ea01f6ce2a9 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1109,43 +1109,38 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
return resultID;
}
-uint32_t Serializer::prepareConstantCompositeReplicate(
- spirv::EXTConstantCompositeReplicateOp op) {
- if (uint32_t id = getValueID(op.getResult())) {
+uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
+ Type resultType,
+ Attribute valueAttr) {
+
+ std::pair<Attribute, Type> valueTypePair{valueAttr, resultType};
+ if (uint32_t id = getConstantCompositeReplicateID(valueTypePair)) {
return id;
}
uint32_t typeID = 0;
- if (failed(processType(op.getLoc(), op.getType(), typeID))) {
+ if (failed(processType(loc, resultType, typeID))) {
return 0;
}
- Operation *definingOp = op.getConstant().getDefiningOp();
- if (!definingOp) {
- emitError(op.getLoc(), "op defining splat value not found");
- return 0;
- }
+ auto elementType = dyn_cast<CompositeType>(resultType).getElementType(0);
+ Type valueType = dyn_cast<TypedAttr>(valueAttr).getType();
- uint32_t operandID;
- if (auto constantOp = dyn_cast_or_null<spirv::ConstantOp>(definingOp)) {
- operandID = getConstantID(constantOp.getValue());
-
- } else if (auto constantCompositeReplicateOp =
- dyn_cast_or_null<spirv::EXTConstantCompositeReplicateOp>(
- definingOp)) {
- operandID = prepareConstantCompositeReplicate(constantCompositeReplicateOp);
+ uint32_t constandID;
+ if (elementType == valueType) {
+ constandID = prepareConstant(loc, elementType, valueAttr);
} else {
- emitError(op.getLoc(), "operand op type not supported");
- return 0;
+ constandID = prepareConstantCompositeReplicate(loc, elementType, valueAttr);
}
uint32_t resultID = getNextID();
- SmallVector<uint32_t> operands = {typeID, resultID, operandID};
+ SmallVector<uint32_t> operands = {typeID, resultID, constandID};
encodeInstructionInto(typesGlobalValues,
spirv::Opcode::OpConstantCompositeReplicateEXT,
operands);
+ constCompositeReplicateIDMap[valueTypePair] = resultID;
return resultID;
}
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index 183af41172be7..7047869bca4cd 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -197,6 +197,11 @@ class Serializer {
return constIDMap.lookup(value);
}
+ uint32_t getConstantCompositeReplicateID(
+ std::pair<Attribute, Type> valueTypePair) const {
+ return constCompositeReplicateIDMap.lookup(valueTypePair);
+ }
+
/// Main dispatch method for processing a constant with the given `constType`
/// and `valueAttr`. `constType` is needed here because we can interpret the
/// `valueAttr` as a different type than the type of `valueAttr` itself; for
@@ -239,8 +244,8 @@ class Serializer {
/// Prepares `spirv.EXTConstantCompositeReplicateOp` serialization. This
/// method emits OpConstantCompositeReplicateEXT and returns the result <id>
/// associated with it.
- uint32_t
- prepareConstantCompositeReplicate(spirv::EXTConstantCompositeReplicateOp op);
+ uint32_t prepareConstantCompositeReplicate(Location loc, Type resultType,
+ Attribute valueAttr);
//===--------------------------------------------------------------------===//
// Control flow
@@ -401,6 +406,9 @@ class Serializer {
/// Map from constant values to their <id>s.
DenseMap<Attribute, uint32_t> constIDMap;
+ /// Map from a replicated composite constant's value and type to their <id>s.
+ DenseMap<std::pair<Attribute, Type>, uint32_t> constCompositeReplicateIDMap;
+
/// Map from specialization constant names to their <id>s.
llvm::StringMap<uint32_t> specConstIDMap;
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index 329f5a8eed9ae..1640862d07d5c 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -288,7 +288,6 @@ func.func @constant_composite_replicate() -> () {
// CHECK: max version: v1.6
// CHECK: extensions: [ [SPV_EXT_replicated_composites] ]
// CHECK: capabilities: [ [ReplicatedCompositesEXT] ]
- %0 = spirv.Constant 1 : i32
- %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xi32>
+ %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : vector<2xi32>
spirv.Return
}
\ No newline at end of file
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 39f7df10be634..5da26111e9917 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -169,37 +169,26 @@ func.func @coop_matrix_const_wrong_type() -> () {
// spirv.EXT.ConstantCompositeReplicate
//===----------------------------------------------------------------------===//
-spirv.module Logical GLSL450 {
- spirv.func @ccr() -> i32 "None" {
- %0 = spirv.Constant 1 : i32
- // expected-error @+2 {{result is not a composite type}}
- %1 = spirv.EXT.ConstantCompositeReplicate %0 : i32
- spirv.ReturnValue %1: i32
- }
+func.func @ccr_result_not_composite() -> () {
+ // expected-error @+1 {{op result #0 must be vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16 or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type, but got 'i32'}}
+ %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : i32
+ return
}
// -----
-spirv.module Logical GLSL450 {
- spirv.func @ccr() -> vector<2xf32> "None" {
- %0 = spirv.Constant 1 : i32
- // expected-note at -1 {{prior use here}}
- // expected-error @+1 {{use of value '%0' expects different type than prior uses: 'f32' vs 'i32'}}
- %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xf32>
- spirv.ReturnValue %1: vector<2xf32>
- }
+func.func @ccr_wrong_splat_type() -> () {
+ // expected-error @+1 {{expected splat element type'f32', but got: 'i32'}}
+ %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : vector<2xf32>
+ return
}
// -----
-spirv.module Logical GLSL450 {
- spirv.func @ccr() -> vector<2xi32> "None" {
- %c1 = spirv.Constant 1 : i32
- %0 = spirv.IAdd %c1, %c1 : i32
- // expected-error @+1 {{op defining the splat constant is not a spirv.Constant or a spirv.EXT.ConstantCompositeReplicate}}
- %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xi32>
- spirv.ReturnValue %1: vector<2xi32>
- }
+func.func @ccr_wrong_splat_type() -> () {
+ // expected-error @+1 {{expected splat element type'i32', but got: 'vector<2xi32>'}}
+ %0 = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x !spirv.array<3 x i32>>
+ return
}
// -----
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index f521deebe0bb8..809065fdc651d 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -313,77 +313,71 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
// CHECK-LABEL: @splat_vector_i32
spirv.func @splat_vector_i32() -> (vector<3xi32>) "None" {
- // CHECK: spirv.Constant 1 : i32
- %0 = spirv.Constant 1 : i32
- // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : vector<3xi32>
- %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<3xi32>
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [1 : i32] : vector<3xi32>
+ %1 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : vector<3xi32>
spirv.ReturnValue %1 : vector<3xi32>
}
// CHECK-LABEL: @splat_array_of_i32
spirv.func @splat_array_of_i32() -> (!spirv.array<3 x i32>) "None" {
- // CHECK: spirv.Constant 1 : i32
- %0 = spirv.Constant 1 : i32
- // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<3 x i32>
- %1 = spirv.EXT.ConstantCompositeReplicate %0 : !spirv.array<3 x i32>
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32>
+ %1 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32>
spirv.ReturnValue %1 : !spirv.array<3 x i32>
}
// CHECK-LABEL: @splat_array_of_vectors_of_i32
- spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" {
- // CHECK: spirv.Constant dense<[1, 2]> : vector<2xi32>
- %0 = spirv.Constant dense<[1, 2]> : vector<2xi32>
- // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<2 x vector<2xi32>>
- %1 = spirv.EXT.ConstantCompositeReplicate %0 : !spirv.array<2 x vector<2xi32>>
- spirv.ReturnValue %1 : !spirv.array<2 x vector<2xi32>>
+ spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<3 x vector<2xi32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<3 x vector<2xi32>>
+ %0 = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<3 x vector<2xi32>>
+ spirv.ReturnValue %0 : !spirv.array<3 x vector<2xi32>>
}
- // CHECK-LABEL: @splat_array_of_splat_vector_i32
- spirv.func @splat_array_of_splat_vector_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" {
- // CHECK: spirv.Constant 2 : i32
- %0 = spirv.Constant 2 : i32
- // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : vector<2xi32>
- %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xi32>
- // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<2 x vector<2xi32>>
- %2 = spirv.EXT.ConstantCompositeReplicate %1 : !spirv.array<2 x vector<2xi32>>
- spirv.ReturnValue %2 : !spirv.array<2 x vector<2xi32>>
+ // CHECK-LABEL: @splat_array_of_splat_vectors_of_i32
+ spirv.func @splat_array_of_splat_vectors_of_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>>
+ %0 = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
+ }
+
+ // CHECK-LABEL: @splat_tensor_of_i32
+ spirv.func @splat_tensor_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
+ %0 = spirv.EXT.ConstantCompositeReplicate [3 : i32] : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
}
// CHECK-LABEL: @splat_vector_f32
spirv.func @splat_vector_f32() -> (vector<3xf32>) "None" {
- // CHECK: spirv.Constant 1.000000e+00 : f32
- %0 = spirv.Constant 1.0 : f32
- // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : vector<3xf32>
- %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<3xf32>
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : vector<3xf32>
+ %1 = spirv.EXT.ConstantCompositeReplicate [1.0 : f32] : vector<3xf32>
spirv.ReturnValue %1 : vector<3xf32>
}
// CHECK-LABEL: @splat_array_of_f32
spirv.func @splat_array_of_f32() -> (!spirv.array<3 x f32>) "None" {
- // CHECK: spirv.Constant 1.000000e+00 : f32
- %0 = spirv.Constant 1.0 : f32
- // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<3 x f32>
- %1 = spirv.EXT.ConstantCompositeReplicate %0 : !spirv.array<3 x f32>
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<3 x f32>
+ %1 = spirv.EXT.ConstantCompositeReplicate [1.0 : f32] : !spirv.array<3 x f32>
spirv.ReturnValue %1 : !spirv.array<3 x f32>
}
// CHECK-LABEL: @splat_array_of_vectors_of_f32
- spirv.func @splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" {
- // CHECK: spirv.Constant dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>
- %0 = spirv.Constant dense<[1.0, 2.0]> : vector<2xf32>
- // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<2 x vector<2xf32>>
- %1 = spirv.EXT.ConstantCompositeReplicate %0 : !spirv.array<2 x vector<2xf32>>
- spirv.ReturnValue %1 : !spirv.array<2 x vector<2xf32>>
- }
-
- // CHECK-LABEL: @splat_array_of_splat_vector_f32
- spirv.func @splat_array_of_splat_vector_f32() -> (!spirv.array<2 x vector<2xf32>>) "None" {
- // CHECK: spirv.Constant 2.000000e+00 : f32
- %0 = spirv.Constant 2.0 : f32
- // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : vector<2xf32>
- %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xf32>
- // CHECK: spirv.EXT.ConstantCompositeReplicate {{.*}} : !spirv.array<2 x vector<2xf32>>
- %2 = spirv.EXT.ConstantCompositeReplicate %1 : !spirv.array<2 x vector<2xf32>>
- spirv.ReturnValue %2 : !spirv.array<2 x vector<2xf32>>
+ spirv.func @splat_array_of_vectors_of_f32() -> (!spirv.array<3 x vector<2xf32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>] : !spirv.array<3 x vector<2xf32>>
+ %0 = spirv.EXT.ConstantCompositeReplicate [dense<[1.0, 2.0]> : vector<2xf32>] : !spirv.array<3 x vector<2xf32>>
+ spirv.ReturnValue %0 : !spirv.array<3 x vector<2xf32>>
+ }
+
+ // CHECK-LABEL: @splat_array_of_splat_vectors_of_f32
+ spirv.func @splat_array_of_splat_vectors_of_f32() -> (!spirv.array<2 x vector<2xf32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.array<2 x vector<2xf32>>
+ %0 = spirv.EXT.ConstantCompositeReplicate [2.0 : f32] : !spirv.array<2 x vector<2xf32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
+ }
+
+ // CHECK-LABEL: @splat_tensor_of_f32
+ spirv.func @splat_tensor_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
+ %0 = spirv.EXT.ConstantCompositeReplicate [3.0 : f32] : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
}
}
More information about the Mlir-commits
mailing list