[Mlir-commits] [mlir] [mlir][python] loosen infertype reqs in bindings generation (PR #73620)

Maksim Levental llvmlistbot at llvm.org
Tue Nov 28 11:30:46 PST 2023


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/73620

>From cef587c68b5d2379b3bc4d5a9fb57412068741e7 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Tue, 28 Nov 2023 02:17:19 -0600
Subject: [PATCH] [mlir][python] loosen infertype

---
 mlir/lib/CAPI/IR/IR.cpp                       |  6 +-
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 77 +++++++++++--------
 2 files changed, 45 insertions(+), 38 deletions(-)

diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index d1ee1b774c34478..b1d87f0ad6a09e4 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -18,7 +18,6 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/Dialect.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/OperationSupport.h"
@@ -29,7 +28,6 @@
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Parser/Parser.h"
 
-#include <cstddef>
 #include <memory>
 #include <optional>
 
@@ -471,9 +469,7 @@ MlirOperation mlirOperationCreate(MlirOperationState *state) {
   free(state->attributes);
 
   // Infer result types.
-  if (state->enableResultTypeInference) {
-    assert(cppState.types.empty() &&
-           "result type inference enabled and result types provided");
+  if (state->enableResultTypeInference && cppState.types.empty()) {
     if (failed(inferOperationTypes(cppState)))
       return {nullptr};
   }
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0c0ad2cfeffdcc2..02259b8a140ea28 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -582,8 +582,7 @@ static bool hasFirstAttrDerivedResultTypes(const Operator &op) {
 /// Returns true if the InferTypeOpInterface can be used to infer result types
 /// of the given operation.
 static bool hasInferTypeInterface(const Operator &op) {
-  return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
-         op.getNumRegions() == 0;
+  return op.getTrait("::mlir::InferTypeOpInterface::Trait");
 }
 
 /// Returns true if there is a trait or interface that can be used to infer
@@ -598,9 +597,6 @@ static bool canInferType(const Operator &op) {
 static void
 populateBuilderArgsResults(const Operator &op,
                            llvm::SmallVectorImpl<std::string> &builderArgs) {
-  if (canInferType(op))
-    return;
-
   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
     std::string name = op.getResultName(i).str();
     if (name.empty()) {
@@ -769,6 +765,36 @@ static void appendLineByLine(StringRef string,
   } while (!split.second.empty());
 }
 
+static void populateBuilderLinesResultList(
+    const Operator &op, llvm::ArrayRef<std::string> names,
+    llvm::SmallVectorImpl<std::string> &builderLines) {
+  bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
+  // For each element, find or generate a name.
+  for (int i = 0, e = op.getNumResults(); i < e; ++i) {
+    const NamedTypeConstraint &element = op.getResult(i);
+    std::string name = names[i];
+
+    // Choose the formatting string based on the element kind.
+    llvm::StringRef formatString;
+    if (!element.isVariableLength()) {
+      formatString = singleResultAppendTemplate;
+    } else if (element.isOptional()) {
+      formatString = optionalAppendResultTemplate;
+    } else {
+      assert(element.isVariadic() && "unhandled element group type");
+      // If emitting with sizedSegments, then we add the actual list-typed
+      // element. Otherwise, we extend the actual operands.
+      if (sizedSegments) {
+        formatString = singleResultAppendTemplate;
+      } else {
+        formatString = multiResultAppendTemplate;
+      }
+    }
+
+    builderLines.push_back(llvm::formatv(formatString.data(), name));
+  }
+}
+
 /// Populates `builderLines` with additional lines that are required in the
 /// builder to set up op results.
 static void
@@ -798,31 +824,7 @@ populateBuilderLinesResult(const Operator &op,
 
   if (hasInferTypeInterface(op))
     return;
-
-  // For each element, find or generate a name.
-  for (int i = 0, e = op.getNumResults(); i < e; ++i) {
-    const NamedTypeConstraint &element = op.getResult(i);
-    std::string name = names[i];
-
-    // Choose the formatting string based on the element kind.
-    llvm::StringRef formatString;
-    if (!element.isVariableLength()) {
-      formatString = singleResultAppendTemplate;
-    } else if (element.isOptional()) {
-      formatString = optionalAppendResultTemplate;
-    } else {
-      assert(element.isVariadic() && "unhandled element group type");
-      // If emitting with sizedSegments, then we add the actual list-typed
-      // element. Otherwise, we extend the actual operands.
-      if (sizedSegments) {
-        formatString = singleResultAppendTemplate;
-      } else {
-        formatString = multiResultAppendTemplate;
-      }
-    }
-
-    builderLines.push_back(llvm::formatv(formatString.data(), name));
-  }
+  populateBuilderLinesResultList(op, names, builderLines);
 }
 
 /// If the operation has variadic regions, adds a builder argument to specify
@@ -861,7 +863,8 @@ static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
   llvm::SmallVector<std::string> successorArgNames;
   builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
                       op.getNumNativeAttributes() + op.getNumSuccessors());
-  populateBuilderArgsResults(op, builderArgs);
+  if (!canInferType(op))
+    populateBuilderArgsResults(op, builderArgs);
   size_t numResultArgs = builderArgs.size();
   populateBuilderArgs(op, builderArgs, operandArgNames);
   size_t numOperandAttrArgs = builderArgs.size() - numResultArgs;
@@ -918,13 +921,21 @@ static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
       functionArgs.push_back(builderArgs[i]);
     }
   }
+  llvm::SmallVector<std::string> builderArgs2;
+  if (canInferType(op)) {
+    populateBuilderArgsResults(op, builderArgs2);
+    populateBuilderLinesResultList(op, builderArgs2, builderLines);
+    for (size_t i = 0, cnt = builderArgs2.size(); i < cnt; ++i) {
+      builderArgs2[i].append("=None");
+      functionArgs.push_back(builderArgs2[i]);
+    }
+  }
   functionArgs.push_back("loc=None");
   functionArgs.push_back("ip=None");
 
   SmallVector<std::string> initArgs;
   initArgs.push_back("attributes=attributes");
-  if (!hasInferTypeInterface(op))
-    initArgs.push_back("results=results");
+  initArgs.push_back("results=results");
   initArgs.push_back("operands=operands");
   initArgs.push_back("successors=_ods_successors");
   initArgs.push_back("regions=regions");



More information about the Mlir-commits mailing list