[Mlir-commits] [mlir] [mlir][spirv] Add basic support for SPV_EXT_replicated_composites (PR #147067)

Mohammadreza Ameri Mahabadian llvmlistbot at llvm.org
Tue Jul 8 05:56:51 PDT 2025


================
@@ -765,6 +765,71 @@ void mlir::spirv::AddressOfOp::getAsmResultNames(
   setNameFn(getResult(), specialName.str());
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.EXTConstantCompositeReplicate
+//===----------------------------------------------------------------------===//
+
+ParseResult
+spirv::EXTConstantCompositeReplicateOp::parse(OpAsmParser &parser,
+                                              OperationState &result) {
+  OpAsmParser::UnresolvedOperand constOperand;
+  Type resultType;
+  if (parser.parseOperand(constOperand) || parser.parseColonType(resultType)) {
+    return failure();
+  }
+
+  if (isa<TensorType>(resultType)) {
+    if (parser.parseColonType(resultType))
+      return failure();
+  }
+
+  auto compositeType = dyn_cast_or_null<spirv::CompositeType>(resultType);
+  if (!compositeType)
+    return parser.emitError(parser.getCurrentLocation(),
+                            "result is not a composite type");
+
+  Type constType = compositeType.getElementType(0);
+  while (auto type = dyn_cast<spirv::ArrayType>(constType)) {
+    constType = type.getElementType();
+  }
+
+  if (parser.resolveOperand(constOperand, constType, result.operands))
+    return failure();
+
+  return parser.addTypeToList(compositeType, result.types);
+}
+
+void spirv::EXTConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
+  printer << ' ' << getConstant() << " : " << getType();
+}
+
+LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
+  auto compositeType = dyn_cast<spirv::CompositeType>(getType());
+  if (!compositeType)
+    return emitError("result type must be a composite type, but provided ")
+           << getType();
+
+  Operation *constantDefiningOp = getConstant().getDefiningOp();
----------------
mahabadm wrote:

Thanks for your comment @kuhar 

I suppose you mean that this needs being re-implemnted with the second method that I mentioned in RFC. That is, using an attribute like below. Can you please confirm that?

`%0 = spirv.EXT.ConstantCompositeReplicate 1 : vector<4xi32>`

As mentioned though the problem with this approach would be that it wont work for a less trivial constant type like the one below, as one would need to specify the type of the splat value:

`spirv.Constant [dense<2> : vector<2xi32>, dense<2> : vector<2xi32>] : !spirv.array<2xvector<2xi32>>`

In this example `spirv.EXT.ConstantCompositeReplicate` can be either:

1. Splat value is of type `i32`:
`%0 = spirv.EXT.ConstantCompositeReplicate 2 : !spirv.array<2xvector<2xi32>>`
or 
2. Splat value is of type `vector<2xi32>`
`%0 = spirv.EXT.ConstantCompositeReplicate dense<2> : !spirv.array<2xvector<2xi32>>`

To resolve that, one can include the attribute type in the representation, like below, respectively:

`%0 = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>>`

`%0 = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<3 x vector<2xi32>>`

I appreciate if you please confirm if you would be happy with this approach.

https://github.com/llvm/llvm-project/pull/147067


More information about the Mlir-commits mailing list