[Mlir-commits] [mlir] [mlir][python] expose tryFold (PR #90148)
Maksim Levental
llvmlistbot at llvm.org
Thu Apr 25 16:27:16 PDT 2024
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/90148
>From 8842be09d1f1d344548460bee0d553aee905833f Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 25 Apr 2024 17:15:22 -0500
Subject: [PATCH] [mlir][python] expose tryFold
---
mlir/include/mlir-c/IR.h | 8 +++
mlir/include/mlir/IR/Builders.h | 9 ++++
mlir/lib/Bindings/Python/IRCore.cpp | 27 +++++++++-
mlir/lib/CAPI/IR/IR.cpp | 57 +++++++++++++++++++++-
mlir/lib/IR/Builders.cpp | 17 +++++--
mlir/test/python/dialects/arith_dialect.py | 6 +++
6 files changed, 118 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 32abacf353133e..3c6297f26771de 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -730,6 +730,14 @@ MLIR_CAPI_EXPORTED
void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
void *userData, MlirWalkOrder walkOrder);
+MLIR_CAPI_EXPORTED MlirLogicalResult
+mlirOperationTryFold(MlirContext mlirContext, MlirOperation mlirOp,
+ void *mlirResults, void *mlirConstants);
+
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirOperationTryFold(
+ MlirContext mlirContext, MlirOperation mlirOp, MlirValue **mlirResults,
+ size_t *numResults, MlirOperation **mlirConstants, size_t *numConstants);
+
//===----------------------------------------------------------------------===//
// Region API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 0d5fa719d0dee2..0adad2501d3737 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -204,6 +204,13 @@ class Builder {
MLIRContext *context;
};
+/// Attempts to fold the given operation and places new results within
+/// 'results'. Returns success if the operation was folded, failure otherwise.
+/// Note: This function does not erase the operation on a successful fold.
+LogicalResult tryFold(MLIRContext *context, Operation *op,
+ SmallVectorImpl<Value> &results,
+ SmallVector<Operation *, 1> &generatedConstants);
+
/// This class helps build Operations. Operations that are created are
/// automatically inserted at an insertion point. The builder is copyable.
class OpBuilder : public Builder {
@@ -570,6 +577,8 @@ class OpBuilder : public Builder {
/// `results`. Returns success if the operation was folded, failure otherwise.
/// If the fold was in-place, `results` will not be filled.
/// Note: This function does not erase the operation on a successful fold.
+ /// Note: This function inserts generated constants at the current insertion
+ /// point.
LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
/// Creates a deep copy of the specified operation, remapping any operands
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 01678a9719f90f..3ff7a0fc4abfb7 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3081,7 +3081,32 @@ void mlir::python::populateIRCore(py::module &m) {
"Detaches the operation from its parent block.")
.def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
.def("walk", &PyOperationBase::walk, py::arg("callback"),
- py::arg("walk_order") = MlirWalkPostOrder);
+ py::arg("walk_order") = MlirWalkPostOrder)
+ .def(
+ "try_fold",
+ [](PyOperationBase &self, MlirContext context) {
+ CollectDiagnosticsToStringScope scope(context);
+
+ size_t numResults;
+ size_t numConstants;
+ MlirValue *mlirResults;
+ MlirOperation *mlirConstants;
+
+ if (mlirLogicalResultIsFailure(mlirOperationTryFold(
+ context, self.getOperation(), &mlirResults, &numResults,
+ &mlirConstants, &numConstants)))
+ throw py::value_error(scope.takeMessage());
+
+ std::vector<MlirValue> result(mlirResults,
+ mlirResults + numResults);
+ std::vector<MlirOperation> generatedConstant(
+ mlirConstants, mlirConstants + numConstants);
+
+ free(mlirResults);
+ free(mlirConstants);
+ return std::pair(result, generatedConstant);
+ },
+ "context"_a = py::none());
py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
.def_static("create", &PyOperation::create, py::arg("name"),
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index a72cd247e73f60..1492b88f476131 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -15,10 +15,10 @@
#include "mlir/CAPI/Support.h"
#include "mlir/CAPI/Utils.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
@@ -30,7 +30,6 @@
#include "mlir/Parser/Parser.h"
#include "llvm/Support/ThreadPool.h"
-#include <cstddef>
#include <memory>
#include <optional>
@@ -748,6 +747,60 @@ void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
}
}
+MlirLogicalResult mlirOperationTryFold(MlirContext mlirContext,
+ MlirOperation mlirOp, void *mlirResults,
+ void *mlirConstants) {
+ Operation *op = unwrap(mlirOp);
+ MLIRContext *context = unwrap(mlirContext);
+
+ SmallVector<MlirValue> *unwrappedResults;
+ unwrappedResults = static_cast<decltype(unwrappedResults)>(mlirResults);
+ SmallVector<MlirOperation> *unwrappedConstants;
+ unwrappedConstants = static_cast<decltype(unwrappedConstants)>(mlirConstants);
+
+ SmallVector<Value, 1> results;
+ SmallVector<Operation *, 1> generatedConstants;
+
+ if (failed(tryFold(context, op, results, generatedConstants)))
+ return mlirLogicalResultFailure();
+
+ for (auto &item : results)
+ unwrappedResults->push_back(wrap(item));
+
+ for (auto &item : generatedConstants)
+ unwrappedConstants->push_back(wrap(item));
+
+ return mlirLogicalResultSuccess();
+}
+
+MlirLogicalResult
+mlirOperationTryFold(MlirContext mlirContext, MlirOperation mlirOp,
+ MlirValue **mlirResults, size_t *numResults,
+ MlirOperation **mlirConstants, size_t *numConstants) {
+ Operation *op = unwrap(mlirOp);
+ MLIRContext *context = unwrap(mlirContext);
+
+ SmallVector<Value, 1> results;
+ SmallVector<Operation *, 1> generatedConstants;
+
+ if (failed(tryFold(context, op, results, generatedConstants)))
+ return mlirLogicalResultFailure();
+
+ *numResults = results.size();
+ *numConstants = generatedConstants.size();
+ *mlirResults =
+ reinterpret_cast<MlirValue *>(malloc(sizeof(MlirValue) * results.size()));
+ *mlirConstants = reinterpret_cast<MlirOperation *>(
+ malloc(sizeof(MlirOperation) * generatedConstants.size()));
+
+ for (size_t i = 0; i < results.size(); ++i)
+ (*mlirResults)[i] = wrap(results[i]);
+ for (size_t i = 0; i < generatedConstants.size(); ++i)
+ (*mlirConstants)[i] = wrap(generatedConstants[i]);
+
+ return mlirLogicalResultSuccess();
+}
+
//===----------------------------------------------------------------------===//
// Region API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index d49f69a7b7ae6b..4f753d89c0ae7c 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -476,8 +476,12 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
return create(state);
}
-LogicalResult OpBuilder::tryFold(Operation *op,
- SmallVectorImpl<Value> &results) {
+/// Attempts to fold the given operation and places new results within
+/// 'results'. Returns success if the operation was folded, failure otherwise.
+/// Note: This function does not erase the operation on a successful fold.
+LogicalResult mlir::tryFold(MLIRContext *context, Operation *op,
+ SmallVectorImpl<Value> &results,
+ SmallVector<Operation *, 1> &generatedConstants) {
assert(results.empty() && "expected empty results");
ResultRange opResults = op->getResults();
@@ -502,7 +506,6 @@ LogicalResult OpBuilder::tryFold(Operation *op,
// A temporary builder used for creating constants during folding.
OpBuilder cstBuilder(context);
- SmallVector<Operation *, 1> generatedConstants;
// Populate the results with the folded results.
Dialect *dialect = op->getDialect();
@@ -535,6 +538,14 @@ LogicalResult OpBuilder::tryFold(Operation *op,
results.push_back(constOp->getResult(0));
}
+ return success();
+}
+
+LogicalResult OpBuilder::tryFold(Operation *op,
+ SmallVectorImpl<Value> &results) {
+ SmallVector<Operation *, 1> generatedConstants;
+ if (failed(mlir::tryFold(context, op, results, generatedConstants)))
+ return failure();
// If we were successful, insert any generated constants.
for (Operation *cst : generatedConstants)
insert(cst);
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index c9af5e7b46db84..bccfa9dd384aa4 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -89,6 +89,12 @@ def __str__(self):
# CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
print(b)
+ results, generated_constants = b.owner.try_fold()
+ for r in results:
+ print(r)
+ for g in generated_constants:
+ print(g)
+
a = arith.constant(f64_t, 42.42)
b = a * a
# CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
More information about the Mlir-commits
mailing list