[Mlir-commits] [mlir] [mlir][python] loosen infertype reqs in bindings generation (PR #73620)
Maksim Levental
llvmlistbot at llvm.org
Tue Nov 28 11:30:46 PST 2023
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/73620
>From cef587c68b5d2379b3bc4d5a9fb57412068741e7 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Tue, 28 Nov 2023 02:17:19 -0600
Subject: [PATCH] [mlir][python] loosen infertype
---
mlir/lib/CAPI/IR/IR.cpp | 6 +-
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 77 +++++++++++--------
2 files changed, 45 insertions(+), 38 deletions(-)
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index d1ee1b774c34478..b1d87f0ad6a09e4 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -18,7 +18,6 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
@@ -29,7 +28,6 @@
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Parser/Parser.h"
-#include <cstddef>
#include <memory>
#include <optional>
@@ -471,9 +469,7 @@ MlirOperation mlirOperationCreate(MlirOperationState *state) {
free(state->attributes);
// Infer result types.
- if (state->enableResultTypeInference) {
- assert(cppState.types.empty() &&
- "result type inference enabled and result types provided");
+ if (state->enableResultTypeInference && cppState.types.empty()) {
if (failed(inferOperationTypes(cppState)))
return {nullptr};
}
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0c0ad2cfeffdcc2..02259b8a140ea28 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -582,8 +582,7 @@ static bool hasFirstAttrDerivedResultTypes(const Operator &op) {
/// Returns true if the InferTypeOpInterface can be used to infer result types
/// of the given operation.
static bool hasInferTypeInterface(const Operator &op) {
- return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
- op.getNumRegions() == 0;
+ return op.getTrait("::mlir::InferTypeOpInterface::Trait");
}
/// Returns true if there is a trait or interface that can be used to infer
@@ -598,9 +597,6 @@ static bool canInferType(const Operator &op) {
static void
populateBuilderArgsResults(const Operator &op,
llvm::SmallVectorImpl<std::string> &builderArgs) {
- if (canInferType(op))
- return;
-
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
std::string name = op.getResultName(i).str();
if (name.empty()) {
@@ -769,6 +765,36 @@ static void appendLineByLine(StringRef string,
} while (!split.second.empty());
}
+static void populateBuilderLinesResultList(
+ const Operator &op, llvm::ArrayRef<std::string> names,
+ llvm::SmallVectorImpl<std::string> &builderLines) {
+ bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
+ // For each element, find or generate a name.
+ for (int i = 0, e = op.getNumResults(); i < e; ++i) {
+ const NamedTypeConstraint &element = op.getResult(i);
+ std::string name = names[i];
+
+ // Choose the formatting string based on the element kind.
+ llvm::StringRef formatString;
+ if (!element.isVariableLength()) {
+ formatString = singleResultAppendTemplate;
+ } else if (element.isOptional()) {
+ formatString = optionalAppendResultTemplate;
+ } else {
+ assert(element.isVariadic() && "unhandled element group type");
+ // If emitting with sizedSegments, then we add the actual list-typed
+ // element. Otherwise, we extend the actual operands.
+ if (sizedSegments) {
+ formatString = singleResultAppendTemplate;
+ } else {
+ formatString = multiResultAppendTemplate;
+ }
+ }
+
+ builderLines.push_back(llvm::formatv(formatString.data(), name));
+ }
+}
+
/// Populates `builderLines` with additional lines that are required in the
/// builder to set up op results.
static void
@@ -798,31 +824,7 @@ populateBuilderLinesResult(const Operator &op,
if (hasInferTypeInterface(op))
return;
-
- // For each element, find or generate a name.
- for (int i = 0, e = op.getNumResults(); i < e; ++i) {
- const NamedTypeConstraint &element = op.getResult(i);
- std::string name = names[i];
-
- // Choose the formatting string based on the element kind.
- llvm::StringRef formatString;
- if (!element.isVariableLength()) {
- formatString = singleResultAppendTemplate;
- } else if (element.isOptional()) {
- formatString = optionalAppendResultTemplate;
- } else {
- assert(element.isVariadic() && "unhandled element group type");
- // If emitting with sizedSegments, then we add the actual list-typed
- // element. Otherwise, we extend the actual operands.
- if (sizedSegments) {
- formatString = singleResultAppendTemplate;
- } else {
- formatString = multiResultAppendTemplate;
- }
- }
-
- builderLines.push_back(llvm::formatv(formatString.data(), name));
- }
+ populateBuilderLinesResultList(op, names, builderLines);
}
/// If the operation has variadic regions, adds a builder argument to specify
@@ -861,7 +863,8 @@ static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
llvm::SmallVector<std::string> successorArgNames;
builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
op.getNumNativeAttributes() + op.getNumSuccessors());
- populateBuilderArgsResults(op, builderArgs);
+ if (!canInferType(op))
+ populateBuilderArgsResults(op, builderArgs);
size_t numResultArgs = builderArgs.size();
populateBuilderArgs(op, builderArgs, operandArgNames);
size_t numOperandAttrArgs = builderArgs.size() - numResultArgs;
@@ -918,13 +921,21 @@ static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
functionArgs.push_back(builderArgs[i]);
}
}
+ llvm::SmallVector<std::string> builderArgs2;
+ if (canInferType(op)) {
+ populateBuilderArgsResults(op, builderArgs2);
+ populateBuilderLinesResultList(op, builderArgs2, builderLines);
+ for (size_t i = 0, cnt = builderArgs2.size(); i < cnt; ++i) {
+ builderArgs2[i].append("=None");
+ functionArgs.push_back(builderArgs2[i]);
+ }
+ }
functionArgs.push_back("loc=None");
functionArgs.push_back("ip=None");
SmallVector<std::string> initArgs;
initArgs.push_back("attributes=attributes");
- if (!hasInferTypeInterface(op))
- initArgs.push_back("results=results");
+ initArgs.push_back("results=results");
initArgs.push_back("operands=operands");
initArgs.push_back("successors=_ods_successors");
initArgs.push_back("regions=regions");
More information about the Mlir-commits
mailing list