[Mlir-commits] [mlir] [mlir][c] Expose AsmState. (PR #66693)
Jacques Pienaar
llvmlistbot at llvm.org
Mon Sep 18 17:51:01 PDT 2023
https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/66693
>From e331c033bda32f81faf663dd4531efe31fe7ef51 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Mon, 18 Sep 2023 17:50:30 -0700
Subject: [PATCH] [mlir][c] Expose AsmState.
Enable usage where capturing AsmState is good. Haven't plumbed through to python yet.
---
mlir/include/mlir-c/IR.h | 26 ++++++++++++++-
mlir/include/mlir/CAPI/IR.h | 1 +
mlir/lib/Bindings/Python/IRCore.cpp | 4 ++-
mlir/lib/CAPI/IR/IR.cpp | 49 +++++++++++++++++++++++++++--
mlir/test/CAPI/ir.c | 7 +++++
5 files changed, 83 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index b5c6a3094bc67df..68eccab6dbacaef 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -48,6 +48,7 @@ extern "C" {
}; \
typedef struct name name
+DEFINE_C_API_STRUCT(MlirAsmState, void);
DEFINE_C_API_STRUCT(MlirBytecodeWriterConfig, void);
DEFINE_C_API_STRUCT(MlirContext, void);
DEFINE_C_API_STRUCT(MlirDialect, void);
@@ -383,6 +384,29 @@ mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
MLIR_CAPI_EXPORTED void
mlirOperationStateEnableResultTypeInference(MlirOperationState *state);
+//===----------------------------------------------------------------------===//
+// AsmState API.
+// While many of these are simple settings that could be represented in a
+// struct, they are wrapped in a heap allocated object and accessed via
+// functions to maximize the possibility of compatibility over time.
+//===----------------------------------------------------------------------===//
+
+/// Creates new AsmState, as with AsmState the IR should not be mutated
+/// in-between using this state.
+/// Must be freed with a call to mlirAsmStateDestroy().
+// TODO: This should be expanded to handle location & resouce map.
+MLIR_CAPI_EXPORTED MlirAsmState
+mlirAsmStateCreateForOperation(MlirOperation op, MlirOpPrintingFlags flags);
+
+/// Creates new AsmState from value.
+/// Must be freed with a call to mlirAsmStateDestroy().
+// TODO: This should be expanded to handle location & resouce map.
+MLIR_CAPI_EXPORTED MlirAsmState
+mlirAsmStateCreateForValue(MlirValue value, MlirOpPrintingFlags flags);
+
+/// Destroys printing flags created with mlirAsmStateCreate.
+MLIR_CAPI_EXPORTED void mlirAsmStateDestroy(MlirAsmState state);
+
//===----------------------------------------------------------------------===//
// Op Printing flags API.
// While many of these are simple settings that could be represented in a
@@ -815,7 +839,7 @@ mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData);
/// Prints a value as an operand (i.e., the ValueID).
MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value,
- MlirOpPrintingFlags flags,
+ MlirAsmState state,
MlirStringCallback callback,
void *userData);
diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index b8ccec896c27ba5..1836cb0acb67e7e 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -21,6 +21,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
+DEFINE_C_API_PTR_METHODS(MlirAsmState, mlir::AsmState)
DEFINE_C_API_PTR_METHODS(MlirBytecodeWriterConfig, mlir::BytecodeWriterConfig)
DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext)
DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index b06937bc285e206..af713547cccbb27 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3430,9 +3430,11 @@ void mlir::python::populateIRCore(py::module &m) {
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
if (useLocalScope)
mlirOpPrintingFlagsUseLocalScope(flags);
- mlirValuePrintAsOperand(self.get(), flags, printAccum.getCallback(),
+ MlirAsmState state = mlirAsmStateCreateForValue(self.get(), flags);
+ mlirValuePrintAsOperand(self.get(), state, printAccum.getCallback(),
printAccum.getUserData());
mlirOpPrintingFlagsDestroy(flags);
+ mlirAsmStateDestroy(state);
return printAccum.join();
},
py::arg("use_local_scope") = false, kGetNameAsOperand)
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index ef234a912490eea..7f5c2aaee67382b 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -138,6 +138,51 @@ void mlirDialectRegistryDestroy(MlirDialectRegistry registry) {
delete unwrap(registry);
}
+//===----------------------------------------------------------------------===//
+// AsmState API.
+//===----------------------------------------------------------------------===//
+
+MlirAsmState mlirAsmStateCreateForOperation(MlirOperation op,
+ MlirOpPrintingFlags flags) {
+ return wrap(new AsmState(unwrap(op), *unwrap(flags)));
+}
+
+static Operation *findParent(Operation *op, bool shouldUseLocalScope) {
+ do {
+ // If we are printing local scope, stop at the first operation that is
+ // isolated from above.
+ if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
+ break;
+
+ // Otherwise, traverse up to the next parent.
+ Operation *parentOp = op->getParentOp();
+ if (!parentOp)
+ break;
+ op = parentOp;
+ } while (true);
+ return op;
+}
+
+MlirAsmState mlirAsmStateCreateForValue(MlirValue value,
+ MlirOpPrintingFlags flags) {
+ Operation *op;
+ mlir::Value val = unwrap(value);
+ if (auto result = llvm::dyn_cast<OpResult>(val)) {
+ op = result.getOwner();
+ } else {
+ op = llvm::cast<BlockArgument>(val).getOwner()->getParentOp();
+ if (!op) {
+ emitError(val.getLoc()) << "<<UNKNOWN SSA VALUE>>";
+ return {nullptr};
+ }
+ }
+ op = findParent(op, unwrap(flags)->shouldUseLocalScope());
+ return wrap(new AsmState(op, *unwrap(flags)));
+}
+
+/// Destroys printing flags created with mlirAsmStateCreate.
+void mlirAsmStateDestroy(MlirAsmState state) { delete unwrap(state); }
+
//===----------------------------------------------------------------------===//
// Printing flags API.
//===----------------------------------------------------------------------===//
@@ -840,11 +885,11 @@ void mlirValuePrint(MlirValue value, MlirStringCallback callback,
unwrap(value).print(stream);
}
-void mlirValuePrintAsOperand(MlirValue value, MlirOpPrintingFlags flags,
+void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state,
MlirStringCallback callback, void *userData) {
detail::CallbackOstream stream(callback, userData);
Value cppValue = unwrap(value);
- cppValue.printAsOperand(stream, *unwrap(flags));
+ cppValue.printAsOperand(stream, *unwrap(state));
}
MlirOpOperand mlirValueGetFirstUse(MlirValue value) {
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 5725d05a3e132f7..2a40d97e7e651fe 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -487,6 +487,13 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
// CHECK: Op print with all flags: %{{.*}} = "arith.constant"() <{value = 0 : index}> {elts = dense_resource<__elided__> : tensor<4xi32>} : () -> index loc(unknown)
// clang-format on
+ MlirAsmState state = mlirAsmStateCreate(parentOperation, flags);
+ fprintf(stderr, "With state: |");
+ mlirValuePrintAsOperand(value, state, printToStderr, NULL);
+ // CHECK: With state: |%0|
+ fprintf(stderr, "|\n");
+ mlirAsmStateDestroy(state);
+
mlirOpPrintingFlagsDestroy(flags);
}
More information about the Mlir-commits
mailing list