[Mlir-commits] [mlir] [mlir][ODS] Fix default inferReturnTypes generation for variadic operands (PR #131483)
Kunwar Grover
llvmlistbot at llvm.org
Sat Mar 15 17:21:24 PDT 2025
https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/131483
For variadic operands, `operands[odsOperandIndex]` is incorrect, because the operand can be variadic. Instead, create an adaptor and use it to get the correct operand.
>From 3cdf6e781e997c7483840ad1dec932d1ac5a27de Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Sun, 16 Mar 2025 00:17:37 +0000
Subject: [PATCH] [mlir][ODS] Fix default inferReturnTypes generation for
variadic operands
---
mlir/test/mlir-tblgen/op-result.td | 18 +++-----
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 51 ++++++++++-----------
2 files changed, 30 insertions(+), 39 deletions(-)
diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td
index a4f7af6dbcf1c..334ca118e31c0 100644
--- a/mlir/test/mlir-tblgen/op-result.td
+++ b/mlir/test/mlir-tblgen/op-result.td
@@ -136,9 +136,8 @@ def OpL1 : NS_Op<"op_with_all_types_constraint",
// CHECK-LABEL: LogicalResult OpL1::inferReturnTypes
// CHECK-NOT: }
-// CHECK: if (operands.size() <= 0)
-// CHECK-NEXT: return ::mlir::failure();
-// CHECK: ::mlir::Type odsInferredType0 = operands[0].getType();
+// CHECK: OpL1::Adaptor adaptor
+// CHECK: ::mlir::Type odsInferredType0 = adaptor.getA().getType();
// CHECK: inferredReturnTypes[0] = odsInferredType0;
def OpL2 : NS_Op<"op_with_all_types_constraint",
@@ -149,11 +148,9 @@ def OpL2 : NS_Op<"op_with_all_types_constraint",
// CHECK-LABEL: LogicalResult OpL2::inferReturnTypes
// CHECK-NOT: }
-// CHECK: if (operands.size() <= 2)
-// CHECK-NEXT: return ::mlir::failure();
-// CHECK-NOT: if (operands.size() <= 0)
-// CHECK: ::mlir::Type odsInferredType0 = operands[2].getType();
-// CHECK: ::mlir::Type odsInferredType1 = operands[0].getType();
+// CHECK: OpL2::Adaptor adaptor
+// CHECK: ::mlir::Type odsInferredType0 = adaptor.getC().getType();
+// CHECK: ::mlir::Type odsInferredType1 = adaptor.getA().getType();
// CHECK: inferredReturnTypes[0] = odsInferredType0;
// CHECK: inferredReturnTypes[1] = odsInferredType1;
@@ -177,9 +174,8 @@ def OpL4 : NS_Op<"two_inference_edges", [
}
// CHECK-LABEL: LogicalResult OpL4::inferReturnTypes
-// CHECK: if (operands.size() <= 0)
-// CHECK-NEXT: return ::mlir::failure();
-// CHECK: odsInferredType0 = fromInput(operands[0].getType())
+// CHECK: OpL4::Adaptor adaptor
+// CHECK: odsInferredType0 = fromInput(adaptor.getInput().getType())
// CHECK: odsInferredType1 = infer0(odsInferredType0)
// CHECK: odsInferredType2 = infer1(odsInferredType1)
// CHECK: inferredReturnTypes[0] = odsInferredType0
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index b957c8ee9f8ab..8288e77b8f653 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -2641,8 +2641,7 @@ void OpEmitter::genSeparateArgParamBuilder() {
// Avoid emitting "resultTypes.size() >= 0u" which is always true.
if (!hasVariadicResult || numNonVariadicResults != 0)
- body << " "
- << "assert(resultTypes.size() "
+ body << " " << "assert(resultTypes.size() "
<< (hasVariadicResult ? ">=" : "==") << " "
<< numNonVariadicResults
<< "u && \"mismatched number of results\");\n";
@@ -3751,29 +3750,24 @@ void OpEmitter::genTypeInterfaceMethods() {
fctx.addSubst("_ctxt", "context");
body << " ::mlir::Builder odsBuilder(context);\n";
- // Preprocessing stage to verify all accesses to operands are valid.
- int maxAccessedIndex = -1;
- for (int i = 0, e = op.getNumResults(); i != e; ++i) {
- const InferredResultType &infer = op.getInferredResultType(i);
- if (!infer.isArg())
- continue;
- Operator::OperandOrAttribute arg =
- op.getArgToOperandOrAttribute(infer.getIndex());
- if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
- maxAccessedIndex =
- std::max(maxAccessedIndex, arg.operandOrAttributeIndex());
- }
- }
- if (maxAccessedIndex != -1) {
- body << " if (operands.size() <= " << Twine(maxAccessedIndex) << ")\n";
- body << " return ::mlir::failure();\n";
- }
+ // Emit an adaptor to access right ranges for ods operands.
+ body << " " << op.getCppClassName()
+ << "::Adaptor adaptor(operands, attributes, properties, regions);\n";
- // Process the type inference graph in topological order, starting from types
- // that are always fully-inferred: operands and results with constructible
- // types. The type inference graph here will always be a DAG, so this gives
- // us the correct order for generating the types. -1 is a placeholder to
- // indicate the type for a result has not been generated.
+ // TODO: Ideally, we should be doing some sort of verification here. This
+ // is however problemetic due to 2 reasons:
+ //
+ // 1. Adaptor::verify only verifies attributes. It really should verify
+ // if the number of given attributes is right too.
+ // 2. PDL passes empty properties to inferReturnTypes, which does not verify.
+ // Without properties, it's not really possible to verify the number of
+ // operands as we do not know the variadic operand segment sizes.
+
+ // Process the type inference graph in topological order, starting from
+ // types that are always fully-inferred: operands and results with
+ // constructible types. The type inference graph here will always be a
+ // DAG, so this gives us the correct order for generating the types. -1 is
+ // a placeholder to indicate the type for a result has not been generated.
SmallVector<int> constructedIndices(op.getNumResults(), -1);
int inferredTypeIdx = 0;
for (int numResults = op.getNumResults(); inferredTypeIdx != numResults;) {
@@ -3788,10 +3782,11 @@ void OpEmitter::genTypeInterfaceMethods() {
Operator::OperandOrAttribute arg =
op.getArgToOperandOrAttribute(infer.getIndex());
if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
- typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) +
- "].getType()")
- .str();
-
+ std::string getter =
+ "adaptor." +
+ op.getGetterName(
+ op.getOperand(arg.operandOrAttributeIndex()).name);
+ typeStr = (getter + "().getType()");
// If this is an attribute, index into the attribute dictionary.
} else {
auto *attr =
More information about the Mlir-commits
mailing list