[Mlir-commits] [mlir] [mlir][spirv] Add basic support for SPV_EXT_replicated_composites (PR #147067)
Mohammadreza Ameri Mahabadian
llvmlistbot at llvm.org
Tue Jul 8 05:56:51 PDT 2025
================
@@ -765,6 +765,71 @@ void mlir::spirv::AddressOfOp::getAsmResultNames(
setNameFn(getResult(), specialName.str());
}
+//===----------------------------------------------------------------------===//
+// spirv.EXTConstantCompositeReplicate
+//===----------------------------------------------------------------------===//
+
+ParseResult
+spirv::EXTConstantCompositeReplicateOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::UnresolvedOperand constOperand;
+ Type resultType;
+ if (parser.parseOperand(constOperand) || parser.parseColonType(resultType)) {
+ return failure();
+ }
+
+ if (isa<TensorType>(resultType)) {
+ if (parser.parseColonType(resultType))
+ return failure();
+ }
+
+ auto compositeType = dyn_cast_or_null<spirv::CompositeType>(resultType);
+ if (!compositeType)
+ return parser.emitError(parser.getCurrentLocation(),
+ "result is not a composite type");
+
+ Type constType = compositeType.getElementType(0);
+ while (auto type = dyn_cast<spirv::ArrayType>(constType)) {
+ constType = type.getElementType();
+ }
+
+ if (parser.resolveOperand(constOperand, constType, result.operands))
+ return failure();
+
+ return parser.addTypeToList(compositeType, result.types);
+}
+
+void spirv::EXTConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
+ printer << ' ' << getConstant() << " : " << getType();
+}
+
+LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
+ auto compositeType = dyn_cast<spirv::CompositeType>(getType());
+ if (!compositeType)
+ return emitError("result type must be a composite type, but provided ")
+ << getType();
+
+ Operation *constantDefiningOp = getConstant().getDefiningOp();
----------------
mahabadm wrote:
Thanks for your comment @kuhar
I suppose you mean that this needs being re-implemnted with the second method that I mentioned in RFC. That is, using an attribute like below. Can you please confirm that?
`%0 = spirv.EXT.ConstantCompositeReplicate 1 : vector<4xi32>`
As mentioned though the problem with this approach would be that it wont work for a less trivial constant type like the one below, as one would need to specify the type of the splat value:
`spirv.Constant [dense<2> : vector<2xi32>, dense<2> : vector<2xi32>] : !spirv.array<2xvector<2xi32>>`
In this example `spirv.EXT.ConstantCompositeReplicate` can be either:
1. Splat value is of type `i32`:
`%0 = spirv.EXT.ConstantCompositeReplicate 2 : !spirv.array<2xvector<2xi32>>`
or
2. Splat value is of type `vector<2xi32>`
`%0 = spirv.EXT.ConstantCompositeReplicate dense<2> : !spirv.array<2xvector<2xi32>>`
To resolve that, one can include the attribute type in the representation, like below, respectively:
`%0 = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>>`
`%0 = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<3 x vector<2xi32>>`
I appreciate if you please confirm if you would be happy with this approach.
https://github.com/llvm/llvm-project/pull/147067
More information about the Mlir-commits
mailing list