[Mlir-commits] [mlir] [mlir] Add RewriterBase to the C API (PR #98962)
Fehr Mathieu
llvmlistbot at llvm.org
Mon Jul 15 14:04:22 PDT 2024
https://github.com/math-fehr created https://github.com/llvm/llvm-project/pull/98962
This exposes most of the `RewriterBase` methods to the C API.
This allows to manipulate both the `IRRewriter` and the `PatternRewriter`. The
`IRRewriter` can be created from the C API, while the `PatternRewriter` cannot.
The missing operations are the ones taking `Block::iterator` and `Region::iterator` as
parameters, as they are not exposed by the C API yet AFAIK.
The Python bindings for these methods and classes are not implemented.
>From fd2fd347cb185722265f27b7a2aa8a32315e3cbf Mon Sep 17 00:00:00 2001
From: Mathieu Fehr <mathieu.fehr at gmail.com>
Date: Mon, 15 Jul 2024 14:50:47 +0100
Subject: [PATCH] [mlir] Add RewriterBase to the C API
---
mlir/include/mlir-c/Rewrite.h | 257 +++++++++++++
mlir/include/mlir/CAPI/Rewrite.h | 23 ++
mlir/lib/CAPI/Transforms/Rewrite.cpp | 249 ++++++++++++
mlir/test/CAPI/CMakeLists.txt | 9 +
mlir/test/CAPI/rewrite.c | 551 +++++++++++++++++++++++++++
mlir/test/CMakeLists.txt | 1 +
mlir/test/lit.cfg.py | 1 +
7 files changed, 1091 insertions(+)
create mode 100644 mlir/include/mlir/CAPI/Rewrite.h
create mode 100644 mlir/test/CAPI/rewrite.c
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index bed93045f4b50..09f8a72a0c599 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -33,10 +33,263 @@ extern "C" {
}; \
typedef struct name name
+DEFINE_C_API_STRUCT(MlirRewriterBase, void);
DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
+//===----------------------------------------------------------------------===//
+/// RewriterBase API inherited from OpBuilder
+//===----------------------------------------------------------------------===//
+
+/// Get the MLIR context referenced by the rewriter.
+MLIR_CAPI_EXPORTED MlirContext
+mlirRewriterBaseGetContext(MlirRewriterBase rewriter);
+
+//===----------------------------------------------------------------------===//
+/// Insertion points methods
+
+// They do not include functions using Block::iterator or Region::iterator, as
+// they are not exposed by the C API yet. This includes methods using
+// `InsertPoint` directly.
+
+/// Reset the insertion point to no location. Creating an operation without a
+/// set insertion point is an error, but this can still be useful when the
+/// current insertion point a builder refers to is being removed.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter);
+
+/// Sets the insertion point to the specified operation, which will cause
+/// subsequent insertions to go right before it.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter,
+ MlirOperation op);
+
+/// Sets the insertion point to the node after the specified operation, which
+/// will cause subsequent insertions to go right after it.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter,
+ MlirOperation op);
+
+/// Sets the insertion point to the node after the specified value. If value
+/// has a defining operation, sets the insertion point to the node after such
+/// defining operation. This will cause subsequent insertions to go right
+/// after it. Otherwise, value is a BlockArgument. Sets the insertion point to
+/// the start of its block.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter,
+ MlirValue value);
+
+/// Sets the insertion point to the start of the specified block.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter,
+ MlirBlock block);
+
+/// Sets the insertion point to the end of the specified block.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter,
+ MlirBlock block);
+
+/// Return the block the current insertion point belongs to. Note that the
+/// insertion point is not necessarily the end of the block.
+MLIR_CAPI_EXPORTED MlirBlock
+mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter);
+
+/// Returns the current block of the rewriter.
+MLIR_CAPI_EXPORTED MlirBlock
+mlirRewriterBaseGetBlock(MlirRewriterBase rewriter);
+
+//===----------------------------------------------------------------------===//
+/// Block and operation creation/insertion/cloning
+
+/// Add new block with 'argTypes' arguments and set the insertion point to the
+/// end of it. The block is placed before 'insertBefore'. `locs` contains the
+/// locations of the inserted arguments, and should match the size of
+/// `argTypes`.
+MLIR_CAPI_EXPORTED MlirBlock mlirRewriterBaseCreateBlockBefore(
+ MlirRewriterBase rewriter, MlirBlock insertBefore, intptr_t nArgTypes,
+ MlirType const *argTypes, MlirLocation const *locations);
+
+/// Insert the given operation at the current insertion point and return it.
+MLIR_CAPI_EXPORTED MlirOperation
+mlirRewriterBaseInsert(MlirRewriterBase rewriter, MlirOperation op);
+
+// The IRMapper is not yet exposed in the CAPI
+MLIR_CAPI_EXPORTED MlirOperation
+mlirRewriterBaseClone(MlirRewriterBase rewriter, MlirOperation op);
+
+// The IRMapper is not yet exposed in the CAPI
+MLIR_CAPI_EXPORTED MlirOperation mlirRewriterBaseCloneWithoutRegions(
+ MlirRewriterBase rewriter, MlirOperation op);
+
+// The IRMapper is not yet exposed in the CAPI, nor Region::iterator.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter, MlirRegion region,
+ MlirBlock before);
+
+//===----------------------------------------------------------------------===//
+/// RewriterBase API
+//===----------------------------------------------------------------------===//
+
+/// Move the blocks that belong to "region" before the given position in
+/// another region "parent". The two regions must be different. The caller
+/// is responsible for creating or updating the operation transferring flow
+/// of control to the region and passing it the correct block arguments.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter, MlirRegion region,
+ MlirBlock before);
+
+/// Replace the results of the given (original) operation with the specified
+/// list of values (replacements). The result types of the given op and the
+/// replacements must match. The original op is erased.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter, MlirOperation op,
+ intptr_t nValues, MlirValue const *values);
+
+/// Replace the results of the given (original) operation with the specified
+/// new op (replacement). The result types of the two ops must match. The
+/// original op is erased.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter,
+ MlirOperation op, MlirOperation newOp);
+
+/// Erases an operation that is known to have no uses.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter,
+ MlirOperation op);
+
+/// Erases a block along with all operations inside it.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter,
+ MlirBlock block);
+
+/// Inline the operations of block 'source' before the operation 'op'. The
+/// source block will be deleted and must have no uses. 'argValues' is used to
+/// replace the block arguments of 'source'
+///
+/// The source block must have no successors. Otherwise, the resulting IR
+/// would have unreachable operations.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter, MlirBlock source,
+ MlirOperation op, intptr_t nArgValues,
+ MlirValue const *argValues);
+
+/// Inline the operations of block 'source' into the end of block 'dest'. The
+/// source block will be deleted and must have no uses. 'argValues' is used to
+/// replace the block arguments of 'source'
+///
+/// The dest block must have no successors. Otherwise, the resulting IR would
+/// have unreachable operation.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter,
+ MlirBlock source,
+ MlirBlock dest,
+ intptr_t nArgValues,
+ MlirValue const *argValues);
+
+// splitBlock is not implemented as Block::iterator is not exposed by the CAPI
+
+/// Unlink this operation from its current block and insert it right before
+/// `existingOp` which may be in the same or another block in the same
+/// function.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter,
+ MlirOperation op,
+ MlirOperation existingOp);
+
+/// Unlink this operation from its current block and insert it right after
+/// `existingOp` which may be in the same or another block in the same
+/// function.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter,
+ MlirOperation op,
+ MlirOperation existingOp);
+
+/// Unlink this block and insert it right before `existingBlock`.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block,
+ MlirBlock existingBlock);
+
+/// This method is used to notify the rewriter that an in-place operation
+/// modification is about to happen. A call to this function *must* be
+/// followed by a call to either `finalizeOpModification` or
+/// `cancelOpModification`. This is a minor efficiency win (it avoids creating
+/// a new operation and removing the old one) but also often allows simpler
+/// code in the client.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter,
+ MlirOperation op);
+
+/// This method is used to signal the end of an in-place modification of the
+/// given operation. This can only be called on operations that were provided
+/// to a call to `startOpModification`.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter,
+ MlirOperation op);
+
+/// This method cancels a pending in-place modification. This can only be
+/// called on operations that were provided to a call to
+/// `startOpModification`.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter,
+ MlirOperation op);
+
+/// Find uses of `from` and replace them with `to`. Also notify the listener
+/// about every in-place op modification (for every use that was replaced).
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter, MlirValue from,
+ MlirValue to);
+
+/// Find uses of `from` and replace them with `to`. Also notify the listener
+/// about every in-place op modification (for every use that was replaced).
+MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceAllValueRangeUsesWith(
+ MlirRewriterBase rewriter, intptr_t nValues, MlirValue const *from,
+ MlirValue const *to);
+
+/// Find uses of `from` and replace them with `to`. Also notify the listener
+/// about every in-place op modification (for every use that was replaced)
+/// and that the `from` operation is about to be replaced.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter,
+ MlirOperation from, intptr_t nTo,
+ MlirValue const *to);
+
+/// Find uses of `from` and replace them with `to`. Also notify the listener
+/// about every in-place op modification (for every use that was replaced)
+/// and that the `from` operation is about to be replaced.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceAllOpUsesWithOperation(
+ MlirRewriterBase rewriter, MlirOperation from, MlirOperation to);
+
+/// Find uses of `from` within `block` and replace them with `to`. Also notify
+/// the listener about every in-place op modification (for every use that was
+/// replaced). The optional `allUsesReplaced` flag is set to "true" if all
+/// uses were replaced.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceOpUsesWithinBlock(
+ MlirRewriterBase rewriter, MlirOperation op, intptr_t nNewValues,
+ MlirValue const *newValues, MlirBlock block);
+
+/// Find uses of `from` and replace them with `to` except if the user is
+/// `exceptedUser`. Also notify the listener about every in-place op
+/// modification (for every use that was replaced).
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter, MlirValue from,
+ MlirValue to, MlirOperation exceptedUser);
+
+//===----------------------------------------------------------------------===//
+/// IRRewriter API
+//===----------------------------------------------------------------------===//
+
+/// Create an IRRewriter and transfer ownership to the caller.
+MLIR_CAPI_EXPORTED MlirRewriterBase mlirIRRewriterCreate(MlirContext context);
+
+/// Create an IRRewriter and transfer ownership to the caller. Additionally
+/// set the insertion point before the operation.
+MLIR_CAPI_EXPORTED MlirRewriterBase
+mlirIRRewriterCreateFromOp(MlirOperation op);
+
+/// Takes an IRRewriter owned by the caller and destroys it. It is the
+/// responsibility of the user to only pass an IRRewriter class.
+MLIR_CAPI_EXPORTED void mlirIRRewriterDestroy(MlirRewriterBase rewriter);
+
+//===----------------------------------------------------------------------===//
+/// FrozenRewritePatternSet API
+//===----------------------------------------------------------------------===//
+
MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet
mlirFreezeRewritePattern(MlirRewritePatternSet op);
@@ -47,6 +300,10 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MlirModule op, MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig);
+//===----------------------------------------------------------------------===//
+/// PDLPatternModule API
+//===----------------------------------------------------------------------===//
+
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
DEFINE_C_API_STRUCT(MlirPDLPatternModule, void);
diff --git a/mlir/include/mlir/CAPI/Rewrite.h b/mlir/include/mlir/CAPI/Rewrite.h
new file mode 100644
index 0000000000000..0e6dcb2477626
--- /dev/null
+++ b/mlir/include/mlir/CAPI/Rewrite.h
@@ -0,0 +1,23 @@
+//===- Rewrite.h - C API Utils for Core MLIR classes ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains declarations of implementation details of the C API for
+// rewrite patterns. This file should not be included from C++ code other than
+// C API implementation nor from C code.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CAPI_REWRITE_H
+#define MLIR_CAPI_REWRITE_H
+
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/IR/PatternMatch.h"
+
+DEFINE_C_API_PTR_METHODS(MlirRewriterBase, mlir::RewriterBase);
+
+#endif // MLIR_CAPIREWRITER_H
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 0de1958398f63..7f3c833df0910 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -7,15 +7,260 @@
//===----------------------------------------------------------------------===//
#include "mlir-c/Rewrite.h"
+
#include "mlir-c/Transforms.h"
#include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Rewrite.h"
#include "mlir/CAPI/Support.h"
+#include "mlir/CAPI/Wrap.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
+//===----------------------------------------------------------------------===//
+/// RewriterBase API inherited from OpBuilder
+//===----------------------------------------------------------------------===//
+
+MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) {
+ return wrap(unwrap(rewriter)->getContext());
+}
+
+//===----------------------------------------------------------------------===//
+/// Insertion points methods
+
+void mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter) {
+ unwrap(rewriter)->clearInsertionPoint();
+}
+
+void mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter,
+ MlirOperation op) {
+ unwrap(rewriter)->setInsertionPoint(unwrap(op));
+}
+
+void mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter,
+ MlirOperation op) {
+ unwrap(rewriter)->setInsertionPointAfter(unwrap(op));
+}
+
+void mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter,
+ MlirValue value) {
+ unwrap(rewriter)->setInsertionPointAfterValue(unwrap(value));
+}
+
+void mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter,
+ MlirBlock block) {
+ unwrap(rewriter)->setInsertionPointToStart(unwrap(block));
+}
+
+void mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter,
+ MlirBlock block) {
+ unwrap(rewriter)->setInsertionPointToEnd(unwrap(block));
+}
+
+MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter) {
+ return wrap(unwrap(rewriter)->getInsertionBlock());
+}
+
+MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) {
+ return wrap(unwrap(rewriter)->getBlock());
+}
+
+//===----------------------------------------------------------------------===//
+/// Block and operation creation/insertion/cloning
+
+MlirBlock mlirRewriterBaseCreateBlockBefore(MlirRewriterBase rewriter,
+ MlirBlock insertBefore,
+ intptr_t nArgTypes,
+ MlirType const *argTypes,
+ MlirLocation const *locations) {
+ SmallVector<Type, 4> args;
+ ArrayRef<Type> unwrappedArgs = unwrapList(nArgTypes, argTypes, args);
+ SmallVector<Location, 4> locs;
+ ArrayRef<Location> unwrappedLocs = unwrapList(nArgTypes, locations, locs);
+ return wrap(unwrap(rewriter)->createBlock(unwrap(insertBefore), unwrappedArgs,
+ unwrappedLocs));
+}
+
+MlirOperation mlirRewriterBaseInsert(MlirRewriterBase rewriter,
+ MlirOperation op) {
+ return wrap(unwrap(rewriter)->insert(unwrap(op)));
+}
+
+// Other methods of OpBuilder
+
+// The IRMapper is not yet exposed in the CAPI
+MlirOperation mlirRewriterBaseClone(MlirRewriterBase rewriter,
+ MlirOperation op) {
+ return wrap(unwrap(rewriter)->clone(*unwrap(op)));
+}
+
+// The IRMapper is not yet exposed in the CAPI
+MlirOperation mlirRewriterBaseCloneWithoutRegions(MlirRewriterBase rewriter,
+ MlirOperation op) {
+ return wrap(unwrap(rewriter)->cloneWithoutRegions(*unwrap(op)));
+}
+
+// The IRMapper is not yet exposed in the CAPI, nor Region::iterator.
+void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter,
+ MlirRegion region, MlirBlock before) {
+
+ unwrap(rewriter)->cloneRegionBefore(*unwrap(region), unwrap(before));
+}
+
+//===----------------------------------------------------------------------===//
+/// RewriterBase API
+//===----------------------------------------------------------------------===//
+
+// Region::iterator is not yet exposed in the CAPI.
+void mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter,
+ MlirRegion region, MlirBlock before) {
+ unwrap(rewriter)->inlineRegionBefore(*unwrap(region), unwrap(before));
+}
+
+void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter,
+ MlirOperation op, intptr_t nValues,
+ MlirValue const *values) {
+ SmallVector<Value, 4> vals;
+ ArrayRef<Value> unwrappedVals = unwrapList(nValues, values, vals);
+ unwrap(rewriter)->replaceOp(unwrap(op), unwrappedVals);
+}
+
+void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter,
+ MlirOperation op,
+ MlirOperation newOp) {
+ unwrap(rewriter)->replaceOp(unwrap(op), unwrap(newOp));
+}
+
+void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op) {
+ unwrap(rewriter)->eraseOp(unwrap(op));
+}
+
+void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter, MlirBlock block) {
+ unwrap(rewriter)->eraseBlock(unwrap(block));
+}
+
+void mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter,
+ MlirBlock source, MlirOperation op,
+ intptr_t nArgValues,
+ MlirValue const *argValues) {
+ SmallVector<Value, 4> vals;
+ ArrayRef<Value> unwrappedVals = unwrapList(nArgValues, argValues, vals);
+
+ unwrap(rewriter)->inlineBlockBefore(unwrap(source), unwrap(op),
+ unwrappedVals);
+}
+
+void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, MlirBlock source,
+ MlirBlock dest, intptr_t nArgValues,
+ MlirValue const *argValues) {
+ SmallVector<Value, 4> args;
+ ArrayRef<Value> unwrappedArgs = unwrapList(nArgValues, argValues, args);
+ unwrap(rewriter)->mergeBlocks(unwrap(source), unwrap(dest), unwrappedArgs);
+}
+
+// splitBlock is not implemented as Block::iterator is not exposed by the CAPI
+
+void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter, MlirOperation op,
+ MlirOperation existingOp) {
+ unwrap(rewriter)->moveOpBefore(unwrap(op), unwrap(existingOp));
+}
+
+void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter, MlirOperation op,
+ MlirOperation existingOp) {
+ unwrap(rewriter)->moveOpAfter(unwrap(op), unwrap(existingOp));
+}
+
+void mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block,
+ MlirBlock existingBlock) {
+ unwrap(rewriter)->moveBlockBefore(unwrap(block), unwrap(existingBlock));
+}
+
+void mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter,
+ MlirOperation op) {
+ unwrap(rewriter)->startOpModification(unwrap(op));
+}
+
+void mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter,
+ MlirOperation op) {
+ unwrap(rewriter)->finalizeOpModification(unwrap(op));
+}
+
+void mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter,
+ MlirOperation op) {
+ unwrap(rewriter)->cancelOpModification(unwrap(op));
+}
+
+void mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter,
+ MlirValue from, MlirValue to) {
+ unwrap(rewriter)->replaceAllUsesWith(unwrap(from), unwrap(to));
+}
+
+void mlirRewriterBaseReplaceAllValueRangeUsesWith(MlirRewriterBase rewriter,
+ intptr_t nValues,
+ MlirValue const *from,
+ MlirValue const *to) {
+ SmallVector<Value, 4> fromVals;
+ ArrayRef<Value> unwrappedFromVals = unwrapList(nValues, from, fromVals);
+ SmallVector<Value, 4> toVals;
+ ArrayRef<Value> unwrappedToVals = unwrapList(nValues, to, toVals);
+ unwrap(rewriter)->replaceAllUsesWith(unwrappedFromVals, unwrappedToVals);
+}
+
+void mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter,
+ MlirOperation from,
+ intptr_t nTo,
+ MlirValue const *to) {
+ SmallVector<Value, 4> toVals;
+ ArrayRef<Value> unwrappedToVals = unwrapList(nTo, to, toVals);
+ unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrappedToVals);
+}
+
+void mlirRewriterBaseReplaceAllOpUsesWithOperation(MlirRewriterBase rewriter,
+ MlirOperation from,
+ MlirOperation to) {
+ unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrap(to));
+}
+
+void mlirRewriterBaseReplaceOpUsesWithinBlock(MlirRewriterBase rewriter,
+ MlirOperation op,
+ intptr_t nNewValues,
+ MlirValue const *newValues,
+ MlirBlock block) {
+ SmallVector<Value, 4> vals;
+ ArrayRef<Value> unwrappedVals = unwrapList(nNewValues, newValues, vals);
+ unwrap(rewriter)->replaceOpUsesWithinBlock(unwrap(op), unwrappedVals,
+ unwrap(block));
+}
+
+void mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter,
+ MlirValue from, MlirValue to,
+ MlirOperation exceptedUser) {
+ unwrap(rewriter)->replaceAllUsesExcept(unwrap(from), unwrap(to),
+ unwrap(exceptedUser));
+}
+
+//===----------------------------------------------------------------------===//
+/// IRRewriter API
+//===----------------------------------------------------------------------===//
+
+MlirRewriterBase mlirIRRewriterCreate(MlirContext context) {
+ return wrap(new IRRewriter(unwrap(context)));
+}
+
+MlirRewriterBase mlirIRRewriterCreateFromOp(MlirOperation op) {
+ return wrap(new IRRewriter(unwrap(op)));
+}
+
+void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
+ delete static_cast<IRRewriter *>(unwrap(rewriter));
+}
+
+//===----------------------------------------------------------------------===//
+/// RewritePatternSet and FrozenRewritePatternSet API
+//===----------------------------------------------------------------------===//
+
inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
assert(module.ptr && "unexpected null module");
return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
@@ -54,6 +299,10 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op,
mlir::applyPatternsAndFoldGreedily(unwrap(op), *unwrap(patterns)));
}
+//===----------------------------------------------------------------------===//
+/// PDLPatternModule API
+//===----------------------------------------------------------------------===//
+
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
assert(module.ptr && "unexpected null module");
diff --git a/mlir/test/CAPI/CMakeLists.txt b/mlir/test/CAPI/CMakeLists.txt
index ad312764b3e06..e795672bce5d1 100644
--- a/mlir/test/CAPI/CMakeLists.txt
+++ b/mlir/test/CAPI/CMakeLists.txt
@@ -89,6 +89,15 @@ _add_capi_test_executable(mlir-capi-quant-test
MLIRCAPIQuant
)
+_add_capi_test_executable(mlir-capi-rewrite-test
+ rewrite.c
+ LINK_LIBS PRIVATE
+ MLIRCAPIIR
+ MLIRCAPIRegisterEverything
+ MLIRCAPITransforms
+)
+
+
_add_capi_test_executable(mlir-capi-transform-test
transform.c
LINK_LIBS PRIVATE
diff --git a/mlir/test/CAPI/rewrite.c b/mlir/test/CAPI/rewrite.c
new file mode 100644
index 0000000000000..a8b593eabb781
--- /dev/null
+++ b/mlir/test/CAPI/rewrite.c
@@ -0,0 +1,551 @@
+//===- rewrite.c - Test of the rewriting C API ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// RUN: mlir-capi-rewrite-test 2>&1 | FileCheck %s
+
+#include "mlir-c/Rewrite.h"
+#include "mlir-c/BuiltinTypes.h"
+#include "mlir-c/IR.h"
+
+#include <assert.h>
+#include <stdio.h>
+
+MlirOperation createOperationWithName(MlirContext ctx, const char *name) {
+ MlirStringRef nameRef = mlirStringRefCreateFromCString(name);
+ MlirLocation loc = mlirLocationUnknownGet(ctx);
+ MlirOperationState state = mlirOperationStateGet(nameRef, loc);
+ MlirType indexType = mlirIndexTypeGet(ctx);
+ mlirOperationStateAddResults(&state, 1, &indexType);
+ return mlirOperationCreate(&state);
+}
+
+void testInsertionPoint(MlirContext ctx) {
+ // CHECK-LABEL: @testInsertionPoint
+ fprintf(stderr, "@testInsertionPoint\n");
+
+ const char *moduleString = "\"dialect.op1\"() : () -> ()\n";
+ MlirModule module =
+ mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+ MlirOperation op = mlirModuleGetOperation(module);
+ MlirBlock body = mlirModuleGetBody(module);
+ MlirOperation op1 = mlirBlockGetFirstOperation(body);
+
+ // IRRewriter create
+ MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx);
+
+ // Insert before op
+ mlirRewriterBaseSetInsertionPointBefore(rewriter, op1);
+ MlirOperation op2 = createOperationWithName(ctx, "dialect.op2");
+ mlirRewriterBaseInsert(rewriter, op2);
+
+ // Insert after op
+ mlirRewriterBaseSetInsertionPointAfter(rewriter, op2);
+ MlirOperation op3 = createOperationWithName(ctx, "dialect.op3");
+ mlirRewriterBaseInsert(rewriter, op3);
+ MlirValue op3Res = mlirOperationGetResult(op3, 0);
+
+ // Insert after value
+ mlirRewriterBaseSetInsertionPointAfterValue(rewriter, op3Res);
+ MlirOperation op4 = createOperationWithName(ctx, "dialect.op4");
+ mlirRewriterBaseInsert(rewriter, op4);
+
+ // Insert at beginning of block
+ mlirRewriterBaseSetInsertionPointToStart(rewriter, body);
+ MlirOperation op5 = createOperationWithName(ctx, "dialect.op5");
+ mlirRewriterBaseInsert(rewriter, op5);
+
+ // Insert at end of block
+ mlirRewriterBaseSetInsertionPointToEnd(rewriter, body);
+ MlirOperation op6 = createOperationWithName(ctx, "dialect.op6");
+ mlirRewriterBaseInsert(rewriter, op6);
+
+ // Get insertion blocks
+ MlirBlock block1 = mlirRewriterBaseGetBlock(rewriter);
+ MlirBlock block2 = mlirRewriterBaseGetInsertionBlock(rewriter);
+ assert(body.ptr == block1.ptr);
+ assert(body.ptr == block2.ptr);
+
+ // clang-format off
+ // CHECK-NEXT: module {
+ // CHECK-NEXT: %{{.*}} = "dialect.op5"() : () -> index
+ // CHECK-NEXT: %{{.*}} = "dialect.op2"() : () -> index
+ // CHECK-NEXT: %{{.*}} = "dialect.op3"() : () -> index
+ // CHECK-NEXT: %{{.*}} = "dialect.op4"() : () -> index
+ // CHECK-NEXT: "dialect.op1"() : () -> ()
+ // CHECK-NEXT: %{{.*}} = "dialect.op6"() : () -> index
+ // CHECK-NEXT: }
+ // clang-format on
+ mlirOperationDump(op);
+
+ mlirIRRewriterDestroy(rewriter);
+ mlirModuleDestroy(module);
+}
+
+void testCreateBlock(MlirContext ctx) {
+ // CHECK-LABEL: @testCreateBlock
+ fprintf(stderr, "@testCreateBlock\n");
+
+ const char *moduleString = "\"dialect.op1\"() ({^bb0:}) : () -> ()\n"
+ "\"dialect.op2\"() ({^bb0:}) : () -> ()\n";
+ MlirModule module =
+ mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+ MlirOperation op = mlirModuleGetOperation(module);
+ MlirBlock body = mlirModuleGetBody(module);
+
+ MlirOperation op1 = mlirBlockGetFirstOperation(body);
+ MlirRegion region1 = mlirOperationGetRegion(op1, 0);
+ MlirBlock block1 = mlirRegionGetFirstBlock(region1);
+
+ MlirOperation op2 = mlirOperationGetNextInBlock(op1);
+ MlirRegion region2 = mlirOperationGetRegion(op2, 0);
+ MlirBlock block2 = mlirRegionGetFirstBlock(region2);
+
+ MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx);
+
+ // Create block before
+ MlirType indexType = mlirIndexTypeGet(ctx);
+ MlirLocation unknown = mlirLocationUnknownGet(ctx);
+ mlirRewriterBaseCreateBlockBefore(rewriter, block1, 1, &indexType, &unknown);
+
+ mlirRewriterBaseSetInsertionPointToEnd(rewriter, body);
+
+ // Clone operation
+ mlirRewriterBaseClone(rewriter, op1);
+
+ // Clone without regions
+ mlirRewriterBaseCloneWithoutRegions(rewriter, op1);
+
+ // Clone region before
+ mlirRewriterBaseCloneRegionBefore(rewriter, region1, block2);
+
+ mlirOperationDump(op);
+ // clang-format off
+ // CHECK-NEXT: "builtin.module"() ({
+ // CHECK-NEXT: "dialect.op1"() ({
+ // CHECK-NEXT: ^{{.*}}(%{{.*}}: index):
+ // CHECK-NEXT: ^{{.*}}:
+ // CHECK-NEXT: }) : () -> ()
+ // CHECK-NEXT: "dialect.op2"() ({
+ // CHECK-NEXT: ^{{.*}}(%{{.*}}: index):
+ // CHECK-NEXT: ^{{.*}}:
+ // CHECK-NEXT: ^{{.*}}:
+ // CHECK-NEXT: }) : () -> ()
+ // CHECK-NEXT: "dialect.op1"() ({
+ // CHECK-NEXT: ^{{.*}}(%{{.*}}: index):
+ // CHECK-NEXT: ^{{.*}}:
+ // CHECK-NEXT: }) : () -> ()
+ // CHECK-NEXT: "dialect.op1"() ({
+ // CHECK-NEXT: }) : () -> ()
+ // CHECK-NEXT: }) : () -> ()
+ // clang-format on
+
+ mlirIRRewriterDestroy(rewriter);
+ mlirModuleDestroy(module);
+}
+
+void testInlineRegionBlock(MlirContext ctx) {
+ // CHECK-LABEL: @testInlineRegionBlock
+ fprintf(stderr, "@testInlineRegionBlock\n");
+
+ const char *moduleString =
+ "\"dialect.op1\"() ({\n"
+ " ^bb0(%arg0: index):\n"
+ " \"dialect.op1_in1\"(%arg0) [^bb1] : (index) -> ()\n"
+ " ^bb1():\n"
+ " \"dialect.op1_in2\"() : () -> ()\n"
+ "}) : () -> ()\n"
+ "\"dialect.op2\"() ({^bb0:}) : () -> ()\n"
+ "\"dialect.op3\"() ({\n"
+ " ^bb0(%arg0: index):\n"
+ " \"dialect.op3_in1\"(%arg0) : (index) -> ()\n"
+ " ^bb1():\n"
+ " %x = \"dialect.op3_in2\"() : () -> index\n"
+ " %y = \"dialect.op3_in3\"() : () -> index\n"
+ "}) : () -> ()\n"
+ "\"dialect.op4\"() ({\n"
+ " ^bb0():\n"
+ " \"dialect.op4_in1\"() : () -> index\n"
+ " ^bb1(%arg0: index):\n"
+ " \"dialect.op4_in2\"(%arg0) : (index) -> ()\n"
+ "}) : () -> ()\n";
+ MlirModule module =
+ mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+ MlirOperation op = mlirModuleGetOperation(module);
+ MlirBlock body = mlirModuleGetBody(module);
+
+ MlirOperation op1 = mlirBlockGetFirstOperation(body);
+ MlirRegion region1 = mlirOperationGetRegion(op1, 0);
+
+ MlirOperation op2 = mlirOperationGetNextInBlock(op1);
+ MlirRegion region2 = mlirOperationGetRegion(op2, 0);
+ MlirBlock block2 = mlirRegionGetFirstBlock(region2);
+
+ MlirOperation op3 = mlirOperationGetNextInBlock(op2);
+ MlirRegion region3 = mlirOperationGetRegion(op3, 0);
+ MlirBlock block3_1 = mlirRegionGetFirstBlock(region3);
+ MlirBlock block3_2 = mlirBlockGetNextInRegion(block3_1);
+ MlirOperation op3_in2 = mlirBlockGetFirstOperation(block3_2);
+ MlirValue op3_in2_res = mlirOperationGetResult(op3_in2, 0);
+ MlirOperation op3_in3 = mlirOperationGetNextInBlock(op3_in2);
+
+ MlirOperation op4 = mlirOperationGetNextInBlock(op3);
+ MlirRegion region4 = mlirOperationGetRegion(op4, 0);
+ MlirBlock block4_1 = mlirRegionGetFirstBlock(region4);
+ MlirOperation op4_in1 = mlirBlockGetFirstOperation(block4_1);
+ MlirValue op4_in1_res = mlirOperationGetResult(op4_in1, 0);
+ MlirBlock block4_2 = mlirBlockGetNextInRegion(block4_1);
+
+ MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx);
+
+ // Test these three functions
+ mlirRewriterBaseInlineRegionBefore(rewriter, region1, block2);
+ mlirRewriterBaseInlineBlockBefore(rewriter, block3_1, op3_in3, 1,
+ &op3_in2_res);
+ mlirRewriterBaseMergeBlocks(rewriter, block4_2, block4_1, 1, &op4_in1_res);
+
+ mlirOperationDump(op);
+ // clang-format off
+ // CHECK-NEXT: "builtin.module"() ({
+ // CHECK-NEXT: "dialect.op1"() ({
+ // CHECK-NEXT: }) : () -> ()
+ // CHECK-NEXT: "dialect.op2"() ({
+ // CHECK-NEXT: ^{{.*}}(%{{.*}}: index):
+ // CHECK-NEXT: "dialect.op1_in1"(%{{.*}})[^[[bb:.*]]] : (index) -> ()
+ // CHECK-NEXT: ^[[bb]]:
+ // CHECK-NEXT: "dialect.op1_in2"() : () -> ()
+ // CHECK-NEXT: ^{{.*}}: // no predecessors
+ // CHECK-NEXT: }) : () -> ()
+ // CHECK-NEXT: "dialect.op3"() ({
+ // CHECK-NEXT: %{{.*}} = "dialect.op3_in2"() : () -> index
+ // CHECK-NEXT: "dialect.op3_in1"(%{{.*}}) : (index) -> ()
+ // CHECK-NEXT: %{{.*}} = "dialect.op3_in3"() : () -> index
+ // CHECK-NEXT: }) : () -> ()
+ // CHECK-NEXT: "dialect.op4"() ({
+ // CHECK-NEXT: %{{.*}} = "dialect.op4_in1"() : () -> index
+ // CHECK-NEXT: "dialect.op4_in2"(%{{.*}}) : (index) -> ()
+ // CHECK-NEXT: }) : () -> ()
+ // CHECK-NEXT: }) : () -> ()
+ // clang-format on
+
+ mlirIRRewriterDestroy(rewriter);
+ mlirModuleDestroy(module);
+}
+
+void testReplaceOp(MlirContext ctx) {
+ // CHECK-LABEL: @testReplaceOp
+ fprintf(stderr, "@testReplaceOp\n");
+
+ const char *moduleString =
+ "%x, %y, %z = \"dialect.create_values\"() : () -> (index, index, index)\n"
+ "%x_1, %y_1 = \"dialect.op1\"() : () -> (index, index)\n"
+ "\"dialect.use_op1\"(%x_1, %y_1) : (index, index) -> ()\n"
+ "%x_2, %y_2 = \"dialect.op2\"() : () -> (index, index)\n"
+ "%x_3, %y_3 = \"dialect.op3\"() : () -> (index, index)\n"
+ "\"dialect.use_op2\"(%x_2, %y_2) : (index, index) -> ()\n";
+ MlirModule module =
+ mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+ MlirOperation op = mlirModuleGetOperation(module);
+ MlirBlock body = mlirModuleGetBody(module);
+
+ // get a handle to all operations/values
+ MlirOperation createValues = mlirBlockGetFirstOperation(body);
+ MlirValue x = mlirOperationGetResult(createValues, 0);
+ MlirValue z = mlirOperationGetResult(createValues, 2);
+ MlirOperation op1 = mlirOperationGetNextInBlock(createValues);
+ MlirOperation useOp1 = mlirOperationGetNextInBlock(op1);
+ MlirOperation op2 = mlirOperationGetNextInBlock(useOp1);
+ MlirOperation op3 = mlirOperationGetNextInBlock(op2);
+
+ MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx);
+
+ // Test replace op with values
+ MlirValue xz[2] = {x, z};
+ mlirRewriterBaseReplaceOpWithValues(rewriter, op1, 2, xz);
+
+ // Test replace op with op
+ mlirRewriterBaseReplaceOpWithOperation(rewriter, op2, op3);
+
+ mlirOperationDump(op);
+ // clang-format off
+ // CHECK-NEXT: module {
+ // CHECK-NEXT: %[[res:.*]]:3 = "dialect.create_values"() : () -> (index, index, index)
+ // CHECK-NEXT: "dialect.use_op1"(%[[res]]#0, %[[res]]#2) : (index, index) -> ()
+ // CHECK-NEXT: %[[res2:.*]]:2 = "dialect.op3"() : () -> (index, index)
+ // CHECK-NEXT: "dialect.use_op2"(%[[res2]]#0, %[[res2]]#1) : (index, index) -> ()
+ // CHECK-NEXT: }
+ // clang-format on
+
+ mlirIRRewriterDestroy(rewriter);
+ mlirModuleDestroy(module);
+}
+
+void testErase(MlirContext ctx) {
+ // CHECK-LABEL: @testErase
+ fprintf(stderr, "@testErase\n");
+
+ const char *moduleString = "\"dialect.op_to_erase\"() : () -> ()\n"
+ "\"dialect.op2\"() ({\n"
+ "^bb0():\n"
+ " \"dialect.op2_nested\"() : () -> ()"
+ "^block_to_erase():\n"
+ " \"dialect.op2_nested\"() : () -> ()"
+ "^bb1():\n"
+ " \"dialect.op2_nested\"() : () -> ()"
+ "}) : () -> ()\n";
+ MlirModule module =
+ mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+ MlirOperation op = mlirModuleGetOperation(module);
+ MlirBlock body = mlirModuleGetBody(module);
+
+ // get a handle to all operations/values
+ MlirOperation opToErase = mlirBlockGetFirstOperation(body);
+ MlirOperation op2 = mlirOperationGetNextInBlock(opToErase);
+ MlirRegion op2Region = mlirOperationGetRegion(op2, 0);
+ MlirBlock bb0 = mlirRegionGetFirstBlock(op2Region);
+ MlirBlock blockToErase = mlirBlockGetNextInRegion(bb0);
+
+ MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx);
+ mlirRewriterBaseEraseOp(rewriter, opToErase);
+ mlirRewriterBaseEraseBlock(rewriter, blockToErase);
+
+ mlirOperationDump(op);
+ // CHECK-NEXT: module {
+ // CHECK-NEXT: "dialect.op2"() ({
+ // CHECK-NEXT: "dialect.op2_nested"() : () -> ()
+ // CHECK-NEXT: ^{{.*}}:
+ // CHECK-NEXT: "dialect.op2_nested"() : () -> ()
+ // CHECK-NEXT: }) : () -> ()
+ // CHECK-NEXT: }
+
+ mlirIRRewriterDestroy(rewriter);
+ mlirModuleDestroy(module);
+}
+
+void testMove(MlirContext ctx) {
+ // CHECK-LABEL: @testMove
+ fprintf(stderr, "@testMove\n");
+
+ const char *moduleString = "\"dialect.op1\"() : () -> ()\n"
+ "\"dialect.op2\"() ({\n"
+ "^bb0(%arg0: index):\n"
+ " \"dialect.op2_1\"(%arg0) : (index) -> ()"
+ "^bb1(%arg1: index):\n"
+ " \"dialect.op2_2\"(%arg1) : (index) -> ()"
+ "}) : () -> ()\n"
+ "\"dialect.op3\"() : () -> ()\n"
+ "\"dialect.op4\"() : () -> ()\n";
+
+ MlirModule module =
+ mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+ MlirOperation op = mlirModuleGetOperation(module);
+ MlirBlock body = mlirModuleGetBody(module);
+
+ // get a handle to all operations/values
+ MlirOperation op1 = mlirBlockGetFirstOperation(body);
+ MlirOperation op2 = mlirOperationGetNextInBlock(op1);
+ MlirOperation op3 = mlirOperationGetNextInBlock(op2);
+ MlirOperation op4 = mlirOperationGetNextInBlock(op3);
+
+ MlirRegion region2 = mlirOperationGetRegion(op2, 0);
+ MlirBlock block0 = mlirRegionGetFirstBlock(region2);
+ MlirBlock block1 = mlirBlockGetNextInRegion(block0);
+
+ // Test move operations.
+ MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx);
+ mlirRewriterBaseMoveOpBefore(rewriter, op3, op1);
+ mlirRewriterBaseMoveOpAfter(rewriter, op4, op1);
+ mlirRewriterBaseMoveBlockBefore(rewriter, block1, block0);
+
+ mlirOperationDump(op);
+ // CHECK-NEXT: module {
+ // CHECK-NEXT: "dialect.op3"() : () -> ()
+ // CHECK-NEXT: "dialect.op1"() : () -> ()
+ // CHECK-NEXT: "dialect.op4"() : () -> ()
+ // CHECK-NEXT: "dialect.op2"() ({
+ // CHECK-NEXT: ^{{.*}}(%[[arg0:.*]]: index):
+ // CHECK-NEXT: "dialect.op2_2"(%[[arg0]]) : (index) -> ()
+ // CHECK-NEXT: ^{{.*}}(%[[arg1:.*]]: index): // no predecessors
+ // CHECK-NEXT: "dialect.op2_1"(%[[arg1]]) : (index) -> ()
+ // CHECK-NEXT: }) : () -> ()
+ // CHECK-NEXT: }
+
+ mlirIRRewriterDestroy(rewriter);
+ mlirModuleDestroy(module);
+}
+
+void testOpModification(MlirContext ctx) {
+ // CHECK-LABEL: @testOpModification
+ fprintf(stderr, "@testOpModification\n");
+
+ const char *moduleString =
+ "%x, %y = \"dialect.op1\"() : () -> (index, index)\n"
+ "\"dialect.op2\"(%x) : (index) -> ()\n";
+
+ MlirModule module =
+ mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+ MlirOperation op = mlirModuleGetOperation(module);
+ MlirBlock body = mlirModuleGetBody(module);
+
+ // get a handle to all operations/values
+ MlirOperation op1 = mlirBlockGetFirstOperation(body);
+ MlirValue y = mlirOperationGetResult(op1, 1);
+ MlirOperation op2 = mlirOperationGetNextInBlock(op1);
+
+ MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx);
+ mlirRewriterBaseStartOpModification(rewriter, op1);
+ mlirRewriterBaseCancelOpModification(rewriter, op1);
+
+ mlirRewriterBaseStartOpModification(rewriter, op2);
+ mlirOperationSetOperand(op2, 0, y);
+ mlirRewriterBaseFinalizeOpModification(rewriter, op2);
+
+ mlirOperationDump(op);
+ // CHECK-NEXT: module {
+ // CHECK-NEXT: %[[xy:.*]]:2 = "dialect.op1"() : () -> (index, index)
+ // CHECK-NEXT: "dialect.op2"(%[[xy]]#1) : (index) -> ()
+ // CHECK-NEXT: }
+
+ mlirIRRewriterDestroy(rewriter);
+ mlirModuleDestroy(module);
+}
+
+void testReplaceUses(MlirContext ctx) {
+ // CHECK-LABEL: @testReplaceUses
+ fprintf(stderr, "@testReplaceUses\n");
+
+ const char *moduleString =
+ // Replace values with values
+ "%x1, %y1, %z1 = \"dialect.op1\"() : () -> (index, index, index)\n"
+ "%x2, %y2, %z2 = \"dialect.op2\"() : () -> (index, index, index)\n"
+ "\"dialect.op1_uses\"(%x1, %y1, %z1) : (index, index, index) -> ()\n"
+ // Replace op with values
+ "%x3 = \"dialect.op3\"() : () -> index\n"
+ "%x4 = \"dialect.op4\"() : () -> index\n"
+ "\"dialect.op3_uses\"(%x3) : (index) -> ()\n"
+ // Replace op with op
+ "%x5 = \"dialect.op5\"() : () -> index\n"
+ "%x6 = \"dialect.op6\"() : () -> index\n"
+ "\"dialect.op5_uses\"(%x5) : (index) -> ()\n"
+ // Replace op in block;
+ "%x7 = \"dialect.op7\"() : () -> index\n"
+ "%x8 = \"dialect.op8\"() : () -> index\n"
+ "\"dialect.op9\"() ({\n"
+ "^bb0:\n"
+ " \"dialect.op7_uses\"(%x7) : (index) -> ()\n"
+ "}): () -> ()\n"
+ "\"dialect.op7_uses\"(%x7) : (index) -> ()\n"
+ // Replace value with value except in op
+ "%x10 = \"dialect.op10\"() : () -> index\n"
+ "%x11 = \"dialect.op11\"() : () -> index\n"
+ "\"dialect.op10_uses\"(%x10) : (index) -> ()\n"
+ "\"dialect.op10_uses\"(%x10) : (index) -> ()\n";
+
+ MlirModule module =
+ mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+ MlirOperation op = mlirModuleGetOperation(module);
+ MlirBlock body = mlirModuleGetBody(module);
+
+ // get a handle to all operations/values
+ MlirOperation op1 = mlirBlockGetFirstOperation(body);
+ MlirValue x1 = mlirOperationGetResult(op1, 0);
+ MlirValue y1 = mlirOperationGetResult(op1, 1);
+ MlirValue z1 = mlirOperationGetResult(op1, 2);
+ MlirOperation op2 = mlirOperationGetNextInBlock(op1);
+ MlirValue x2 = mlirOperationGetResult(op2, 0);
+ MlirValue y2 = mlirOperationGetResult(op2, 1);
+ MlirValue z2 = mlirOperationGetResult(op2, 2);
+ MlirOperation op1Uses = mlirOperationGetNextInBlock(op2);
+
+ MlirOperation op3 = mlirOperationGetNextInBlock(op1Uses);
+ MlirOperation op4 = mlirOperationGetNextInBlock(op3);
+ MlirValue x4 = mlirOperationGetResult(op4, 0);
+ MlirOperation op3Uses = mlirOperationGetNextInBlock(op4);
+
+ MlirOperation op5 = mlirOperationGetNextInBlock(op3Uses);
+ MlirOperation op6 = mlirOperationGetNextInBlock(op5);
+ MlirOperation op5Uses = mlirOperationGetNextInBlock(op6);
+
+ MlirOperation op7 = mlirOperationGetNextInBlock(op5Uses);
+ MlirOperation op8 = mlirOperationGetNextInBlock(op7);
+ MlirValue x8 = mlirOperationGetResult(op8, 0);
+ MlirOperation op9 = mlirOperationGetNextInBlock(op8);
+ MlirRegion region9 = mlirOperationGetRegion(op9, 0);
+ MlirBlock block9 = mlirRegionGetFirstBlock(region9);
+ MlirOperation op7Uses = mlirOperationGetNextInBlock(op9);
+
+ MlirOperation op10 = mlirOperationGetNextInBlock(op7Uses);
+ MlirValue x10 = mlirOperationGetResult(op10, 0);
+ MlirOperation op11 = mlirOperationGetNextInBlock(op10);
+ MlirValue x11 = mlirOperationGetResult(op11, 0);
+ MlirOperation op10Uses1 = mlirOperationGetNextInBlock(op11);
+
+ MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx);
+
+ // Replace values
+ mlirRewriterBaseReplaceAllUsesWith(rewriter, x1, x2);
+ MlirValue y1z1[2] = {y1, z1};
+ MlirValue y2z2[2] = {y2, z2};
+ mlirRewriterBaseReplaceAllValueRangeUsesWith(rewriter, 2, y1z1, y2z2);
+
+ // Replace op with values
+ mlirRewriterBaseReplaceOpWithValues(rewriter, op3, 1, &x4);
+
+ // Replace op with op
+ mlirRewriterBaseReplaceOpWithOperation(rewriter, op5, op6);
+
+ // Replace op with op in block
+ mlirRewriterBaseReplaceOpUsesWithinBlock(rewriter, op7, 1, &x8, block9);
+
+ // Replace value with value except in op
+ mlirRewriterBaseReplaceAllUsesExcept(rewriter, x10, x11, op10Uses1);
+
+ mlirOperationDump(op);
+ // clang-format off
+ // CHECK-NEXT: module {
+ // CHECK-NEXT: %{{.*}}:3 = "dialect.op1"() : () -> (index, index, index)
+ // CHECK-NEXT: %[[res2:.*]]:3 = "dialect.op2"() : () -> (index, index, index)
+ // CHECK-NEXT: "dialect.op1_uses"(%[[res2]]#0, %[[res2]]#1, %[[res2]]#2) : (index, index, index) -> ()
+ // CHECK-NEXT: %[[res4:.*]] = "dialect.op4"() : () -> index
+ // CHECK-NEXT: "dialect.op3_uses"(%[[res4]]) : (index) -> ()
+ // CHECK-NEXT: %[[res6:.*]] = "dialect.op6"() : () -> index
+ // CHECK-NEXT: "dialect.op5_uses"(%[[res6]]) : (index) -> ()
+ // CHECK-NEXT: %[[res7:.*]] = "dialect.op7"() : () -> index
+ // CHECK-NEXT: %[[res8:.*]] = "dialect.op8"() : () -> index
+ // CHECK-NEXT: "dialect.op9"() ({
+ // CHECK-NEXT: "dialect.op7_uses"(%[[res8]]) : (index) -> ()
+ // CHECK-NEXT: }) : () -> ()
+ // CHECK-NEXT: "dialect.op7_uses"(%[[res7]]) : (index) -> ()
+ // CHECK-NEXT: %[[res10:.*]] = "dialect.op10"() : () -> index
+ // CHECK-NEXT: %[[res11:.*]] = "dialect.op11"() : () -> index
+ // CHECK-NEXT: "dialect.op10_uses"(%[[res10]]) : (index) -> ()
+ // CHECK-NEXT: "dialect.op10_uses"(%[[res11]]) : (index) -> ()
+ // CHECK-NEXT: }
+ // clang-format on
+
+ mlirIRRewriterDestroy(rewriter);
+ mlirModuleDestroy(module);
+}
+
+int main(void) {
+ MlirContext ctx = mlirContextCreate();
+ mlirContextSetAllowUnregisteredDialects(ctx, true);
+ mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("builtin"));
+
+ testInsertionPoint(ctx);
+ testCreateBlock(ctx);
+ testInlineRegionBlock(ctx);
+ testReplaceOp(ctx);
+ testErase(ctx);
+ testMove(ctx);
+ testOpModification(ctx);
+ testReplaceUses(ctx);
+
+ mlirContextDestroy(ctx);
+ return 0;
+}
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 45009a78aa49f..df95e5db11f1e 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -105,6 +105,7 @@ set(MLIR_TEST_DEPENDS
mlir-capi-llvm-test
mlir-capi-pass-test
mlir-capi-quant-test
+ mlir-capi-rewrite-test
mlir-capi-sparse-tensor-test
mlir-capi-transform-test
mlir-capi-transform-interpreter-test
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 1175f87877f9e..98d0ddd9a2be1 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -106,6 +106,7 @@ def add_runtime(name):
"mlir-capi-pass-test",
"mlir-capi-pdl-test",
"mlir-capi-quant-test",
+ "mlir-capi-rewrite-test",
"mlir-capi-sparse-tensor-test",
"mlir-capi-transform-test",
"mlir-capi-transform-interpreter-test",
More information about the Mlir-commits
mailing list