[Mlir-commits] [mlir] [mlir][python] expose tryFold (PR #90148)

Maksim Levental llvmlistbot at llvm.org
Thu Apr 25 16:27:16 PDT 2024


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/90148

>From 8842be09d1f1d344548460bee0d553aee905833f Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 25 Apr 2024 17:15:22 -0500
Subject: [PATCH] [mlir][python] expose tryFold

---
 mlir/include/mlir-c/IR.h                   |  8 +++
 mlir/include/mlir/IR/Builders.h            |  9 ++++
 mlir/lib/Bindings/Python/IRCore.cpp        | 27 +++++++++-
 mlir/lib/CAPI/IR/IR.cpp                    | 57 +++++++++++++++++++++-
 mlir/lib/IR/Builders.cpp                   | 17 +++++--
 mlir/test/python/dialects/arith_dialect.py |  6 +++
 6 files changed, 118 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 32abacf353133e..3c6297f26771de 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -730,6 +730,14 @@ MLIR_CAPI_EXPORTED
 void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
                        void *userData, MlirWalkOrder walkOrder);
 
+MLIR_CAPI_EXPORTED MlirLogicalResult
+mlirOperationTryFold(MlirContext mlirContext, MlirOperation mlirOp,
+                     void *mlirResults, void *mlirConstants);
+
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirOperationTryFold(
+    MlirContext mlirContext, MlirOperation mlirOp, MlirValue **mlirResults,
+    size_t *numResults, MlirOperation **mlirConstants, size_t *numConstants);
+
 //===----------------------------------------------------------------------===//
 // Region API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 0d5fa719d0dee2..0adad2501d3737 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -204,6 +204,13 @@ class Builder {
   MLIRContext *context;
 };
 
+/// Attempts to fold the given operation and places new results within
+/// 'results'. Returns success if the operation was folded, failure otherwise.
+/// Note: This function does not erase the operation on a successful fold.
+LogicalResult tryFold(MLIRContext *context, Operation *op,
+                      SmallVectorImpl<Value> &results,
+                      SmallVector<Operation *, 1> &generatedConstants);
+
 /// This class helps build Operations. Operations that are created are
 /// automatically inserted at an insertion point. The builder is copyable.
 class OpBuilder : public Builder {
@@ -570,6 +577,8 @@ class OpBuilder : public Builder {
   /// `results`. Returns success if the operation was folded, failure otherwise.
   /// If the fold was in-place, `results` will not be filled.
   /// Note: This function does not erase the operation on a successful fold.
+  /// Note: This function inserts generated constants at the current insertion
+  /// point.
   LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
 
   /// Creates a deep copy of the specified operation, remapping any operands
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 01678a9719f90f..3ff7a0fc4abfb7 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3081,7 +3081,32 @@ void mlir::python::populateIRCore(py::module &m) {
           "Detaches the operation from its parent block.")
       .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
       .def("walk", &PyOperationBase::walk, py::arg("callback"),
-           py::arg("walk_order") = MlirWalkPostOrder);
+           py::arg("walk_order") = MlirWalkPostOrder)
+      .def(
+          "try_fold",
+          [](PyOperationBase &self, MlirContext context) {
+            CollectDiagnosticsToStringScope scope(context);
+
+            size_t numResults;
+            size_t numConstants;
+            MlirValue *mlirResults;
+            MlirOperation *mlirConstants;
+
+            if (mlirLogicalResultIsFailure(mlirOperationTryFold(
+                    context, self.getOperation(), &mlirResults, &numResults,
+                    &mlirConstants, &numConstants)))
+              throw py::value_error(scope.takeMessage());
+
+            std::vector<MlirValue> result(mlirResults,
+                                          mlirResults + numResults);
+            std::vector<MlirOperation> generatedConstant(
+                mlirConstants, mlirConstants + numConstants);
+
+            free(mlirResults);
+            free(mlirConstants);
+            return std::pair(result, generatedConstant);
+          },
+          "context"_a = py::none());
 
   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
       .def_static("create", &PyOperation::create, py::arg("name"),
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index a72cd247e73f60..1492b88f476131 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -15,10 +15,10 @@
 #include "mlir/CAPI/Support.h"
 #include "mlir/CAPI/Utils.h"
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/Dialect.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/OperationSupport.h"
@@ -30,7 +30,6 @@
 #include "mlir/Parser/Parser.h"
 #include "llvm/Support/ThreadPool.h"
 
-#include <cstddef>
 #include <memory>
 #include <optional>
 
@@ -748,6 +747,60 @@ void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
   }
 }
 
+MlirLogicalResult mlirOperationTryFold(MlirContext mlirContext,
+                                       MlirOperation mlirOp, void *mlirResults,
+                                       void *mlirConstants) {
+  Operation *op = unwrap(mlirOp);
+  MLIRContext *context = unwrap(mlirContext);
+
+  SmallVector<MlirValue> *unwrappedResults;
+  unwrappedResults = static_cast<decltype(unwrappedResults)>(mlirResults);
+  SmallVector<MlirOperation> *unwrappedConstants;
+  unwrappedConstants = static_cast<decltype(unwrappedConstants)>(mlirConstants);
+
+  SmallVector<Value, 1> results;
+  SmallVector<Operation *, 1> generatedConstants;
+
+  if (failed(tryFold(context, op, results, generatedConstants)))
+    return mlirLogicalResultFailure();
+
+  for (auto &item : results)
+    unwrappedResults->push_back(wrap(item));
+
+  for (auto &item : generatedConstants)
+    unwrappedConstants->push_back(wrap(item));
+
+  return mlirLogicalResultSuccess();
+}
+
+MlirLogicalResult
+mlirOperationTryFold(MlirContext mlirContext, MlirOperation mlirOp,
+                     MlirValue **mlirResults, size_t *numResults,
+                     MlirOperation **mlirConstants, size_t *numConstants) {
+  Operation *op = unwrap(mlirOp);
+  MLIRContext *context = unwrap(mlirContext);
+
+  SmallVector<Value, 1> results;
+  SmallVector<Operation *, 1> generatedConstants;
+
+  if (failed(tryFold(context, op, results, generatedConstants)))
+    return mlirLogicalResultFailure();
+
+  *numResults = results.size();
+  *numConstants = generatedConstants.size();
+  *mlirResults =
+      reinterpret_cast<MlirValue *>(malloc(sizeof(MlirValue) * results.size()));
+  *mlirConstants = reinterpret_cast<MlirOperation *>(
+      malloc(sizeof(MlirOperation) * generatedConstants.size()));
+
+  for (size_t i = 0; i < results.size(); ++i)
+    (*mlirResults)[i] = wrap(results[i]);
+  for (size_t i = 0; i < generatedConstants.size(); ++i)
+    (*mlirConstants)[i] = wrap(generatedConstants[i]);
+
+  return mlirLogicalResultSuccess();
+}
+
 //===----------------------------------------------------------------------===//
 // Region API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index d49f69a7b7ae6b..4f753d89c0ae7c 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -476,8 +476,12 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
   return create(state);
 }
 
-LogicalResult OpBuilder::tryFold(Operation *op,
-                                 SmallVectorImpl<Value> &results) {
+/// Attempts to fold the given operation and places new results within
+/// 'results'. Returns success if the operation was folded, failure otherwise.
+/// Note: This function does not erase the operation on a successful fold.
+LogicalResult mlir::tryFold(MLIRContext *context, Operation *op,
+                            SmallVectorImpl<Value> &results,
+                            SmallVector<Operation *, 1> &generatedConstants) {
   assert(results.empty() && "expected empty results");
   ResultRange opResults = op->getResults();
 
@@ -502,7 +506,6 @@ LogicalResult OpBuilder::tryFold(Operation *op,
 
   // A temporary builder used for creating constants during folding.
   OpBuilder cstBuilder(context);
-  SmallVector<Operation *, 1> generatedConstants;
 
   // Populate the results with the folded results.
   Dialect *dialect = op->getDialect();
@@ -535,6 +538,14 @@ LogicalResult OpBuilder::tryFold(Operation *op,
     results.push_back(constOp->getResult(0));
   }
 
+  return success();
+}
+
+LogicalResult OpBuilder::tryFold(Operation *op,
+                                 SmallVectorImpl<Value> &results) {
+  SmallVector<Operation *, 1> generatedConstants;
+  if (failed(mlir::tryFold(context, op, results, generatedConstants)))
+    return failure();
   // If we were successful, insert any generated constants.
   for (Operation *cst : generatedConstants)
     insert(cst);
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index c9af5e7b46db84..bccfa9dd384aa4 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -89,6 +89,12 @@ def __str__(self):
             # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
             print(b)
 
+            results, generated_constants = b.owner.try_fold()
+            for r in results:
+                print(r)
+            for g in generated_constants:
+                print(g)
+
             a = arith.constant(f64_t, 42.42)
             b = a * a
             # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)



More information about the Mlir-commits mailing list