[Mlir-commits] [mlir] 5f8c46b - [mlir] Add RewriterBase to the C API (#98962)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 16 12:37:16 PDT 2024


Author: Fehr Mathieu
Date: 2024-07-16T20:37:11+01:00
New Revision: 5f8c46b88799a710f98c00d377d7edc34096f85d

URL: https://github.com/llvm/llvm-project/commit/5f8c46b88799a710f98c00d377d7edc34096f85d
DIFF: https://github.com/llvm/llvm-project/commit/5f8c46b88799a710f98c00d377d7edc34096f85d.diff

LOG: [mlir] Add RewriterBase to the C API (#98962)

This exposes most of the `RewriterBase` methods to the C API.
This allows to manipulate both the `IRRewriter` and the
`PatternRewriter`. The
`IRRewriter` can be created from the C API, while the `PatternRewriter`
cannot.

The missing operations are the ones taking `Block::iterator` and
`Region::iterator` as
parameters, as they are not exposed by the C API yet AFAIK.

The Python bindings for these methods and classes are not implemented.

Added: 
    mlir/include/mlir/CAPI/Rewrite.h
    mlir/test/CAPI/rewrite.c

Modified: 
    mlir/include/mlir-c/Rewrite.h
    mlir/lib/CAPI/Transforms/Rewrite.cpp
    mlir/test/CAPI/CMakeLists.txt
    mlir/test/CMakeLists.txt
    mlir/test/lit.cfg.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index bed93045f4b50..d8f2275b61532 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -33,10 +33,266 @@ extern "C" {
   };                                                                           \
   typedef struct name name
 
+DEFINE_C_API_STRUCT(MlirRewriterBase, void);
 DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
 DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
 DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
 
+//===----------------------------------------------------------------------===//
+/// RewriterBase API inherited from OpBuilder
+//===----------------------------------------------------------------------===//
+
+/// Get the MLIR context referenced by the rewriter.
+MLIR_CAPI_EXPORTED MlirContext
+mlirRewriterBaseGetContext(MlirRewriterBase rewriter);
+
+//===----------------------------------------------------------------------===//
+/// Insertion points methods
+
+// These do not include functions using Block::iterator or Region::iterator, as
+// they are not exposed by the C API yet. Similarly for methods using
+// `InsertPoint` directly.
+
+/// Reset the insertion point to no location.  Creating an operation without a
+/// set insertion point is an error, but this can still be useful when the
+/// current insertion point a builder refers to is being removed.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter);
+
+/// Sets the insertion point to the specified operation, which will cause
+/// subsequent insertions to go right before it.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter,
+                                        MlirOperation op);
+
+/// Sets the insertion point to the node after the specified operation, which
+/// will cause subsequent insertions to go right after it.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter,
+                                       MlirOperation op);
+
+/// Sets the insertion point to the node after the specified value. If value
+/// has a defining operation, sets the insertion point to the node after such
+/// defining operation. This will cause subsequent insertions to go right
+/// after it. Otherwise, value is a BlockArgument. Sets the insertion point to
+/// the start of its block.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter,
+                                            MlirValue value);
+
+/// Sets the insertion point to the start of the specified block.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter,
+                                         MlirBlock block);
+
+/// Sets the insertion point to the end of the specified block.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter,
+                                       MlirBlock block);
+
+/// Return the block the current insertion point belongs to.  Note that the
+/// insertion point is not necessarily the end of the block.
+MLIR_CAPI_EXPORTED MlirBlock
+mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter);
+
+/// Returns the current block of the rewriter.
+MLIR_CAPI_EXPORTED MlirBlock
+mlirRewriterBaseGetBlock(MlirRewriterBase rewriter);
+
+//===----------------------------------------------------------------------===//
+/// Block and operation creation/insertion/cloning
+
+// These functions do not include the IRMapper, as it is not yet exposed by the
+// C API.
+
+/// Add new block with 'argTypes' arguments and set the insertion point to the
+/// end of it. The block is placed before 'insertBefore'. `locs` contains the
+/// locations of the inserted arguments, and should match the size of
+/// `argTypes`.
+MLIR_CAPI_EXPORTED MlirBlock mlirRewriterBaseCreateBlockBefore(
+    MlirRewriterBase rewriter, MlirBlock insertBefore, intptr_t nArgTypes,
+    MlirType const *argTypes, MlirLocation const *locations);
+
+/// Insert the given operation at the current insertion point and return it.
+MLIR_CAPI_EXPORTED MlirOperation
+mlirRewriterBaseInsert(MlirRewriterBase rewriter, MlirOperation op);
+
+/// Creates a deep copy of the specified operation.
+MLIR_CAPI_EXPORTED MlirOperation
+mlirRewriterBaseClone(MlirRewriterBase rewriter, MlirOperation op);
+
+/// Creates a deep copy of this operation but keep the operation regions
+/// empty.
+MLIR_CAPI_EXPORTED MlirOperation mlirRewriterBaseCloneWithoutRegions(
+    MlirRewriterBase rewriter, MlirOperation op);
+
+/// Clone the blocks that belong to "region" before the given position in
+/// another region "parent".
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter, MlirRegion region,
+                                  MlirBlock before);
+
+//===----------------------------------------------------------------------===//
+/// RewriterBase API
+//===----------------------------------------------------------------------===//
+
+/// Move the blocks that belong to "region" before the given position in
+/// another region "parent". The two regions must be 
diff erent. The caller
+/// is responsible for creating or updating the operation transferring flow
+/// of control to the region and passing it the correct block arguments.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter, MlirRegion region,
+                                   MlirBlock before);
+
+/// Replace the results of the given (original) operation with the specified
+/// list of values (replacements). The result types of the given op and the
+/// replacements must match. The original op is erased.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter, MlirOperation op,
+                                    intptr_t nValues, MlirValue const *values);
+
+/// Replace the results of the given (original) operation with the specified
+/// new op (replacement). The result types of the two ops must match. The
+/// original op is erased.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter,
+                                       MlirOperation op, MlirOperation newOp);
+
+/// Erases an operation that is known to have no uses.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter,
+                                                MlirOperation op);
+
+/// Erases a block along with all operations inside it.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter,
+                                                   MlirBlock block);
+
+/// Inline the operations of block 'source' before the operation 'op'. The
+/// source block will be deleted and must have no uses. 'argValues' is used to
+/// replace the block arguments of 'source'
+///
+/// The source block must have no successors. Otherwise, the resulting IR
+/// would have unreachable operations.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter, MlirBlock source,
+                                  MlirOperation op, intptr_t nArgValues,
+                                  MlirValue const *argValues);
+
+/// Inline the operations of block 'source' into the end of block 'dest'. The
+/// source block will be deleted and must have no uses. 'argValues' is used to
+/// replace the block arguments of 'source'
+///
+/// The dest block must have no successors. Otherwise, the resulting IR would
+/// have unreachable operation.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter,
+                                                    MlirBlock source,
+                                                    MlirBlock dest,
+                                                    intptr_t nArgValues,
+                                                    MlirValue const *argValues);
+
+/// Unlink this operation from its current block and insert it right before
+/// `existingOp` which may be in the same or another block in the same
+/// function.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter,
+                                                     MlirOperation op,
+                                                     MlirOperation existingOp);
+
+/// Unlink this operation from its current block and insert it right after
+/// `existingOp` which may be in the same or another block in the same
+/// function.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter,
+                                                    MlirOperation op,
+                                                    MlirOperation existingOp);
+
+/// Unlink this block and insert it right before `existingBlock`.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block,
+                                MlirBlock existingBlock);
+
+/// This method is used to notify the rewriter that an in-place operation
+/// modification is about to happen. A call to this function *must* be
+/// followed by a call to either `finalizeOpModification` or
+/// `cancelOpModification`. This is a minor efficiency win (it avoids creating
+/// a new operation and removing the old one) but also often allows simpler
+/// code in the client.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter,
+                                    MlirOperation op);
+
+/// This method is used to signal the end of an in-place modification of the
+/// given operation. This can only be called on operations that were provided
+/// to a call to `startOpModification`.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter,
+                                       MlirOperation op);
+
+/// This method cancels a pending in-place modification. This can only be
+/// called on operations that were provided to a call to
+/// `startOpModification`.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter,
+                                     MlirOperation op);
+
+/// Find uses of `from` and replace them with `to`. Also notify the listener
+/// about every in-place op modification (for every use that was replaced).
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter, MlirValue from,
+                                   MlirValue to);
+
+/// Find uses of `from` and replace them with `to`. Also notify the listener
+/// about every in-place op modification (for every use that was replaced).
+MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceAllValueRangeUsesWith(
+    MlirRewriterBase rewriter, intptr_t nValues, MlirValue const *from,
+    MlirValue const *to);
+
+/// Find uses of `from` and replace them with `to`. Also notify the listener
+/// about every in-place op modification (for every use that was replaced)
+/// and that the `from` operation is about to be replaced.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter,
+                                               MlirOperation from, intptr_t nTo,
+                                               MlirValue const *to);
+
+/// Find uses of `from` and replace them with `to`. Also notify the listener
+/// about every in-place op modification (for every use that was replaced)
+/// and that the `from` operation is about to be replaced.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceAllOpUsesWithOperation(
+    MlirRewriterBase rewriter, MlirOperation from, MlirOperation to);
+
+/// Find uses of `from` within `block` and replace them with `to`. Also notify
+/// the listener about every in-place op modification (for every use that was
+/// replaced). The optional `allUsesReplaced` flag is set to "true" if all
+/// uses were replaced.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceOpUsesWithinBlock(
+    MlirRewriterBase rewriter, MlirOperation op, intptr_t nNewValues,
+    MlirValue const *newValues, MlirBlock block);
+
+/// Find uses of `from` and replace them with `to` except if the user is
+/// `exceptedUser`. Also notify the listener about every in-place op
+/// modification (for every use that was replaced).
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter, MlirValue from,
+                                     MlirValue to, MlirOperation exceptedUser);
+
+//===----------------------------------------------------------------------===//
+/// IRRewriter API
+//===----------------------------------------------------------------------===//
+
+/// Create an IRRewriter and transfer ownership to the caller.
+MLIR_CAPI_EXPORTED MlirRewriterBase mlirIRRewriterCreate(MlirContext context);
+
+/// Create an IRRewriter and transfer ownership to the caller. Additionally
+/// set the insertion point before the operation.
+MLIR_CAPI_EXPORTED MlirRewriterBase
+mlirIRRewriterCreateFromOp(MlirOperation op);
+
+/// Takes an IRRewriter owned by the caller and destroys it. It is the
+/// responsibility of the user to only pass an IRRewriter class.
+MLIR_CAPI_EXPORTED void mlirIRRewriterDestroy(MlirRewriterBase rewriter);
+
+//===----------------------------------------------------------------------===//
+/// FrozenRewritePatternSet API
+//===----------------------------------------------------------------------===//
+
 MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet
 mlirFreezeRewritePattern(MlirRewritePatternSet op);
 
@@ -47,6 +303,10 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
     MlirModule op, MlirFrozenRewritePatternSet patterns,
     MlirGreedyRewriteDriverConfig);
 
+//===----------------------------------------------------------------------===//
+/// PDLPatternModule API
+//===----------------------------------------------------------------------===//
+
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
 DEFINE_C_API_STRUCT(MlirPDLPatternModule, void);
 

diff  --git a/mlir/include/mlir/CAPI/Rewrite.h b/mlir/include/mlir/CAPI/Rewrite.h
new file mode 100644
index 0000000000000..0e6dcb2477626
--- /dev/null
+++ b/mlir/include/mlir/CAPI/Rewrite.h
@@ -0,0 +1,23 @@
+//===- Rewrite.h - C API Utils for Core MLIR classes ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains declarations of implementation details of the C API for
+// rewrite patterns. This file should not be included from C++ code other than
+// C API implementation nor from C code.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CAPI_REWRITE_H
+#define MLIR_CAPI_REWRITE_H
+
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/IR/PatternMatch.h"
+
+DEFINE_C_API_PTR_METHODS(MlirRewriterBase, mlir::RewriterBase);
+
+#endif // MLIR_CAPIREWRITER_H

diff  --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 0de1958398f63..379f09cf5cc26 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -7,15 +7,254 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir-c/Rewrite.h"
+
 #include "mlir-c/Transforms.h"
 #include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Rewrite.h"
 #include "mlir/CAPI/Support.h"
+#include "mlir/CAPI/Wrap.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+/// RewriterBase API inherited from OpBuilder
+//===----------------------------------------------------------------------===//
+
+MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) {
+  return wrap(unwrap(rewriter)->getContext());
+}
+
+//===----------------------------------------------------------------------===//
+/// Insertion points methods
+
+void mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter) {
+  unwrap(rewriter)->clearInsertionPoint();
+}
+
+void mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter,
+                                             MlirOperation op) {
+  unwrap(rewriter)->setInsertionPoint(unwrap(op));
+}
+
+void mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter,
+                                            MlirOperation op) {
+  unwrap(rewriter)->setInsertionPointAfter(unwrap(op));
+}
+
+void mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter,
+                                                 MlirValue value) {
+  unwrap(rewriter)->setInsertionPointAfterValue(unwrap(value));
+}
+
+void mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter,
+                                              MlirBlock block) {
+  unwrap(rewriter)->setInsertionPointToStart(unwrap(block));
+}
+
+void mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter,
+                                            MlirBlock block) {
+  unwrap(rewriter)->setInsertionPointToEnd(unwrap(block));
+}
+
+MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter) {
+  return wrap(unwrap(rewriter)->getInsertionBlock());
+}
+
+MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) {
+  return wrap(unwrap(rewriter)->getBlock());
+}
+
+//===----------------------------------------------------------------------===//
+/// Block and operation creation/insertion/cloning
+
+MlirBlock mlirRewriterBaseCreateBlockBefore(MlirRewriterBase rewriter,
+                                            MlirBlock insertBefore,
+                                            intptr_t nArgTypes,
+                                            MlirType const *argTypes,
+                                            MlirLocation const *locations) {
+  SmallVector<Type, 4> args;
+  ArrayRef<Type> unwrappedArgs = unwrapList(nArgTypes, argTypes, args);
+  SmallVector<Location, 4> locs;
+  ArrayRef<Location> unwrappedLocs = unwrapList(nArgTypes, locations, locs);
+  return wrap(unwrap(rewriter)->createBlock(unwrap(insertBefore), unwrappedArgs,
+                                            unwrappedLocs));
+}
+
+MlirOperation mlirRewriterBaseInsert(MlirRewriterBase rewriter,
+                                     MlirOperation op) {
+  return wrap(unwrap(rewriter)->insert(unwrap(op)));
+}
+
+// Other methods of OpBuilder
+
+MlirOperation mlirRewriterBaseClone(MlirRewriterBase rewriter,
+                                    MlirOperation op) {
+  return wrap(unwrap(rewriter)->clone(*unwrap(op)));
+}
+
+MlirOperation mlirRewriterBaseCloneWithoutRegions(MlirRewriterBase rewriter,
+                                                  MlirOperation op) {
+  return wrap(unwrap(rewriter)->cloneWithoutRegions(*unwrap(op)));
+}
+
+void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter,
+                                       MlirRegion region, MlirBlock before) {
+
+  unwrap(rewriter)->cloneRegionBefore(*unwrap(region), unwrap(before));
+}
+
+//===----------------------------------------------------------------------===//
+/// RewriterBase API
+//===----------------------------------------------------------------------===//
+
+void mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter,
+                                        MlirRegion region, MlirBlock before) {
+  unwrap(rewriter)->inlineRegionBefore(*unwrap(region), unwrap(before));
+}
+
+void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter,
+                                         MlirOperation op, intptr_t nValues,
+                                         MlirValue const *values) {
+  SmallVector<Value, 4> vals;
+  ArrayRef<Value> unwrappedVals = unwrapList(nValues, values, vals);
+  unwrap(rewriter)->replaceOp(unwrap(op), unwrappedVals);
+}
+
+void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter,
+                                            MlirOperation op,
+                                            MlirOperation newOp) {
+  unwrap(rewriter)->replaceOp(unwrap(op), unwrap(newOp));
+}
+
+void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op) {
+  unwrap(rewriter)->eraseOp(unwrap(op));
+}
+
+void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter, MlirBlock block) {
+  unwrap(rewriter)->eraseBlock(unwrap(block));
+}
+
+void mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter,
+                                       MlirBlock source, MlirOperation op,
+                                       intptr_t nArgValues,
+                                       MlirValue const *argValues) {
+  SmallVector<Value, 4> vals;
+  ArrayRef<Value> unwrappedVals = unwrapList(nArgValues, argValues, vals);
+
+  unwrap(rewriter)->inlineBlockBefore(unwrap(source), unwrap(op),
+                                      unwrappedVals);
+}
+
+void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, MlirBlock source,
+                                 MlirBlock dest, intptr_t nArgValues,
+                                 MlirValue const *argValues) {
+  SmallVector<Value, 4> args;
+  ArrayRef<Value> unwrappedArgs = unwrapList(nArgValues, argValues, args);
+  unwrap(rewriter)->mergeBlocks(unwrap(source), unwrap(dest), unwrappedArgs);
+}
+
+void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter, MlirOperation op,
+                                  MlirOperation existingOp) {
+  unwrap(rewriter)->moveOpBefore(unwrap(op), unwrap(existingOp));
+}
+
+void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter, MlirOperation op,
+                                 MlirOperation existingOp) {
+  unwrap(rewriter)->moveOpAfter(unwrap(op), unwrap(existingOp));
+}
+
+void mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block,
+                                     MlirBlock existingBlock) {
+  unwrap(rewriter)->moveBlockBefore(unwrap(block), unwrap(existingBlock));
+}
+
+void mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter,
+                                         MlirOperation op) {
+  unwrap(rewriter)->startOpModification(unwrap(op));
+}
+
+void mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter,
+                                            MlirOperation op) {
+  unwrap(rewriter)->finalizeOpModification(unwrap(op));
+}
+
+void mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter,
+                                          MlirOperation op) {
+  unwrap(rewriter)->cancelOpModification(unwrap(op));
+}
+
+void mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter,
+                                        MlirValue from, MlirValue to) {
+  unwrap(rewriter)->replaceAllUsesWith(unwrap(from), unwrap(to));
+}
+
+void mlirRewriterBaseReplaceAllValueRangeUsesWith(MlirRewriterBase rewriter,
+                                                  intptr_t nValues,
+                                                  MlirValue const *from,
+                                                  MlirValue const *to) {
+  SmallVector<Value, 4> fromVals;
+  ArrayRef<Value> unwrappedFromVals = unwrapList(nValues, from, fromVals);
+  SmallVector<Value, 4> toVals;
+  ArrayRef<Value> unwrappedToVals = unwrapList(nValues, to, toVals);
+  unwrap(rewriter)->replaceAllUsesWith(unwrappedFromVals, unwrappedToVals);
+}
+
+void mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter,
+                                                    MlirOperation from,
+                                                    intptr_t nTo,
+                                                    MlirValue const *to) {
+  SmallVector<Value, 4> toVals;
+  ArrayRef<Value> unwrappedToVals = unwrapList(nTo, to, toVals);
+  unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrappedToVals);
+}
+
+void mlirRewriterBaseReplaceAllOpUsesWithOperation(MlirRewriterBase rewriter,
+                                                   MlirOperation from,
+                                                   MlirOperation to) {
+  unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrap(to));
+}
+
+void mlirRewriterBaseReplaceOpUsesWithinBlock(MlirRewriterBase rewriter,
+                                              MlirOperation op,
+                                              intptr_t nNewValues,
+                                              MlirValue const *newValues,
+                                              MlirBlock block) {
+  SmallVector<Value, 4> vals;
+  ArrayRef<Value> unwrappedVals = unwrapList(nNewValues, newValues, vals);
+  unwrap(rewriter)->replaceOpUsesWithinBlock(unwrap(op), unwrappedVals,
+                                             unwrap(block));
+}
+
+void mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter,
+                                          MlirValue from, MlirValue to,
+                                          MlirOperation exceptedUser) {
+  unwrap(rewriter)->replaceAllUsesExcept(unwrap(from), unwrap(to),
+                                         unwrap(exceptedUser));
+}
+
+//===----------------------------------------------------------------------===//
+/// IRRewriter API
+//===----------------------------------------------------------------------===//
+
+MlirRewriterBase mlirIRRewriterCreate(MlirContext context) {
+  return wrap(new IRRewriter(unwrap(context)));
+}
+
+MlirRewriterBase mlirIRRewriterCreateFromOp(MlirOperation op) {
+  return wrap(new IRRewriter(unwrap(op)));
+}
+
+void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
+  delete static_cast<IRRewriter *>(unwrap(rewriter));
+}
+
+//===----------------------------------------------------------------------===//
+/// RewritePatternSet and FrozenRewritePatternSet API
+//===----------------------------------------------------------------------===//
+
 inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
   assert(module.ptr && "unexpected null module");
   return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
@@ -54,6 +293,10 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op,
       mlir::applyPatternsAndFoldGreedily(unwrap(op), *unwrap(patterns)));
 }
 
+//===----------------------------------------------------------------------===//
+/// PDLPatternModule API
+//===----------------------------------------------------------------------===//
+
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
 inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
   assert(module.ptr && "unexpected null module");

diff  --git a/mlir/test/CAPI/CMakeLists.txt b/mlir/test/CAPI/CMakeLists.txt
index ad312764b3e06..e795672bce5d1 100644
--- a/mlir/test/CAPI/CMakeLists.txt
+++ b/mlir/test/CAPI/CMakeLists.txt
@@ -89,6 +89,15 @@ _add_capi_test_executable(mlir-capi-quant-test
     MLIRCAPIQuant
 )
 
+_add_capi_test_executable(mlir-capi-rewrite-test
+  rewrite.c
+  LINK_LIBS PRIVATE
+    MLIRCAPIIR
+    MLIRCAPIRegisterEverything
+    MLIRCAPITransforms
+)
+
+
 _add_capi_test_executable(mlir-capi-transform-test
   transform.c
   LINK_LIBS PRIVATE

diff  --git a/mlir/test/CAPI/rewrite.c b/mlir/test/CAPI/rewrite.c
new file mode 100644
index 0000000000000..a8b593eabb781
--- /dev/null
+++ b/mlir/test/CAPI/rewrite.c
@@ -0,0 +1,551 @@
+//===- rewrite.c - Test of the rewriting C API ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// RUN: mlir-capi-rewrite-test 2>&1 | FileCheck %s
+
+#include "mlir-c/Rewrite.h"
+#include "mlir-c/BuiltinTypes.h"
+#include "mlir-c/IR.h"
+
+#include <assert.h>
+#include <stdio.h>
+
+MlirOperation createOperationWithName(MlirContext ctx, const char *name) {
+  MlirStringRef nameRef = mlirStringRefCreateFromCString(name);
+  MlirLocation loc = mlirLocationUnknownGet(ctx);
+  MlirOperationState state = mlirOperationStateGet(nameRef, loc);
+  MlirType indexType = mlirIndexTypeGet(ctx);
+  mlirOperationStateAddResults(&state, 1, &indexType);
+  return mlirOperationCreate(&state);
+}
+
+void testInsertionPoint(MlirContext ctx) {
+  // CHECK-LABEL: @testInsertionPoint
+  fprintf(stderr, "@testInsertionPoint\n");
+
+  const char *moduleString = "\"dialect.op1\"() : () -> ()\n";
+  MlirModule module =
+      mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+  MlirOperation op = mlirModuleGetOperation(module);
+  MlirBlock body = mlirModuleGetBody(module);
+  MlirOperation op1 = mlirBlockGetFirstOperation(body);
+
+  // IRRewriter create
+  MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx);
+
+  // Insert before op
+  mlirRewriterBaseSetInsertionPointBefore(rewriter, op1);
+  MlirOperation op2 = createOperationWithName(ctx, "dialect.op2");
+  mlirRewriterBaseInsert(rewriter, op2);
+
+  // Insert after op
+  mlirRewriterBaseSetInsertionPointAfter(rewriter, op2);
+  MlirOperation op3 = createOperationWithName(ctx, "dialect.op3");
+  mlirRewriterBaseInsert(rewriter, op3);
+  MlirValue op3Res = mlirOperationGetResult(op3, 0);
+
+  // Insert after value
+  mlirRewriterBaseSetInsertionPointAfterValue(rewriter, op3Res);
+  MlirOperation op4 = createOperationWithName(ctx, "dialect.op4");
+  mlirRewriterBaseInsert(rewriter, op4);
+
+  // Insert at beginning of block
+  mlirRewriterBaseSetInsertionPointToStart(rewriter, body);
+  MlirOperation op5 = createOperationWithName(ctx, "dialect.op5");
+  mlirRewriterBaseInsert(rewriter, op5);
+
+  // Insert at end of block
+  mlirRewriterBaseSetInsertionPointToEnd(rewriter, body);
+  MlirOperation op6 = createOperationWithName(ctx, "dialect.op6");
+  mlirRewriterBaseInsert(rewriter, op6);
+
+  // Get insertion blocks
+  MlirBlock block1 = mlirRewriterBaseGetBlock(rewriter);
+  MlirBlock block2 = mlirRewriterBaseGetInsertionBlock(rewriter);
+  assert(body.ptr == block1.ptr);
+  assert(body.ptr == block2.ptr);
+
+  // clang-format off
+  // CHECK-NEXT: module {
+  // CHECK-NEXT:   %{{.*}} = "dialect.op5"() : () -> index
+  // CHECK-NEXT:   %{{.*}} = "dialect.op2"() : () -> index
+  // CHECK-NEXT:   %{{.*}} = "dialect.op3"() : () -> index
+  // CHECK-NEXT:   %{{.*}} = "dialect.op4"() : () -> index
+  // CHECK-NEXT:   "dialect.op1"() : () -> ()
+  // CHECK-NEXT:   %{{.*}} = "dialect.op6"() : () -> index
+  // CHECK-NEXT: }
+  // clang-format on
+  mlirOperationDump(op);
+
+  mlirIRRewriterDestroy(rewriter);
+  mlirModuleDestroy(module);
+}
+
+void testCreateBlock(MlirContext ctx) {
+  // CHECK-LABEL: @testCreateBlock
+  fprintf(stderr, "@testCreateBlock\n");
+
+  const char *moduleString = "\"dialect.op1\"() ({^bb0:}) : () -> ()\n"
+                             "\"dialect.op2\"() ({^bb0:}) : () -> ()\n";
+  MlirModule module =
+      mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+  MlirOperation op = mlirModuleGetOperation(module);
+  MlirBlock body = mlirModuleGetBody(module);
+
+  MlirOperation op1 = mlirBlockGetFirstOperation(body);
+  MlirRegion region1 = mlirOperationGetRegion(op1, 0);
+  MlirBlock block1 = mlirRegionGetFirstBlock(region1);
+
+  MlirOperation op2 = mlirOperationGetNextInBlock(op1);
+  MlirRegion region2 = mlirOperationGetRegion(op2, 0);
+  MlirBlock block2 = mlirRegionGetFirstBlock(region2);
+
+  MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx);
+
+  // Create block before
+  MlirType indexType = mlirIndexTypeGet(ctx);
+  MlirLocation unknown = mlirLocationUnknownGet(ctx);
+  mlirRewriterBaseCreateBlockBefore(rewriter, block1, 1, &indexType, &unknown);
+
+  mlirRewriterBaseSetInsertionPointToEnd(rewriter, body);
+
+  // Clone operation
+  mlirRewriterBaseClone(rewriter, op1);
+
+  // Clone without regions
+  mlirRewriterBaseCloneWithoutRegions(rewriter, op1);
+
+  // Clone region before
+  mlirRewriterBaseCloneRegionBefore(rewriter, region1, block2);
+
+  mlirOperationDump(op);
+  // clang-format off
+  // CHECK-NEXT: "builtin.module"() ({
+  // CHECK-NEXT:   "dialect.op1"() ({
+  // CHECK-NEXT:   ^{{.*}}(%{{.*}}: index):
+  // CHECK-NEXT:   ^{{.*}}:
+  // CHECK-NEXT:   }) : () -> ()
+  // CHECK-NEXT:   "dialect.op2"() ({
+  // CHECK-NEXT:   ^{{.*}}(%{{.*}}: index):
+  // CHECK-NEXT:   ^{{.*}}:
+  // CHECK-NEXT:   ^{{.*}}:
+  // CHECK-NEXT:   }) : () -> ()
+  // CHECK-NEXT:   "dialect.op1"() ({
+  // CHECK-NEXT:   ^{{.*}}(%{{.*}}: index):
+  // CHECK-NEXT:   ^{{.*}}:
+  // CHECK-NEXT:   }) : () -> ()
+  // CHECK-NEXT:   "dialect.op1"() ({
+  // CHECK-NEXT:   }) : () -> ()
+  // CHECK-NEXT: }) : () -> ()
+  // clang-format on
+
+  mlirIRRewriterDestroy(rewriter);
+  mlirModuleDestroy(module);
+}
+
+void testInlineRegionBlock(MlirContext ctx) {
+  // CHECK-LABEL: @testInlineRegionBlock
+  fprintf(stderr, "@testInlineRegionBlock\n");
+
+  const char *moduleString =
+      "\"dialect.op1\"() ({\n"
+      "  ^bb0(%arg0: index):\n"
+      "    \"dialect.op1_in1\"(%arg0) [^bb1] : (index) -> ()\n"
+      "  ^bb1():\n"
+      "    \"dialect.op1_in2\"() : () -> ()\n"
+      "}) : () -> ()\n"
+      "\"dialect.op2\"() ({^bb0:}) : () -> ()\n"
+      "\"dialect.op3\"() ({\n"
+      "  ^bb0(%arg0: index):\n"
+      "    \"dialect.op3_in1\"(%arg0) : (index) -> ()\n"
+      "  ^bb1():\n"
+      "    %x = \"dialect.op3_in2\"() : () -> index\n"
+      "    %y = \"dialect.op3_in3\"() : () -> index\n"
+      "}) : () -> ()\n"
+      "\"dialect.op4\"() ({\n"
+      "  ^bb0():\n"
+      "    \"dialect.op4_in1\"() : () -> index\n"
+      "  ^bb1(%arg0: index):\n"
+      "    \"dialect.op4_in2\"(%arg0) : (index) -> ()\n"
+      "}) : () -> ()\n";
+  MlirModule module =
+      mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+  MlirOperation op = mlirModuleGetOperation(module);
+  MlirBlock body = mlirModuleGetBody(module);
+
+  MlirOperation op1 = mlirBlockGetFirstOperation(body);
+  MlirRegion region1 = mlirOperationGetRegion(op1, 0);
+
+  MlirOperation op2 = mlirOperationGetNextInBlock(op1);
+  MlirRegion region2 = mlirOperationGetRegion(op2, 0);
+  MlirBlock block2 = mlirRegionGetFirstBlock(region2);
+
+  MlirOperation op3 = mlirOperationGetNextInBlock(op2);
+  MlirRegion region3 = mlirOperationGetRegion(op3, 0);
+  MlirBlock block3_1 = mlirRegionGetFirstBlock(region3);
+  MlirBlock block3_2 = mlirBlockGetNextInRegion(block3_1);
+  MlirOperation op3_in2 = mlirBlockGetFirstOperation(block3_2);
+  MlirValue op3_in2_res = mlirOperationGetResult(op3_in2, 0);
+  MlirOperation op3_in3 = mlirOperationGetNextInBlock(op3_in2);
+
+  MlirOperation op4 = mlirOperationGetNextInBlock(op3);
+  MlirRegion region4 = mlirOperationGetRegion(op4, 0);
+  MlirBlock block4_1 = mlirRegionGetFirstBlock(region4);
+  MlirOperation op4_in1 = mlirBlockGetFirstOperation(block4_1);
+  MlirValue op4_in1_res = mlirOperationGetResult(op4_in1, 0);
+  MlirBlock block4_2 = mlirBlockGetNextInRegion(block4_1);
+
+  MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx);
+
+  // Test these three functions
+  mlirRewriterBaseInlineRegionBefore(rewriter, region1, block2);
+  mlirRewriterBaseInlineBlockBefore(rewriter, block3_1, op3_in3, 1,
+                                    &op3_in2_res);
+  mlirRewriterBaseMergeBlocks(rewriter, block4_2, block4_1, 1, &op4_in1_res);
+
+  mlirOperationDump(op);
+  // clang-format off
+  // CHECK-NEXT: "builtin.module"() ({
+  // CHECK-NEXT:   "dialect.op1"() ({
+  // CHECK-NEXT:   }) : () -> ()
+  // CHECK-NEXT:   "dialect.op2"() ({
+  // CHECK-NEXT:   ^{{.*}}(%{{.*}}: index):
+  // CHECK-NEXT:     "dialect.op1_in1"(%{{.*}})[^[[bb:.*]]] : (index) -> ()
+  // CHECK-NEXT:   ^[[bb]]:
+  // CHECK-NEXT:     "dialect.op1_in2"() : () -> ()
+  // CHECK-NEXT:   ^{{.*}}:  // no predecessors
+  // CHECK-NEXT:   }) : () -> ()
+  // CHECK-NEXT:   "dialect.op3"() ({
+  // CHECK-NEXT:     %{{.*}} = "dialect.op3_in2"() : () -> index
+  // CHECK-NEXT:     "dialect.op3_in1"(%{{.*}}) : (index) -> ()
+  // CHECK-NEXT:     %{{.*}} = "dialect.op3_in3"() : () -> index
+  // CHECK-NEXT:   }) : () -> ()
+  // CHECK-NEXT:   "dialect.op4"() ({
+  // CHECK-NEXT:     %{{.*}} = "dialect.op4_in1"() : () -> index
+  // CHECK-NEXT:     "dialect.op4_in2"(%{{.*}}) : (index) -> ()
+  // CHECK-NEXT:   }) : () -> ()
+  // CHECK-NEXT: }) : () -> ()
+  // clang-format on
+
+  mlirIRRewriterDestroy(rewriter);
+  mlirModuleDestroy(module);
+}
+
+void testReplaceOp(MlirContext ctx) {
+  // CHECK-LABEL: @testReplaceOp
+  fprintf(stderr, "@testReplaceOp\n");
+
+  const char *moduleString =
+      "%x, %y, %z = \"dialect.create_values\"() : () -> (index, index, index)\n"
+      "%x_1, %y_1 = \"dialect.op1\"() : () -> (index, index)\n"
+      "\"dialect.use_op1\"(%x_1, %y_1) : (index, index) -> ()\n"
+      "%x_2, %y_2 = \"dialect.op2\"() : () -> (index, index)\n"
+      "%x_3, %y_3 = \"dialect.op3\"() : () -> (index, index)\n"
+      "\"dialect.use_op2\"(%x_2, %y_2) : (index, index) -> ()\n";
+  MlirModule module =
+      mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+  MlirOperation op = mlirModuleGetOperation(module);
+  MlirBlock body = mlirModuleGetBody(module);
+
+  // get a handle to all operations/values
+  MlirOperation createValues = mlirBlockGetFirstOperation(body);
+  MlirValue x = mlirOperationGetResult(createValues, 0);
+  MlirValue z = mlirOperationGetResult(createValues, 2);
+  MlirOperation op1 = mlirOperationGetNextInBlock(createValues);
+  MlirOperation useOp1 = mlirOperationGetNextInBlock(op1);
+  MlirOperation op2 = mlirOperationGetNextInBlock(useOp1);
+  MlirOperation op3 = mlirOperationGetNextInBlock(op2);
+
+  MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx);
+
+  // Test replace op with values
+  MlirValue xz[2] = {x, z};
+  mlirRewriterBaseReplaceOpWithValues(rewriter, op1, 2, xz);
+
+  // Test replace op with op
+  mlirRewriterBaseReplaceOpWithOperation(rewriter, op2, op3);
+
+  mlirOperationDump(op);
+  // clang-format off
+  // CHECK-NEXT: module {
+  // CHECK-NEXT:   %[[res:.*]]:3 = "dialect.create_values"() : () -> (index, index, index)
+  // CHECK-NEXT:   "dialect.use_op1"(%[[res]]#0, %[[res]]#2) : (index, index) -> ()
+  // CHECK-NEXT:   %[[res2:.*]]:2 = "dialect.op3"() : () -> (index, index)
+  // CHECK-NEXT:   "dialect.use_op2"(%[[res2]]#0, %[[res2]]#1) : (index, index) -> ()
+  // CHECK-NEXT: }
+  // clang-format on
+
+  mlirIRRewriterDestroy(rewriter);
+  mlirModuleDestroy(module);
+}
+
+void testErase(MlirContext ctx) {
+  // CHECK-LABEL: @testErase
+  fprintf(stderr, "@testErase\n");
+
+  const char *moduleString = "\"dialect.op_to_erase\"() : () -> ()\n"
+                             "\"dialect.op2\"() ({\n"
+                             "^bb0():\n"
+                             "  \"dialect.op2_nested\"() : () -> ()"
+                             "^block_to_erase():\n"
+                             "  \"dialect.op2_nested\"() : () -> ()"
+                             "^bb1():\n"
+                             "  \"dialect.op2_nested\"() : () -> ()"
+                             "}) : () -> ()\n";
+  MlirModule module =
+      mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+  MlirOperation op = mlirModuleGetOperation(module);
+  MlirBlock body = mlirModuleGetBody(module);
+
+  // get a handle to all operations/values
+  MlirOperation opToErase = mlirBlockGetFirstOperation(body);
+  MlirOperation op2 = mlirOperationGetNextInBlock(opToErase);
+  MlirRegion op2Region = mlirOperationGetRegion(op2, 0);
+  MlirBlock bb0 = mlirRegionGetFirstBlock(op2Region);
+  MlirBlock blockToErase = mlirBlockGetNextInRegion(bb0);
+
+  MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx);
+  mlirRewriterBaseEraseOp(rewriter, opToErase);
+  mlirRewriterBaseEraseBlock(rewriter, blockToErase);
+
+  mlirOperationDump(op);
+  // CHECK-NEXT: module {
+  // CHECK-NEXT: "dialect.op2"() ({
+  // CHECK-NEXT:   "dialect.op2_nested"() : () -> ()
+  // CHECK-NEXT: ^{{.*}}:
+  // CHECK-NEXT:   "dialect.op2_nested"() : () -> ()
+  // CHECK-NEXT: }) : () -> ()
+  // CHECK-NEXT: }
+
+  mlirIRRewriterDestroy(rewriter);
+  mlirModuleDestroy(module);
+}
+
+void testMove(MlirContext ctx) {
+  // CHECK-LABEL: @testMove
+  fprintf(stderr, "@testMove\n");
+
+  const char *moduleString = "\"dialect.op1\"() : () -> ()\n"
+                             "\"dialect.op2\"() ({\n"
+                             "^bb0(%arg0: index):\n"
+                             "  \"dialect.op2_1\"(%arg0) : (index) -> ()"
+                             "^bb1(%arg1: index):\n"
+                             "  \"dialect.op2_2\"(%arg1) : (index) -> ()"
+                             "}) : () -> ()\n"
+                             "\"dialect.op3\"() : () -> ()\n"
+                             "\"dialect.op4\"() : () -> ()\n";
+
+  MlirModule module =
+      mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+  MlirOperation op = mlirModuleGetOperation(module);
+  MlirBlock body = mlirModuleGetBody(module);
+
+  // get a handle to all operations/values
+  MlirOperation op1 = mlirBlockGetFirstOperation(body);
+  MlirOperation op2 = mlirOperationGetNextInBlock(op1);
+  MlirOperation op3 = mlirOperationGetNextInBlock(op2);
+  MlirOperation op4 = mlirOperationGetNextInBlock(op3);
+
+  MlirRegion region2 = mlirOperationGetRegion(op2, 0);
+  MlirBlock block0 = mlirRegionGetFirstBlock(region2);
+  MlirBlock block1 = mlirBlockGetNextInRegion(block0);
+
+  // Test move operations.
+  MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx);
+  mlirRewriterBaseMoveOpBefore(rewriter, op3, op1);
+  mlirRewriterBaseMoveOpAfter(rewriter, op4, op1);
+  mlirRewriterBaseMoveBlockBefore(rewriter, block1, block0);
+
+  mlirOperationDump(op);
+  // CHECK-NEXT: module {
+  // CHECK-NEXT:   "dialect.op3"() : () -> ()
+  // CHECK-NEXT:   "dialect.op1"() : () -> ()
+  // CHECK-NEXT:   "dialect.op4"() : () -> ()
+  // CHECK-NEXT:   "dialect.op2"() ({
+  // CHECK-NEXT:   ^{{.*}}(%[[arg0:.*]]: index):
+  // CHECK-NEXT:     "dialect.op2_2"(%[[arg0]]) : (index) -> ()
+  // CHECK-NEXT:   ^{{.*}}(%[[arg1:.*]]: index):  // no predecessors
+  // CHECK-NEXT:     "dialect.op2_1"(%[[arg1]]) : (index) -> ()
+  // CHECK-NEXT:   }) : () -> ()
+  // CHECK-NEXT: }
+
+  mlirIRRewriterDestroy(rewriter);
+  mlirModuleDestroy(module);
+}
+
+void testOpModification(MlirContext ctx) {
+  // CHECK-LABEL: @testOpModification
+  fprintf(stderr, "@testOpModification\n");
+
+  const char *moduleString =
+      "%x, %y = \"dialect.op1\"() : () -> (index, index)\n"
+      "\"dialect.op2\"(%x) : (index) -> ()\n";
+
+  MlirModule module =
+      mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+  MlirOperation op = mlirModuleGetOperation(module);
+  MlirBlock body = mlirModuleGetBody(module);
+
+  // get a handle to all operations/values
+  MlirOperation op1 = mlirBlockGetFirstOperation(body);
+  MlirValue y = mlirOperationGetResult(op1, 1);
+  MlirOperation op2 = mlirOperationGetNextInBlock(op1);
+
+  MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx);
+  mlirRewriterBaseStartOpModification(rewriter, op1);
+  mlirRewriterBaseCancelOpModification(rewriter, op1);
+
+  mlirRewriterBaseStartOpModification(rewriter, op2);
+  mlirOperationSetOperand(op2, 0, y);
+  mlirRewriterBaseFinalizeOpModification(rewriter, op2);
+
+  mlirOperationDump(op);
+  // CHECK-NEXT: module {
+  // CHECK-NEXT: %[[xy:.*]]:2 = "dialect.op1"() : () -> (index, index)
+  // CHECK-NEXT: "dialect.op2"(%[[xy]]#1) : (index) -> ()
+  // CHECK-NEXT: }
+
+  mlirIRRewriterDestroy(rewriter);
+  mlirModuleDestroy(module);
+}
+
+void testReplaceUses(MlirContext ctx) {
+  // CHECK-LABEL: @testReplaceUses
+  fprintf(stderr, "@testReplaceUses\n");
+
+  const char *moduleString =
+      // Replace values with values
+      "%x1, %y1, %z1 = \"dialect.op1\"() : () -> (index, index, index)\n"
+      "%x2, %y2, %z2 = \"dialect.op2\"() : () -> (index, index, index)\n"
+      "\"dialect.op1_uses\"(%x1, %y1, %z1) : (index, index, index) -> ()\n"
+      // Replace op with values
+      "%x3 = \"dialect.op3\"() : () -> index\n"
+      "%x4 = \"dialect.op4\"() : () -> index\n"
+      "\"dialect.op3_uses\"(%x3) : (index) -> ()\n"
+      // Replace op with op
+      "%x5 = \"dialect.op5\"() : () -> index\n"
+      "%x6 = \"dialect.op6\"() : () -> index\n"
+      "\"dialect.op5_uses\"(%x5) : (index) -> ()\n"
+      // Replace op in block;
+      "%x7 = \"dialect.op7\"() : () -> index\n"
+      "%x8 = \"dialect.op8\"() : () -> index\n"
+      "\"dialect.op9\"() ({\n"
+      "^bb0:\n"
+      "   \"dialect.op7_uses\"(%x7) : (index) -> ()\n"
+      "}): () -> ()\n"
+      "\"dialect.op7_uses\"(%x7) : (index) -> ()\n"
+      // Replace value with value except in op
+      "%x10 = \"dialect.op10\"() : () -> index\n"
+      "%x11 = \"dialect.op11\"() : () -> index\n"
+      "\"dialect.op10_uses\"(%x10) : (index) -> ()\n"
+      "\"dialect.op10_uses\"(%x10) : (index) -> ()\n";
+
+  MlirModule module =
+      mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+  MlirOperation op = mlirModuleGetOperation(module);
+  MlirBlock body = mlirModuleGetBody(module);
+
+  // get a handle to all operations/values
+  MlirOperation op1 = mlirBlockGetFirstOperation(body);
+  MlirValue x1 = mlirOperationGetResult(op1, 0);
+  MlirValue y1 = mlirOperationGetResult(op1, 1);
+  MlirValue z1 = mlirOperationGetResult(op1, 2);
+  MlirOperation op2 = mlirOperationGetNextInBlock(op1);
+  MlirValue x2 = mlirOperationGetResult(op2, 0);
+  MlirValue y2 = mlirOperationGetResult(op2, 1);
+  MlirValue z2 = mlirOperationGetResult(op2, 2);
+  MlirOperation op1Uses = mlirOperationGetNextInBlock(op2);
+
+  MlirOperation op3 = mlirOperationGetNextInBlock(op1Uses);
+  MlirOperation op4 = mlirOperationGetNextInBlock(op3);
+  MlirValue x4 = mlirOperationGetResult(op4, 0);
+  MlirOperation op3Uses = mlirOperationGetNextInBlock(op4);
+
+  MlirOperation op5 = mlirOperationGetNextInBlock(op3Uses);
+  MlirOperation op6 = mlirOperationGetNextInBlock(op5);
+  MlirOperation op5Uses = mlirOperationGetNextInBlock(op6);
+
+  MlirOperation op7 = mlirOperationGetNextInBlock(op5Uses);
+  MlirOperation op8 = mlirOperationGetNextInBlock(op7);
+  MlirValue x8 = mlirOperationGetResult(op8, 0);
+  MlirOperation op9 = mlirOperationGetNextInBlock(op8);
+  MlirRegion region9 = mlirOperationGetRegion(op9, 0);
+  MlirBlock block9 = mlirRegionGetFirstBlock(region9);
+  MlirOperation op7Uses = mlirOperationGetNextInBlock(op9);
+
+  MlirOperation op10 = mlirOperationGetNextInBlock(op7Uses);
+  MlirValue x10 = mlirOperationGetResult(op10, 0);
+  MlirOperation op11 = mlirOperationGetNextInBlock(op10);
+  MlirValue x11 = mlirOperationGetResult(op11, 0);
+  MlirOperation op10Uses1 = mlirOperationGetNextInBlock(op11);
+
+  MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx);
+
+  // Replace values
+  mlirRewriterBaseReplaceAllUsesWith(rewriter, x1, x2);
+  MlirValue y1z1[2] = {y1, z1};
+  MlirValue y2z2[2] = {y2, z2};
+  mlirRewriterBaseReplaceAllValueRangeUsesWith(rewriter, 2, y1z1, y2z2);
+
+  // Replace op with values
+  mlirRewriterBaseReplaceOpWithValues(rewriter, op3, 1, &x4);
+
+  // Replace op with op
+  mlirRewriterBaseReplaceOpWithOperation(rewriter, op5, op6);
+
+  // Replace op with op in block
+  mlirRewriterBaseReplaceOpUsesWithinBlock(rewriter, op7, 1, &x8, block9);
+
+  // Replace value with value except in op
+  mlirRewriterBaseReplaceAllUsesExcept(rewriter, x10, x11, op10Uses1);
+
+  mlirOperationDump(op);
+  // clang-format off
+  // CHECK-NEXT: module {
+  // CHECK-NEXT:   %{{.*}}:3 = "dialect.op1"() : () -> (index, index, index)
+  // CHECK-NEXT:   %[[res2:.*]]:3 = "dialect.op2"() : () -> (index, index, index)
+  // CHECK-NEXT:   "dialect.op1_uses"(%[[res2]]#0, %[[res2]]#1, %[[res2]]#2) : (index, index, index) -> ()
+  // CHECK-NEXT:   %[[res4:.*]] = "dialect.op4"() : () -> index
+  // CHECK-NEXT:   "dialect.op3_uses"(%[[res4]]) : (index) -> ()
+  // CHECK-NEXT:   %[[res6:.*]] = "dialect.op6"() : () -> index
+  // CHECK-NEXT:   "dialect.op5_uses"(%[[res6]]) : (index) -> ()
+  // CHECK-NEXT:   %[[res7:.*]] = "dialect.op7"() : () -> index
+  // CHECK-NEXT:   %[[res8:.*]] = "dialect.op8"() : () -> index
+  // CHECK-NEXT:   "dialect.op9"() ({
+  // CHECK-NEXT:     "dialect.op7_uses"(%[[res8]]) : (index) -> ()
+  // CHECK-NEXT:   }) : () -> ()
+  // CHECK-NEXT:   "dialect.op7_uses"(%[[res7]]) : (index) -> ()
+  // CHECK-NEXT:   %[[res10:.*]] = "dialect.op10"() : () -> index
+  // CHECK-NEXT:   %[[res11:.*]] = "dialect.op11"() : () -> index
+  // CHECK-NEXT:   "dialect.op10_uses"(%[[res10]]) : (index) -> ()
+  // CHECK-NEXT:   "dialect.op10_uses"(%[[res11]]) : (index) -> ()
+  // CHECK-NEXT: }
+  // clang-format on
+
+  mlirIRRewriterDestroy(rewriter);
+  mlirModuleDestroy(module);
+}
+
+int main(void) {
+  MlirContext ctx = mlirContextCreate();
+  mlirContextSetAllowUnregisteredDialects(ctx, true);
+  mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("builtin"));
+
+  testInsertionPoint(ctx);
+  testCreateBlock(ctx);
+  testInlineRegionBlock(ctx);
+  testReplaceOp(ctx);
+  testErase(ctx);
+  testMove(ctx);
+  testOpModification(ctx);
+  testReplaceUses(ctx);
+
+  mlirContextDestroy(ctx);
+  return 0;
+}

diff  --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 45009a78aa49f..df95e5db11f1e 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -105,6 +105,7 @@ set(MLIR_TEST_DEPENDS
   mlir-capi-llvm-test
   mlir-capi-pass-test
   mlir-capi-quant-test
+  mlir-capi-rewrite-test
   mlir-capi-sparse-tensor-test
   mlir-capi-transform-test
   mlir-capi-transform-interpreter-test

diff  --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 1175f87877f9e..98d0ddd9a2be1 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -106,6 +106,7 @@ def add_runtime(name):
     "mlir-capi-pass-test",
     "mlir-capi-pdl-test",
     "mlir-capi-quant-test",
+    "mlir-capi-rewrite-test",
     "mlir-capi-sparse-tensor-test",
     "mlir-capi-transform-test",
     "mlir-capi-transform-interpreter-test",


        


More information about the Mlir-commits mailing list