[Mlir-commits] [mlir] [mlir][python] expose tryFold (PR #90148)

Maksim Levental llvmlistbot at llvm.org
Thu Apr 25 16:23:24 PDT 2024


https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/90148

None

>From 7bd08944e55f28d3e03699aa5945170b59370172 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 25 Apr 2024 17:15:22 -0500
Subject: [PATCH] [mlir][python] expose tryFold

---
 mlir/include/mlir-c/IR.h                   |  8 +++
 mlir/include/mlir/IR/Builders.h            |  9 ++++
 mlir/lib/Bindings/Python/IRCore.cpp        | 30 ++++++++++--
 mlir/lib/CAPI/IR/IR.cpp                    | 57 +++++++++++++++++++++-
 mlir/lib/IR/Builders.cpp                   | 14 ++++--
 mlir/test/python/dialects/arith_dialect.py |  6 +++
 6 files changed, 116 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 82da511f807a34..eae1e7252e2887 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -722,6 +722,14 @@ MLIR_CAPI_EXPORTED
 void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
                        void *userData, MlirWalkOrder walkOrder);
 
+MLIR_CAPI_EXPORTED MlirLogicalResult
+mlirOperationTryFold(MlirContext mlirContext, MlirOperation mlirOp,
+                     void *mlirResults, void *mlirConstants);
+
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirOperationTryFold(
+    MlirContext mlirContext, MlirOperation mlirOp, MlirValue **mlirResults,
+    size_t *numResults, MlirOperation **mlirConstants, size_t *numConstants);
+
 //===----------------------------------------------------------------------===//
 // Region API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 3beade017d1ab9..1ae1d045d4cf75 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -204,6 +204,13 @@ class Builder {
   MLIRContext *context;
 };
 
+/// Attempts to fold the given operation and places new results within
+/// 'results'. Returns success if the operation was folded, failure otherwise.
+/// Note: This function does not erase the operation on a successful fold.
+LogicalResult tryFold(MLIRContext *context, Operation *op,
+                      SmallVectorImpl<Value> &results,
+                      SmallVector<Operation *, 1> &generatedConstants);
+
 /// This class helps build Operations. Operations that are created are
 /// automatically inserted at an insertion point. The builder is copyable.
 class OpBuilder : public Builder {
@@ -562,6 +569,8 @@ class OpBuilder : public Builder {
   /// Attempts to fold the given operation and places new results within
   /// 'results'. Returns success if the operation was folded, failure otherwise.
   /// Note: This function does not erase the operation on a successful fold.
+  /// Note: This function inserts generated constants at the current insertion
+  /// point.
   LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
 
   /// Creates a deep copy of the specified operation, remapping any operands
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 734f2f7f3f94cf..c6fc36fcb6d759 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2989,8 +2989,7 @@ void mlir::python::populateIRCore(py::module &m) {
            py::arg("binary") = false, kOperationPrintStateDocstring)
       .def("print",
            py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
-                             bool, py::object, bool>(
-               &PyOperationBase::print),
+                             bool, py::object, bool>(&PyOperationBase::print),
            // Careful: Lots of arguments must match up with print method.
            py::arg("large_elements_limit") = py::none(),
            py::arg("enable_debug_info") = false,
@@ -3038,7 +3037,32 @@ void mlir::python::populateIRCore(py::module &m) {
             return operation.createOpView();
           },
           "Detaches the operation from its parent block.")
-      .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); });
+      .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
+      .def(
+          "try_fold",
+          [](PyOperationBase &self, MlirContext context) {
+            CollectDiagnosticsToStringScope scope(context);
+
+            size_t numResults;
+            size_t numConstants;
+            MlirValue *mlirResults;
+            MlirOperation *mlirConstants;
+
+            if (mlirLogicalResultIsFailure(mlirOperationTryFold(
+                    context, self.getOperation(), &mlirResults, &numResults,
+                    &mlirConstants, &numConstants)))
+              throw py::value_error(scope.takeMessage());
+
+            std::vector<MlirValue> result(mlirResults,
+                                          mlirResults + numResults);
+            std::vector<MlirOperation> generatedConstant(
+                mlirConstants, mlirConstants + numConstants);
+
+            free(mlirResults);
+            free(mlirConstants);
+            return std::pair(result, generatedConstant);
+          },
+          "context"_a = py::none());
 
   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
       .def_static("create", &PyOperation::create, py::arg("name"),
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index cdb64f4ec4a40f..bc94c919853d74 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -15,10 +15,10 @@
 #include "mlir/CAPI/Support.h"
 #include "mlir/CAPI/Utils.h"
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/Dialect.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/OperationSupport.h"
@@ -30,7 +30,6 @@
 #include "mlir/Parser/Parser.h"
 #include "llvm/Support/ThreadPool.h"
 
-#include <cstddef>
 #include <memory>
 #include <optional>
 
@@ -731,6 +730,60 @@ void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
   }
 }
 
+MlirLogicalResult mlirOperationTryFold(MlirContext mlirContext,
+                                       MlirOperation mlirOp, void *mlirResults,
+                                       void *mlirConstants) {
+  Operation *op = unwrap(mlirOp);
+  MLIRContext *context = unwrap(mlirContext);
+
+  SmallVector<MlirValue> *unwrappedResults;
+  unwrappedResults = static_cast<decltype(unwrappedResults)>(mlirResults);
+  SmallVector<MlirOperation> *unwrappedConstants;
+  unwrappedConstants = static_cast<decltype(unwrappedConstants)>(mlirConstants);
+
+  SmallVector<Value, 1> results;
+  SmallVector<Operation *, 1> generatedConstants;
+
+  if (failed(tryFold(context, op, results, generatedConstants)))
+    return mlirLogicalResultFailure();
+
+  for (auto &item : results)
+    unwrappedResults->push_back(wrap(item));
+
+  for (auto &item : generatedConstants)
+    unwrappedConstants->push_back(wrap(item));
+
+  return mlirLogicalResultSuccess();
+}
+
+MlirLogicalResult
+mlirOperationTryFold(MlirContext mlirContext, MlirOperation mlirOp,
+                     MlirValue **mlirResults, size_t *numResults,
+                     MlirOperation **mlirConstants, size_t *numConstants) {
+  Operation *op = unwrap(mlirOp);
+  MLIRContext *context = unwrap(mlirContext);
+
+  SmallVector<Value, 1> results;
+  SmallVector<Operation *, 1> generatedConstants;
+
+  if (failed(tryFold(context, op, results, generatedConstants)))
+    return mlirLogicalResultFailure();
+
+  *numResults = results.size();
+  *numConstants = generatedConstants.size();
+  *mlirResults =
+      reinterpret_cast<MlirValue *>(malloc(sizeof(MlirValue) * results.size()));
+  *mlirConstants = reinterpret_cast<MlirOperation *>(
+      malloc(sizeof(MlirOperation) * generatedConstants.size()));
+
+  for (size_t i = 0; i < results.size(); ++i)
+    (*mlirResults)[i] = wrap(results[i]);
+  for (size_t i = 0; i < generatedConstants.size(); ++i)
+    (*mlirConstants)[i] = wrap(generatedConstants[i]);
+
+  return mlirLogicalResultSuccess();
+}
+
 //===----------------------------------------------------------------------===//
 // Region API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 18ca3c332e0204..6e995edf5ebe7d 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -479,8 +479,9 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
 /// Attempts to fold the given operation and places new results within
 /// 'results'. Returns success if the operation was folded, failure otherwise.
 /// Note: This function does not erase the operation on a successful fold.
-LogicalResult OpBuilder::tryFold(Operation *op,
-                                 SmallVectorImpl<Value> &results) {
+LogicalResult mlir::tryFold(MLIRContext *context, Operation *op,
+                            SmallVectorImpl<Value> &results,
+                            SmallVector<Operation *, 1> &generatedConstants) {
   ResultRange opResults = op->getResults();
 
   results.reserve(opResults.size());
@@ -500,7 +501,6 @@ LogicalResult OpBuilder::tryFold(Operation *op,
 
   // A temporary builder used for creating constants during folding.
   OpBuilder cstBuilder(context);
-  SmallVector<Operation *, 1> generatedConstants;
 
   // Populate the results with the folded results.
   Dialect *dialect = op->getDialect();
@@ -533,6 +533,14 @@ LogicalResult OpBuilder::tryFold(Operation *op,
     results.push_back(constOp->getResult(0));
   }
 
+  return success();
+}
+
+LogicalResult OpBuilder::tryFold(Operation *op,
+                                 SmallVectorImpl<Value> &results) {
+  SmallVector<Operation *, 1> generatedConstants;
+  if (failed(mlir::tryFold(context, op, results, generatedConstants)))
+    return failure();
   // If we were successful, insert any generated constants.
   for (Operation *cst : generatedConstants)
     insert(cst);
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index c9af5e7b46db84..bccfa9dd384aa4 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -89,6 +89,12 @@ def __str__(self):
             # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
             print(b)
 
+            results, generated_constants = b.owner.try_fold()
+            for r in results:
+                print(r)
+            for g in generated_constants:
+                print(g)
+
             a = arith.constant(f64_t, 42.42)
             b = a * a
             # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)



More information about the Mlir-commits mailing list