[Mlir-commits] [mlir] [mlir][ODS] Fix default inferReturnTypes generation for variadic operands (PR #131483)

Kunwar Grover llvmlistbot at llvm.org
Wed Mar 19 04:49:25 PDT 2025


https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/131483

>From 6e92fe43f5db87b8eec944c9fb6fb27643cc1c3f Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Sun, 16 Mar 2025 00:17:37 +0000
Subject: [PATCH 1/3] [mlir][ODS] Fix default inferReturnTypes generation for
 variadic operands

---
 mlir/test/mlir-tblgen/op-result.td          | 18 +++-----
 mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 51 ++++++++++-----------
 2 files changed, 30 insertions(+), 39 deletions(-)

diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td
index a4f7af6dbcf1c..334ca118e31c0 100644
--- a/mlir/test/mlir-tblgen/op-result.td
+++ b/mlir/test/mlir-tblgen/op-result.td
@@ -136,9 +136,8 @@ def OpL1 : NS_Op<"op_with_all_types_constraint",
 
 // CHECK-LABEL: LogicalResult OpL1::inferReturnTypes
 // CHECK-NOT: }
-// CHECK: if (operands.size() <= 0)
-// CHECK-NEXT: return ::mlir::failure();
-// CHECK: ::mlir::Type odsInferredType0 = operands[0].getType();
+// CHECK: OpL1::Adaptor adaptor
+// CHECK: ::mlir::Type odsInferredType0 = adaptor.getA().getType();
 // CHECK: inferredReturnTypes[0] = odsInferredType0;
 
 def OpL2 : NS_Op<"op_with_all_types_constraint",
@@ -149,11 +148,9 @@ def OpL2 : NS_Op<"op_with_all_types_constraint",
 
 // CHECK-LABEL: LogicalResult OpL2::inferReturnTypes
 // CHECK-NOT: }
-// CHECK: if (operands.size() <= 2)
-// CHECK-NEXT: return ::mlir::failure();
-// CHECK-NOT: if (operands.size() <= 0)
-// CHECK: ::mlir::Type odsInferredType0 = operands[2].getType();
-// CHECK: ::mlir::Type odsInferredType1 = operands[0].getType();
+// CHECK: OpL2::Adaptor adaptor
+// CHECK: ::mlir::Type odsInferredType0 = adaptor.getC().getType();
+// CHECK: ::mlir::Type odsInferredType1 = adaptor.getA().getType();
 // CHECK: inferredReturnTypes[0] = odsInferredType0;
 // CHECK: inferredReturnTypes[1] = odsInferredType1;
 
@@ -177,9 +174,8 @@ def OpL4 : NS_Op<"two_inference_edges", [
 }
 
 // CHECK-LABEL: LogicalResult OpL4::inferReturnTypes
-// CHECK: if (operands.size() <= 0)
-// CHECK-NEXT: return ::mlir::failure();
-// CHECK: odsInferredType0 = fromInput(operands[0].getType())
+// CHECK: OpL4::Adaptor adaptor
+// CHECK: odsInferredType0 = fromInput(adaptor.getInput().getType())
 // CHECK: odsInferredType1 = infer0(odsInferredType0)
 // CHECK: odsInferredType2 = infer1(odsInferredType1)
 // CHECK: inferredReturnTypes[0] = odsInferredType0
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index b957c8ee9f8ab..8288e77b8f653 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -2641,8 +2641,7 @@ void OpEmitter::genSeparateArgParamBuilder() {
 
       // Avoid emitting "resultTypes.size() >= 0u" which is always true.
       if (!hasVariadicResult || numNonVariadicResults != 0)
-        body << "  "
-             << "assert(resultTypes.size() "
+        body << "  " << "assert(resultTypes.size() "
              << (hasVariadicResult ? ">=" : "==") << " "
              << numNonVariadicResults
              << "u && \"mismatched number of results\");\n";
@@ -3751,29 +3750,24 @@ void OpEmitter::genTypeInterfaceMethods() {
   fctx.addSubst("_ctxt", "context");
   body << "  ::mlir::Builder odsBuilder(context);\n";
 
-  // Preprocessing stage to verify all accesses to operands are valid.
-  int maxAccessedIndex = -1;
-  for (int i = 0, e = op.getNumResults(); i != e; ++i) {
-    const InferredResultType &infer = op.getInferredResultType(i);
-    if (!infer.isArg())
-      continue;
-    Operator::OperandOrAttribute arg =
-        op.getArgToOperandOrAttribute(infer.getIndex());
-    if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
-      maxAccessedIndex =
-          std::max(maxAccessedIndex, arg.operandOrAttributeIndex());
-    }
-  }
-  if (maxAccessedIndex != -1) {
-    body << "  if (operands.size() <= " << Twine(maxAccessedIndex) << ")\n";
-    body << "    return ::mlir::failure();\n";
-  }
+  // Emit an adaptor to access right ranges for ods operands.
+  body << "  " << op.getCppClassName()
+       << "::Adaptor adaptor(operands, attributes, properties, regions);\n";
 
-  // Process the type inference graph in topological order, starting from types
-  // that are always fully-inferred: operands and results with constructible
-  // types. The type inference graph here will always be a DAG, so this gives
-  // us the correct order for generating the types. -1 is a placeholder to
-  // indicate the type for a result has not been generated.
+  // TODO: Ideally, we should be doing some sort of verification here. This
+  // is however problemetic due to 2 reasons:
+  //
+  // 1. Adaptor::verify only verifies attributes. It really should verify
+  //    if the number of given attributes is right too.
+  // 2. PDL passes empty properties to inferReturnTypes, which does not verify.
+  //    Without properties, it's not really possible to verify the number of
+  //    operands as we do not know the variadic operand segment sizes.
+
+  // Process the type inference graph in topological order, starting from
+  // types that are always fully-inferred: operands and results with
+  // constructible types. The type inference graph here will always be a
+  // DAG, so this gives us the correct order for generating the types. -1 is
+  // a placeholder to indicate the type for a result has not been generated.
   SmallVector<int> constructedIndices(op.getNumResults(), -1);
   int inferredTypeIdx = 0;
   for (int numResults = op.getNumResults(); inferredTypeIdx != numResults;) {
@@ -3788,10 +3782,11 @@ void OpEmitter::genTypeInterfaceMethods() {
         Operator::OperandOrAttribute arg =
             op.getArgToOperandOrAttribute(infer.getIndex());
         if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
-          typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) +
-                     "].getType()")
-                        .str();
-
+          std::string getter =
+              "adaptor." +
+              op.getGetterName(
+                  op.getOperand(arg.operandOrAttributeIndex()).name);
+          typeStr = (getter + "().getType()");
           // If this is an attribute, index into the attribute dictionary.
         } else {
           auto *attr =

>From dab553574367755d8c411430cf583fae8c773a92 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Wed, 19 Mar 2025 11:33:02 +0000
Subject: [PATCH 2/3] Remove TODO, already mentioned in docs

---
 mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 9 ---------
 1 file changed, 9 deletions(-)

diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 8288e77b8f653..f61129c234ddf 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -3754,15 +3754,6 @@ void OpEmitter::genTypeInterfaceMethods() {
   body << "  " << op.getCppClassName()
        << "::Adaptor adaptor(operands, attributes, properties, regions);\n";
 
-  // TODO: Ideally, we should be doing some sort of verification here. This
-  // is however problemetic due to 2 reasons:
-  //
-  // 1. Adaptor::verify only verifies attributes. It really should verify
-  //    if the number of given attributes is right too.
-  // 2. PDL passes empty properties to inferReturnTypes, which does not verify.
-  //    Without properties, it's not really possible to verify the number of
-  //    operands as we do not know the variadic operand segment sizes.
-
   // Process the type inference graph in topological order, starting from
   // types that are always fully-inferred: operands and results with
   // constructible types. The type inference graph here will always be a

>From 1beaad382c744bfd23588d0ca504aee0ab5742e8 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Wed, 19 Mar 2025 11:48:39 +0000
Subject: [PATCH 3/3] Add test for variandic + one norm input

---
 mlir/test/mlir-tblgen/op-result.td | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td
index 334ca118e31c0..7f882ce0dfce4 100644
--- a/mlir/test/mlir-tblgen/op-result.td
+++ b/mlir/test/mlir-tblgen/op-result.td
@@ -203,6 +203,18 @@ def OpL6 : NS_Op<"op_with_same_and_constraint_results",
 // CHECK: inferredReturnTypes[1] = odsInferredType1;
 // CHECK: inferredReturnTypes[2] = odsInferredType2;
 
+def OpL7 : NS_Op<"one_variadic_and_one_normal_operand_with_infer_result_op",
+    [TypesMatchWith<"", "input2", "output1", "infer0($_self)">]> {
+  let arguments = (ins Variadic<AnyTensor>:$input1, AnyTensor:$input2);
+  let results = (outs AnyTensor:$output1);
+}
+
+// CHECK-LABEL: LogicalResult OpL7::inferReturnTypes
+// CHECK-NOT: }
+// CHECK: OpL7::Adaptor adaptor
+// CHECK: odsInferredType0 = infer0(adaptor.getInput2().getType())
+// CHECK: inferredReturnTypes[0] = odsInferredType0
+
 def OpM : NS_Op<"mix_diff_size_variadic_and_normal_results_op", [AttrSizedResultSegments]> {
   let results = (outs Variadic<AnyTensor>:$output1, AnyTensor:$output2, Optional<AnyTensor>:$output3);
 }



More information about the Mlir-commits mailing list