[Mlir-commits] [mlir] [mlir] Add RewriterBase to the C API (PR #98962)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 15 14:04:41 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Fehr Mathieu (math-fehr)

<details>
<summary>Changes</summary>

This exposes most of the `RewriterBase` methods to the C API.
This allows to manipulate both the `IRRewriter` and the `PatternRewriter`. The
`IRRewriter` can be created from the C API, while the `PatternRewriter` cannot.

The missing operations are the ones taking `Block::iterator` and `Region::iterator` as
parameters, as they are not exposed by the C API yet AFAIK.

The Python bindings for these methods and classes are not implemented.

---

Patch is 49.36 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/98962.diff


7 Files Affected:

- (modified) mlir/include/mlir-c/Rewrite.h (+257) 
- (added) mlir/include/mlir/CAPI/Rewrite.h (+23) 
- (modified) mlir/lib/CAPI/Transforms/Rewrite.cpp (+249) 
- (modified) mlir/test/CAPI/CMakeLists.txt (+9) 
- (added) mlir/test/CAPI/rewrite.c (+551) 
- (modified) mlir/test/CMakeLists.txt (+1) 
- (modified) mlir/test/lit.cfg.py (+1) 


``````````diff
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index bed93045f4b50..09f8a72a0c599 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -33,10 +33,263 @@ extern "C" {
   };                                                                           \
   typedef struct name name
 
+DEFINE_C_API_STRUCT(MlirRewriterBase, void);
 DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
 DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
 DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
 
+//===----------------------------------------------------------------------===//
+/// RewriterBase API inherited from OpBuilder
+//===----------------------------------------------------------------------===//
+
+/// Get the MLIR context referenced by the rewriter.
+MLIR_CAPI_EXPORTED MlirContext
+mlirRewriterBaseGetContext(MlirRewriterBase rewriter);
+
+//===----------------------------------------------------------------------===//
+/// Insertion points methods
+
+// They do not include functions using Block::iterator or Region::iterator, as
+// they are not exposed by the C API yet. This includes methods using
+// `InsertPoint` directly.
+
+/// Reset the insertion point to no location.  Creating an operation without a
+/// set insertion point is an error, but this can still be useful when the
+/// current insertion point a builder refers to is being removed.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter);
+
+/// Sets the insertion point to the specified operation, which will cause
+/// subsequent insertions to go right before it.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter,
+                                        MlirOperation op);
+
+/// Sets the insertion point to the node after the specified operation, which
+/// will cause subsequent insertions to go right after it.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter,
+                                       MlirOperation op);
+
+/// Sets the insertion point to the node after the specified value. If value
+/// has a defining operation, sets the insertion point to the node after such
+/// defining operation. This will cause subsequent insertions to go right
+/// after it. Otherwise, value is a BlockArgument. Sets the insertion point to
+/// the start of its block.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter,
+                                            MlirValue value);
+
+/// Sets the insertion point to the start of the specified block.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter,
+                                         MlirBlock block);
+
+/// Sets the insertion point to the end of the specified block.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter,
+                                       MlirBlock block);
+
+/// Return the block the current insertion point belongs to.  Note that the
+/// insertion point is not necessarily the end of the block.
+MLIR_CAPI_EXPORTED MlirBlock
+mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter);
+
+/// Returns the current block of the rewriter.
+MLIR_CAPI_EXPORTED MlirBlock
+mlirRewriterBaseGetBlock(MlirRewriterBase rewriter);
+
+//===----------------------------------------------------------------------===//
+/// Block and operation creation/insertion/cloning
+
+/// Add new block with 'argTypes' arguments and set the insertion point to the
+/// end of it. The block is placed before 'insertBefore'. `locs` contains the
+/// locations of the inserted arguments, and should match the size of
+/// `argTypes`.
+MLIR_CAPI_EXPORTED MlirBlock mlirRewriterBaseCreateBlockBefore(
+    MlirRewriterBase rewriter, MlirBlock insertBefore, intptr_t nArgTypes,
+    MlirType const *argTypes, MlirLocation const *locations);
+
+/// Insert the given operation at the current insertion point and return it.
+MLIR_CAPI_EXPORTED MlirOperation
+mlirRewriterBaseInsert(MlirRewriterBase rewriter, MlirOperation op);
+
+// The IRMapper is not yet exposed in the CAPI
+MLIR_CAPI_EXPORTED MlirOperation
+mlirRewriterBaseClone(MlirRewriterBase rewriter, MlirOperation op);
+
+// The IRMapper is not yet exposed in the CAPI
+MLIR_CAPI_EXPORTED MlirOperation mlirRewriterBaseCloneWithoutRegions(
+    MlirRewriterBase rewriter, MlirOperation op);
+
+// The IRMapper is not yet exposed in the CAPI, nor Region::iterator.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter, MlirRegion region,
+                                  MlirBlock before);
+
+//===----------------------------------------------------------------------===//
+/// RewriterBase API
+//===----------------------------------------------------------------------===//
+
+/// Move the blocks that belong to "region" before the given position in
+/// another region "parent". The two regions must be different. The caller
+/// is responsible for creating or updating the operation transferring flow
+/// of control to the region and passing it the correct block arguments.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter, MlirRegion region,
+                                   MlirBlock before);
+
+/// Replace the results of the given (original) operation with the specified
+/// list of values (replacements). The result types of the given op and the
+/// replacements must match. The original op is erased.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter, MlirOperation op,
+                                    intptr_t nValues, MlirValue const *values);
+
+/// Replace the results of the given (original) operation with the specified
+/// new op (replacement). The result types of the two ops must match. The
+/// original op is erased.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter,
+                                       MlirOperation op, MlirOperation newOp);
+
+/// Erases an operation that is known to have no uses.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter,
+                                                MlirOperation op);
+
+/// Erases a block along with all operations inside it.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter,
+                                                   MlirBlock block);
+
+/// Inline the operations of block 'source' before the operation 'op'. The
+/// source block will be deleted and must have no uses. 'argValues' is used to
+/// replace the block arguments of 'source'
+///
+/// The source block must have no successors. Otherwise, the resulting IR
+/// would have unreachable operations.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter, MlirBlock source,
+                                  MlirOperation op, intptr_t nArgValues,
+                                  MlirValue const *argValues);
+
+/// Inline the operations of block 'source' into the end of block 'dest'. The
+/// source block will be deleted and must have no uses. 'argValues' is used to
+/// replace the block arguments of 'source'
+///
+/// The dest block must have no successors. Otherwise, the resulting IR would
+/// have unreachable operation.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter,
+                                                    MlirBlock source,
+                                                    MlirBlock dest,
+                                                    intptr_t nArgValues,
+                                                    MlirValue const *argValues);
+
+// splitBlock is not implemented as Block::iterator is not exposed by the CAPI
+
+/// Unlink this operation from its current block and insert it right before
+/// `existingOp` which may be in the same or another block in the same
+/// function.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter,
+                                                     MlirOperation op,
+                                                     MlirOperation existingOp);
+
+/// Unlink this operation from its current block and insert it right after
+/// `existingOp` which may be in the same or another block in the same
+/// function.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter,
+                                                    MlirOperation op,
+                                                    MlirOperation existingOp);
+
+/// Unlink this block and insert it right before `existingBlock`.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block,
+                                MlirBlock existingBlock);
+
+/// This method is used to notify the rewriter that an in-place operation
+/// modification is about to happen. A call to this function *must* be
+/// followed by a call to either `finalizeOpModification` or
+/// `cancelOpModification`. This is a minor efficiency win (it avoids creating
+/// a new operation and removing the old one) but also often allows simpler
+/// code in the client.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter,
+                                    MlirOperation op);
+
+/// This method is used to signal the end of an in-place modification of the
+/// given operation. This can only be called on operations that were provided
+/// to a call to `startOpModification`.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter,
+                                       MlirOperation op);
+
+/// This method cancels a pending in-place modification. This can only be
+/// called on operations that were provided to a call to
+/// `startOpModification`.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter,
+                                     MlirOperation op);
+
+/// Find uses of `from` and replace them with `to`. Also notify the listener
+/// about every in-place op modification (for every use that was replaced).
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter, MlirValue from,
+                                   MlirValue to);
+
+/// Find uses of `from` and replace them with `to`. Also notify the listener
+/// about every in-place op modification (for every use that was replaced).
+MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceAllValueRangeUsesWith(
+    MlirRewriterBase rewriter, intptr_t nValues, MlirValue const *from,
+    MlirValue const *to);
+
+/// Find uses of `from` and replace them with `to`. Also notify the listener
+/// about every in-place op modification (for every use that was replaced)
+/// and that the `from` operation is about to be replaced.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter,
+                                               MlirOperation from, intptr_t nTo,
+                                               MlirValue const *to);
+
+/// Find uses of `from` and replace them with `to`. Also notify the listener
+/// about every in-place op modification (for every use that was replaced)
+/// and that the `from` operation is about to be replaced.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceAllOpUsesWithOperation(
+    MlirRewriterBase rewriter, MlirOperation from, MlirOperation to);
+
+/// Find uses of `from` within `block` and replace them with `to`. Also notify
+/// the listener about every in-place op modification (for every use that was
+/// replaced). The optional `allUsesReplaced` flag is set to "true" if all
+/// uses were replaced.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceOpUsesWithinBlock(
+    MlirRewriterBase rewriter, MlirOperation op, intptr_t nNewValues,
+    MlirValue const *newValues, MlirBlock block);
+
+/// Find uses of `from` and replace them with `to` except if the user is
+/// `exceptedUser`. Also notify the listener about every in-place op
+/// modification (for every use that was replaced).
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter, MlirValue from,
+                                     MlirValue to, MlirOperation exceptedUser);
+
+//===----------------------------------------------------------------------===//
+/// IRRewriter API
+//===----------------------------------------------------------------------===//
+
+/// Create an IRRewriter and transfer ownership to the caller.
+MLIR_CAPI_EXPORTED MlirRewriterBase mlirIRRewriterCreate(MlirContext context);
+
+/// Create an IRRewriter and transfer ownership to the caller. Additionally
+/// set the insertion point before the operation.
+MLIR_CAPI_EXPORTED MlirRewriterBase
+mlirIRRewriterCreateFromOp(MlirOperation op);
+
+/// Takes an IRRewriter owned by the caller and destroys it. It is the
+/// responsibility of the user to only pass an IRRewriter class.
+MLIR_CAPI_EXPORTED void mlirIRRewriterDestroy(MlirRewriterBase rewriter);
+
+//===----------------------------------------------------------------------===//
+/// FrozenRewritePatternSet API
+//===----------------------------------------------------------------------===//
+
 MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet
 mlirFreezeRewritePattern(MlirRewritePatternSet op);
 
@@ -47,6 +300,10 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
     MlirModule op, MlirFrozenRewritePatternSet patterns,
     MlirGreedyRewriteDriverConfig);
 
+//===----------------------------------------------------------------------===//
+/// PDLPatternModule API
+//===----------------------------------------------------------------------===//
+
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
 DEFINE_C_API_STRUCT(MlirPDLPatternModule, void);
 
diff --git a/mlir/include/mlir/CAPI/Rewrite.h b/mlir/include/mlir/CAPI/Rewrite.h
new file mode 100644
index 0000000000000..0e6dcb2477626
--- /dev/null
+++ b/mlir/include/mlir/CAPI/Rewrite.h
@@ -0,0 +1,23 @@
+//===- Rewrite.h - C API Utils for Core MLIR classes ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains declarations of implementation details of the C API for
+// rewrite patterns. This file should not be included from C++ code other than
+// C API implementation nor from C code.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CAPI_REWRITE_H
+#define MLIR_CAPI_REWRITE_H
+
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/IR/PatternMatch.h"
+
+DEFINE_C_API_PTR_METHODS(MlirRewriterBase, mlir::RewriterBase);
+
+#endif // MLIR_CAPIREWRITER_H
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 0de1958398f63..7f3c833df0910 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -7,15 +7,260 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir-c/Rewrite.h"
+
 #include "mlir-c/Transforms.h"
 #include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Rewrite.h"
 #include "mlir/CAPI/Support.h"
+#include "mlir/CAPI/Wrap.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+/// RewriterBase API inherited from OpBuilder
+//===----------------------------------------------------------------------===//
+
+MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) {
+  return wrap(unwrap(rewriter)->getContext());
+}
+
+//===----------------------------------------------------------------------===//
+/// Insertion points methods
+
+void mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter) {
+  unwrap(rewriter)->clearInsertionPoint();
+}
+
+void mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter,
+                                             MlirOperation op) {
+  unwrap(rewriter)->setInsertionPoint(unwrap(op));
+}
+
+void mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter,
+                                            MlirOperation op) {
+  unwrap(rewriter)->setInsertionPointAfter(unwrap(op));
+}
+
+void mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter,
+                                                 MlirValue value) {
+  unwrap(rewriter)->setInsertionPointAfterValue(unwrap(value));
+}
+
+void mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter,
+                                              MlirBlock block) {
+  unwrap(rewriter)->setInsertionPointToStart(unwrap(block));
+}
+
+void mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter,
+                                            MlirBlock block) {
+  unwrap(rewriter)->setInsertionPointToEnd(unwrap(block));
+}
+
+MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter) {
+  return wrap(unwrap(rewriter)->getInsertionBlock());
+}
+
+MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) {
+  return wrap(unwrap(rewriter)->getBlock());
+}
+
+//===----------------------------------------------------------------------===//
+/// Block and operation creation/insertion/cloning
+
+MlirBlock mlirRewriterBaseCreateBlockBefore(MlirRewriterBase rewriter,
+                                            MlirBlock insertBefore,
+                                            intptr_t nArgTypes,
+                                            MlirType const *argTypes,
+                                            MlirLocation const *locations) {
+  SmallVector<Type, 4> args;
+  ArrayRef<Type> unwrappedArgs = unwrapList(nArgTypes, argTypes, args);
+  SmallVector<Location, 4> locs;
+  ArrayRef<Location> unwrappedLocs = unwrapList(nArgTypes, locations, locs);
+  return wrap(unwrap(rewriter)->createBlock(unwrap(insertBefore), unwrappedArgs,
+                                            unwrappedLocs));
+}
+
+MlirOperation mlirRewriterBaseInsert(MlirRewriterBase rewriter,
+                                     MlirOperation op) {
+  return wrap(unwrap(rewriter)->insert(unwrap(op)));
+}
+
+// Other methods of OpBuilder
+
+// The IRMapper is not yet exposed in the CAPI
+MlirOperation mlirRewriterBaseClone(MlirRewriterBase rewriter,
+                                    MlirOperation op) {
+  return wrap(unwrap(rewriter)->clone(*unwrap(op)));
+}
+
+// The IRMapper is not yet exposed in the CAPI
+MlirOperation mlirRewriterBaseCloneWithoutRegions(MlirRewriterBase rewriter,
+                                                  MlirOperation op) {
+  return wrap(unwrap(rewriter)->cloneWithoutRegions(*unwrap(op)));
+}
+
+// The IRMapper is not yet exposed in the CAPI, nor Region::iterator.
+void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter,
+                                       MlirRegion region, MlirBlock before) {
+
+  unwrap(rewriter)->cloneRegionBefore(*unwrap(region), unwrap(before));
+}
+
+//===----------------------------------------------------------------------===//
+/// RewriterBase API
+//===----------------------------------------------------------------------===//
+
+// Region::iterator is not yet exposed in the CAPI.
+void mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter,
+                                        MlirRegion region, MlirBlock before) {
+  unwrap(rewriter)->inlineRegionBefore(*unwrap(region), unwrap(before));
+}
+
+void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter,
+                                         MlirOperation op, intptr_t nValues,
+                                         MlirValue const *values) {
+  SmallVector<Value, 4> vals;
+  ArrayRef<Value> unwrappedVals = unwrapList(nValues, values, vals);
+  unwrap(rewriter)->replaceOp(unwrap(op), unwrappedVals);
+}
+
+void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter,
+                                            MlirOperation op,
+                                            MlirOperation ne...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/98962


More information about the Mlir-commits mailing list