[Mlir-commits] [mlir] f573bc2 - [mlir][py] Reuse more of CAPI build time inference.
Jacques Pienaar
llvmlistbot at llvm.org
Sun Jul 23 21:27:00 PDT 2023
Author: Jacques Pienaar
Date: 2023-07-23T21:26:52-07:00
New Revision: f573bc24d4ea61d68ea61e49eeefbb9a84fa4f34
URL: https://github.com/llvm/llvm-project/commit/f573bc24d4ea61d68ea61e49eeefbb9a84fa4f34
DIFF: https://github.com/llvm/llvm-project/commit/f573bc24d4ea61d68ea61e49eeefbb9a84fa4f34.diff
LOG: [mlir][py] Reuse more of CAPI build time inference.
This reduces code generated for type inference and instead reuses
facilities CAPI side that performed same role.
Differential Revision: https://reviews.llvm.org/D156041t
Added:
Modified:
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.h
mlir/test/mlir-tblgen/op-python-bindings.td
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 39049f387eda70..971d2819ade44b 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -78,6 +78,7 @@ static const char kOperationCreateDocstring[] =
ip: An InsertionPoint (defaults to resolve from context manager or set to
False to disable insertion, even with an insertion point set in the
context manager).
+ infer_type: Whether to infer result types.
Returns:
A new "detached" Operation object. Detached operations can be added
to blocks, which causes them to become "attached."
@@ -1288,7 +1289,7 @@ py::object PyOperation::create(const std::string &name,
std::optional<py::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
int regions, DefaultingPyLocation location,
- const py::object &maybeIp) {
+ const py::object &maybeIp, bool inferType) {
llvm::SmallVector<MlirValue, 4> mlirOperands;
llvm::SmallVector<MlirType, 4> mlirResults;
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
@@ -1367,6 +1368,7 @@ py::object PyOperation::create(const std::string &name,
if (!mlirOperands.empty())
mlirOperationStateAddOperands(&state, mlirOperands.size(),
mlirOperands.data());
+ state.enableResultTypeInference = inferType;
if (!mlirResults.empty())
mlirOperationStateAddResults(&state, mlirResults.size(),
mlirResults.data());
@@ -1398,6 +1400,8 @@ py::object PyOperation::create(const std::string &name,
// Construct the operation.
MlirOperation operation = mlirOperationCreate(&state);
+ if (!operation.ptr)
+ throw py::value_error("Operation creation failed");
PyOperationRef created =
PyOperation::createDetached(location->getContext(), operation);
maybeInsertOperation(created, maybeIp);
@@ -1441,51 +1445,10 @@ void PyOperation::erase() {
// PyOpView
//------------------------------------------------------------------------------
-py::object
-PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList,
- py::list operandList, std::optional<py::dict> attributes,
- std::optional<std::vector<PyBlock *>> successors,
- std::optional<int> regions,
- DefaultingPyLocation location,
- const py::object &maybeIp) {
- PyMlirContextRef context = location->getContext();
- // Class level operation construction metadata.
- std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
- // Operand and result segment specs are either none, which does no
- // variadic unpacking, or a list of ints with segment sizes, where each
- // element is either a positive number (typically 1 for a scalar) or -1 to
- // indicate that it is derived from the length of the same-indexed operand
- // or result (implying that it is a list at that position).
- py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
- py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
-
- std::vector<int32_t> operandSegmentLengths;
- std::vector<int32_t> resultSegmentLengths;
-
- // Validate/determine region count.
- auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
- int opMinRegionCount = std::get<0>(opRegionSpec);
- bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
- if (!regions) {
- regions = opMinRegionCount;
- }
- if (*regions < opMinRegionCount) {
- throw py::value_error(
- (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
- llvm::Twine(opMinRegionCount) +
- " regions but was built with regions=" + llvm::Twine(*regions))
- .str());
- }
- if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
- throw py::value_error(
- (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
- llvm::Twine(opMinRegionCount) +
- " regions but was built with regions=" + llvm::Twine(*regions))
- .str());
- }
-
- // Unpack results.
- std::vector<PyType *> resultTypes;
+static void populateResultTypes(StringRef name, py::list resultTypeList,
+ const py::object &resultSegmentSpecObj,
+ std::vector<int32_t> &resultSegmentLengths,
+ std::vector<PyType *> &resultTypes) {
resultTypes.reserve(resultTypeList.size());
if (resultSegmentSpecObj.is_none()) {
// Non-variadic result unpacking.
@@ -1568,6 +1531,56 @@ PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList,
}
}
}
+}
+
+py::object PyOpView::buildGeneric(
+ const py::object &cls, std::optional<py::list> resultTypeList,
+ py::list operandList, std::optional<py::dict> attributes,
+ std::optional<std::vector<PyBlock *>> successors,
+ std::optional<int> regions, DefaultingPyLocation location,
+ const py::object &maybeIp) {
+ PyMlirContextRef context = location->getContext();
+ // Class level operation construction metadata.
+ std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
+ // Operand and result segment specs are either none, which does no
+ // variadic unpacking, or a list of ints with segment sizes, where each
+ // element is either a positive number (typically 1 for a scalar) or -1 to
+ // indicate that it is derived from the length of the same-indexed operand
+ // or result (implying that it is a list at that position).
+ py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
+ py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
+
+ std::vector<int32_t> operandSegmentLengths;
+ std::vector<int32_t> resultSegmentLengths;
+
+ // Validate/determine region count.
+ auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
+ int opMinRegionCount = std::get<0>(opRegionSpec);
+ bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
+ if (!regions) {
+ regions = opMinRegionCount;
+ }
+ if (*regions < opMinRegionCount) {
+ throw py::value_error(
+ (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
+ llvm::Twine(opMinRegionCount) +
+ " regions but was built with regions=" + llvm::Twine(*regions))
+ .str());
+ }
+ if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
+ throw py::value_error(
+ (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
+ llvm::Twine(opMinRegionCount) +
+ " regions but was built with regions=" + llvm::Twine(*regions))
+ .str());
+ }
+
+ // Unpack results.
+ std::vector<PyType *> resultTypes;
+ if (resultTypeList.has_value()) {
+ populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
+ resultSegmentLengths, resultTypes);
+ }
// Unpack operands.
std::vector<PyValue *> operands;
@@ -1694,7 +1707,8 @@ PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList,
/*operands=*/std::move(operands),
/*attributes=*/std::move(attributes),
/*successors=*/std::move(successors),
- /*regions=*/*regions, location, maybeIp);
+ /*regions=*/*regions, location, maybeIp,
+ !resultTypeList);
}
pybind11::object PyOpView::constructDerived(const pybind11::object &cls,
@@ -2854,7 +2868,7 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("attributes") = py::none(),
py::arg("successors") = py::none(), py::arg("regions") = 0,
py::arg("loc") = py::none(), py::arg("ip") = py::none(),
- kOperationCreateDocstring)
+ py::arg("infer_type") = false, kOperationCreateDocstring)
.def_static(
"parse",
[](const std::string &sourceStr, const std::string &sourceName,
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 5da1d7d25ebe9c..d1911730c1ede0 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -655,7 +655,8 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
std::optional<std::vector<PyValue *>> operands,
std::optional<pybind11::dict> attributes,
std::optional<std::vector<PyBlock *>> successors, int regions,
- DefaultingPyLocation location, const pybind11::object &ip);
+ DefaultingPyLocation location, const pybind11::object &ip,
+ bool inferType);
/// Creates an OpView suitable for this operation.
pybind11::object createOpView();
@@ -704,13 +705,12 @@ class PyOpView : public PyOperationBase {
pybind11::object getOperationObject() { return operationObject; }
- static pybind11::object
- buildGeneric(const pybind11::object &cls, pybind11::list resultTypeList,
- pybind11::list operandList,
- std::optional<pybind11::dict> attributes,
- std::optional<std::vector<PyBlock *>> successors,
- std::optional<int> regions, DefaultingPyLocation location,
- const pybind11::object &maybeIp);
+ static pybind11::object buildGeneric(
+ const pybind11::object &cls, std::optional<pybind11::list> resultTypeList,
+ pybind11::list operandList, std::optional<pybind11::dict> attributes,
+ std::optional<std::vector<PyBlock *>> successors,
+ std::optional<int> regions, DefaultingPyLocation location,
+ const pybind11::object &maybeIp);
/// Construct an instance of a class deriving from OpView, bypassing its
/// `__init__` method. The derived class will typically define a constructor
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 036ea5b4a852e2..de979f7e8f43e6 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -245,14 +245,10 @@ def EmptyOp : TestOp<"empty">;
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
// CHECK: def __init__(self, *, loc=None, ip=None):
- // CHECK: operands = []
- // CHECK: results = []
// CHECK: _ods_context = _ods_get_default_loc_context(loc)
- // CHECK: results = _ods_ir.InferTypeOpInterface(InferResultTypesImpliedOp).inferReturnTypes(
- // CHECK: operands=operands,
- // CHECK: attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
- // CHECK: context=_ods_context,
- // CHECK: loc=loc)
+ // CHECK: super().__init__(self.build_generic(
+ // CHECK: attributes=attributes, operands=operands,
+ // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip))
let results = (outs I32:$i32, F32:$f32);
}
@@ -260,13 +256,9 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
// CHECK: def __init__(self, *, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
- // CHECK: _ods_context = _ods_get_default_loc_context(loc)
- // CHECK: results = _ods_ir.InferTypeOpInterface(InferResultTypesOp).inferReturnTypes(
- // CHECK: operands=operands,
- // CHECK: attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
- // CHECK: context=_ods_context,
- // CHECK: loc=loc)
+ // CHECK: super().__init__(self.build_generic(
+ // CHECK: attributes=attributes, operands=operands,
+ // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip))
let results = (outs AnyType, AnyType, AnyType);
}
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index f17e6e9eda4031..dd6e52d300efe1 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -493,9 +493,7 @@ constexpr const char *initTemplate = R"Py(
attributes = {{}
regions = None
{1}
- super().__init__(self.build_generic(
- attributes=attributes, results=results, operands=operands,
- successors=_ods_successors, regions=regions, loc=loc, ip=ip))
+ super().__init__(self.build_generic({2}))
)Py";
/// Template for appending a single element to the operand/result list.
@@ -755,17 +753,6 @@ _ods_derived_result_type = (
/// Python code template appending {0} type {1} times to the results list.
constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
-/// Python code template for inferring the operation results using the
-/// corresponding interface:
-/// - {0} is the name of the class for which the types are inferred.
-constexpr const char *inferTypeInterfaceTemplate =
- R"PY(results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
- operands=operands,
- attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
- context=_ods_context,
- loc=loc)
-)PY";
-
/// Appends the given multiline string as individual strings into
/// `builderLines`.
static void appendLineByLine(StringRef string,
@@ -805,12 +792,8 @@ populateBuilderLinesResult(const Operator &op,
return;
}
- if (hasInferTypeInterface(op)) {
- appendLineByLine(
- llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(),
- builderLines);
+ if (hasInferTypeInterface(op))
return;
- }
// For each element, find or generate a name.
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
@@ -934,8 +917,20 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
}
functionArgs.push_back("loc=None");
functionArgs.push_back("ip=None");
+
+ SmallVector<std::string> initArgs;
+ initArgs.push_back("attributes=attributes");
+ if (!hasInferTypeInterface(op))
+ initArgs.push_back("results=results");
+ initArgs.push_back("operands=operands");
+ initArgs.push_back("successors=_ods_successors");
+ initArgs.push_back("regions=regions");
+ initArgs.push_back("loc=loc");
+ initArgs.push_back("ip=ip");
+
os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "),
- llvm::join(builderLines, "\n "));
+ llvm::join(builderLines, "\n "),
+ llvm::join(initArgs, ", "));
}
static void emitSegmentSpec(
More information about the Mlir-commits
mailing list