[Mlir-commits] [mlir] [mlir][c] Expose AsmState. (PR #66693)

Jacques Pienaar llvmlistbot at llvm.org
Mon Sep 18 17:45:04 PDT 2023


https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/66693

>From 49b47012814b06986be20bb53368c2ffcf23d641 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Mon, 18 Sep 2023 17:43:50 -0700
Subject: [PATCH] [mlir][c] Expose AsmState.

Enable usage where capturing AsmState is good. Haven't plumbed through to python yet.
---
 mlir/include/mlir-c/IR.h            |  26 ++++-
 mlir/include/mlir/CAPI/IR.h         |   1 +
 mlir/lib/Bindings/Python/IRCore.cpp |   4 +-
 mlir/lib/CAPI/IR/IR.cpp             |  49 +++++++-
 mlir/test/CAPI/ir.c                 |   7 ++
 mlir/test/python/ir/value.py        | 175 ++++++++++++++--------------
 6 files changed, 171 insertions(+), 91 deletions(-)

diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index b5c6a3094bc67df..68eccab6dbacaef 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -48,6 +48,7 @@ extern "C" {
   };                                                                           \
   typedef struct name name
 
+DEFINE_C_API_STRUCT(MlirAsmState, void);
 DEFINE_C_API_STRUCT(MlirBytecodeWriterConfig, void);
 DEFINE_C_API_STRUCT(MlirContext, void);
 DEFINE_C_API_STRUCT(MlirDialect, void);
@@ -383,6 +384,29 @@ mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
 MLIR_CAPI_EXPORTED void
 mlirOperationStateEnableResultTypeInference(MlirOperationState *state);
 
+//===----------------------------------------------------------------------===//
+// AsmState API.
+// While many of these are simple settings that could be represented in a
+// struct, they are wrapped in a heap allocated object and accessed via
+// functions to maximize the possibility of compatibility over time.
+//===----------------------------------------------------------------------===//
+
+/// Creates new AsmState, as with AsmState the IR should not be mutated
+/// in-between using this state.
+/// Must be freed with a call to mlirAsmStateDestroy().
+// TODO: This should be expanded to handle location & resouce map.
+MLIR_CAPI_EXPORTED MlirAsmState
+mlirAsmStateCreateForOperation(MlirOperation op, MlirOpPrintingFlags flags);
+
+/// Creates new AsmState from value.
+/// Must be freed with a call to mlirAsmStateDestroy().
+// TODO: This should be expanded to handle location & resouce map.
+MLIR_CAPI_EXPORTED MlirAsmState
+mlirAsmStateCreateForValue(MlirValue value, MlirOpPrintingFlags flags);
+
+/// Destroys printing flags created with mlirAsmStateCreate.
+MLIR_CAPI_EXPORTED void mlirAsmStateDestroy(MlirAsmState state);
+
 //===----------------------------------------------------------------------===//
 // Op Printing flags API.
 // While many of these are simple settings that could be represented in a
@@ -815,7 +839,7 @@ mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData);
 
 /// Prints a value as an operand (i.e., the ValueID).
 MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value,
-                                                MlirOpPrintingFlags flags,
+                                                MlirAsmState state,
                                                 MlirStringCallback callback,
                                                 void *userData);
 
diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index b8ccec896c27ba5..1836cb0acb67e7e 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -21,6 +21,7 @@
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Operation.h"
 
+DEFINE_C_API_PTR_METHODS(MlirAsmState, mlir::AsmState)
 DEFINE_C_API_PTR_METHODS(MlirBytecodeWriterConfig, mlir::BytecodeWriterConfig)
 DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext)
 DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index b06937bc285e206..af713547cccbb27 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3430,9 +3430,11 @@ void mlir::python::populateIRCore(py::module &m) {
             MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
             if (useLocalScope)
               mlirOpPrintingFlagsUseLocalScope(flags);
-            mlirValuePrintAsOperand(self.get(), flags, printAccum.getCallback(),
+            MlirAsmState state = mlirAsmStateCreateForValue(self.get(), flags);
+            mlirValuePrintAsOperand(self.get(), state, printAccum.getCallback(),
                                     printAccum.getUserData());
             mlirOpPrintingFlagsDestroy(flags);
+            mlirAsmStateDestroy(state);
             return printAccum.join();
           },
           py::arg("use_local_scope") = false, kGetNameAsOperand)
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index ef234a912490eea..7f5c2aaee67382b 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -138,6 +138,51 @@ void mlirDialectRegistryDestroy(MlirDialectRegistry registry) {
   delete unwrap(registry);
 }
 
+//===----------------------------------------------------------------------===//
+// AsmState API.
+//===----------------------------------------------------------------------===//
+
+MlirAsmState mlirAsmStateCreateForOperation(MlirOperation op,
+                                            MlirOpPrintingFlags flags) {
+  return wrap(new AsmState(unwrap(op), *unwrap(flags)));
+}
+
+static Operation *findParent(Operation *op, bool shouldUseLocalScope) {
+  do {
+    // If we are printing local scope, stop at the first operation that is
+    // isolated from above.
+    if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
+      break;
+
+    // Otherwise, traverse up to the next parent.
+    Operation *parentOp = op->getParentOp();
+    if (!parentOp)
+      break;
+    op = parentOp;
+  } while (true);
+  return op;
+}
+
+MlirAsmState mlirAsmStateCreateForValue(MlirValue value,
+                                        MlirOpPrintingFlags flags) {
+  Operation *op;
+  mlir::Value val = unwrap(value);
+  if (auto result = llvm::dyn_cast<OpResult>(val)) {
+    op = result.getOwner();
+  } else {
+    op = llvm::cast<BlockArgument>(val).getOwner()->getParentOp();
+    if (!op) {
+      emitError(val.getLoc()) << "<<UNKNOWN SSA VALUE>>";
+      return {nullptr};
+    }
+  }
+  op = findParent(op, unwrap(flags)->shouldUseLocalScope());
+  return wrap(new AsmState(op, *unwrap(flags)));
+}
+
+/// Destroys printing flags created with mlirAsmStateCreate.
+void mlirAsmStateDestroy(MlirAsmState state) { delete unwrap(state); }
+
 //===----------------------------------------------------------------------===//
 // Printing flags API.
 //===----------------------------------------------------------------------===//
@@ -840,11 +885,11 @@ void mlirValuePrint(MlirValue value, MlirStringCallback callback,
   unwrap(value).print(stream);
 }
 
-void mlirValuePrintAsOperand(MlirValue value, MlirOpPrintingFlags flags,
+void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state,
                              MlirStringCallback callback, void *userData) {
   detail::CallbackOstream stream(callback, userData);
   Value cppValue = unwrap(value);
-  cppValue.printAsOperand(stream, *unwrap(flags));
+  cppValue.printAsOperand(stream, *unwrap(state));
 }
 
 MlirOpOperand mlirValueGetFirstUse(MlirValue value) {
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 5725d05a3e132f7..2a40d97e7e651fe 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -487,6 +487,13 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
   // CHECK: Op print with all flags: %{{.*}} = "arith.constant"() <{value = 0 : index}> {elts = dense_resource<__elided__> : tensor<4xi32>} : () -> index loc(unknown)
   // clang-format on
 
+  MlirAsmState state = mlirAsmStateCreate(parentOperation, flags);
+  fprintf(stderr, "With state: |");
+  mlirValuePrintAsOperand(value, state, printToStderr, NULL);
+  // CHECK: With state: |%0|
+  fprintf(stderr, "|\n");
+  mlirAsmStateDestroy(state);
+
   mlirOpPrintingFlagsDestroy(flags);
 }
 
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 46a50ac5291e8d9..8eaa75b2d3aebf4 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -151,93 +151,94 @@ def testValueReplaceAllUsesWith():
 # CHECK-LABEL: TEST: testValuePrintAsOperand
 @run
 def testValuePrintAsOperand():
-    ctx = Context()
-    ctx.allow_unregistered_dialects = True
-    with Location.unknown(ctx):
-        i32 = IntegerType.get_signless(32)
-        module = Module.create()
-        with InsertionPoint(module.body):
-            value = Operation.create("custom.op1", results=[i32]).results[0]
-            # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
-            print(value)
-
-            value2 = Operation.create("custom.op2", results=[i32]).results[0]
-            # CHECK: Value(%[[VAL2:.*]] = "custom.op2"() : () -> i32)
-            print(value2)
-
-            f = func.FuncOp("test", ([i32, i32], []))
-            entry_block1 = Block.create_at_start(f.operation.regions[0], [i32, i32])
-
-            with InsertionPoint(entry_block1):
-                value3 = Operation.create("custom.op3", results=[i32]).results[0]
-                # CHECK: Value(%[[VAL3:.*]] = "custom.op3"() : () -> i32)
-                print(value3)
-                value4 = Operation.create("custom.op4", results=[i32]).results[0]
-                # CHECK: Value(%[[VAL4:.*]] = "custom.op4"() : () -> i32)
-                print(value4)
-
-                f = func.FuncOp("test", ([i32, i32], []))
-                entry_block2 = Block.create_at_start(f.operation.regions[0], [i32, i32])
-                with InsertionPoint(entry_block2):
-                    value5 = Operation.create("custom.op5", results=[i32]).results[0]
-                    # CHECK: Value(%[[VAL5:.*]] = "custom.op5"() : () -> i32)
-                    print(value5)
-                    value6 = Operation.create("custom.op6", results=[i32]).results[0]
-                    # CHECK: Value(%[[VAL6:.*]] = "custom.op6"() : () -> i32)
-                    print(value6)
-
-                    func.ReturnOp([])
-
-                func.ReturnOp([])
-
-        # CHECK: %[[VAL1]]
-        print(value.get_name())
-        # CHECK: %[[VAL2]]
-        print(value2.get_name())
-        # CHECK: %[[VAL3]]
-        print(value3.get_name())
-        # CHECK: %[[VAL4]]
-        print(value4.get_name())
-
-        # CHECK: %0
-        print(value3.get_name(use_local_scope=True))
-        # CHECK: %1
-        print(value4.get_name(use_local_scope=True))
-
-        # CHECK: %[[VAL5]]
-        print(value5.get_name())
-        # CHECK: %[[VAL6]]
-        print(value6.get_name())
-
-        # CHECK: %[[ARG0:.*]]
-        print(entry_block1.arguments[0].get_name())
-        # CHECK: %[[ARG1:.*]]
-        print(entry_block1.arguments[1].get_name())
-
-        # CHECK: %[[ARG2:.*]]
-        print(entry_block2.arguments[0].get_name())
-        # CHECK: %[[ARG3:.*]]
-        print(entry_block2.arguments[1].get_name())
-
-        # CHECK: module {
-        # CHECK:   %[[VAL1]] = "custom.op1"() : () -> i32
-        # CHECK:   %[[VAL2]] = "custom.op2"() : () -> i32
-        # CHECK:   func.func @test(%[[ARG0]]: i32, %[[ARG1]]: i32) {
-        # CHECK:     %[[VAL3]] = "custom.op3"() : () -> i32
-        # CHECK:     %[[VAL4]] = "custom.op4"() : () -> i32
-        # CHECK:     func @test(%[[ARG2]]: i32, %[[ARG3]]: i32) {
-        # CHECK:       %[[VAL5]] = "custom.op5"() : () -> i32
-        # CHECK:       %[[VAL6]] = "custom.op6"() : () -> i32
-        # CHECK:       return
-        # CHECK:     }
-        # CHECK:     return
-        # CHECK:   }
-        # CHECK: }
-        print(module)
-
-        value2.owner.detach_from_parent()
-        # CHECK: %0
-        print(value2.get_name())
+  ctx = Context()
+  ctx.allow_unregistered_dialects = True
+  with Location.unknown(ctx):
+    i32 = IntegerType.get_signless(32)
+    module = Module.create()
+    with InsertionPoint(module.body):
+      value = Operation.create("custom.op1", results=[i32]).results[0]
+      # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
+      print(value)
+
+      value2 = Operation.create("custom.op2", results=[i32]).results[0]
+      # CHECK: Value(%[[VAL2:.*]] = "custom.op2"() : () -> i32)
+      print(value2)
+
+      f = func.FuncOp("test", ([i32, i32], []))
+      entry_block1 = Block.create_at_start(f.operation.regions[0], [i32, i32])
+
+      with InsertionPoint(entry_block1):
+        value3 = Operation.create("custom.op3", results=[i32]).results[0]
+        # CHECK: Value(%[[VAL3:.*]] = "custom.op3"() : () -> i32)
+        print(value3)
+        value4 = Operation.create("custom.op4", results=[i32]).results[0]
+        # CHECK: Value(%[[VAL4:.*]] = "custom.op4"() : () -> i32)
+        print(value4)
+
+        f = func.FuncOp("test", ([i32, i32], []))
+        entry_block2 = Block.create_at_start(f.operation.regions[0], [i32, i32])
+        with InsertionPoint(entry_block2):
+          value5 = Operation.create("custom.op5", results=[i32]).results[0]
+          # CHECK: Value(%[[VAL5:.*]] = "custom.op5"() : () -> i32)
+          print(value5)
+          value6 = Operation.create("custom.op6", results=[i32]).results[0]
+          # CHECK: Value(%[[VAL6:.*]] = "custom.op6"() : () -> i32)
+          print(value6)
+
+          func.ReturnOp([])
+
+        func.ReturnOp([])
+
+    # CHECK: %[[VAL1]]
+    print(value.get_name())
+    # CHECK: %[[VAL2]]
+    print(value2.get_name())
+    # CHECK: %[[VAL3]]
+    print(value3.get_name())
+    # CHECK: %[[VAL4]]
+    print(value4.get_name())
+
+    # CHECK: %0
+    print(value3.get_name(use_local_scope=True))
+
+    # CHECK: %1
+    print(value4.get_name(use_local_scope=True))
+
+    # CHECK: %[[VAL5]]
+    print(value5.get_name())
+    # CHECK: %[[VAL6]]
+    print(value6.get_name())
+
+    # CHECK: %[[ARG0:.*]]
+    print(entry_block1.arguments[0].get_name())
+    # CHECK: %[[ARG1:.*]]
+    print(entry_block1.arguments[1].get_name())
+
+    # CHECK: %[[ARG2:.*]]
+    print(entry_block2.arguments[0].get_name())
+    # CHECK: %[[ARG3:.*]]
+    print(entry_block2.arguments[1].get_name())
+
+    # CHECK: module {
+    # CHECK:   %[[VAL1]] = "custom.op1"() : () -> i32
+    # CHECK:   %[[VAL2]] = "custom.op2"() : () -> i32
+    # CHECK:   func.func @test(%[[ARG0]]: i32, %[[ARG1]]: i32) {
+    # CHECK:     %[[VAL3]] = "custom.op3"() : () -> i32
+    # CHECK:     %[[VAL4]] = "custom.op4"() : () -> i32
+    # CHECK:     func @test(%[[ARG2]]: i32, %[[ARG3]]: i32) {
+    # CHECK:       %[[VAL5]] = "custom.op5"() : () -> i32
+    # CHECK:       %[[VAL6]] = "custom.op6"() : () -> i32
+    # CHECK:       return
+    # CHECK:     }
+    # CHECK:     return
+    # CHECK:   }
+    # CHECK: }
+    print(module)
+
+    value2.owner.detach_from_parent()
+    # CHECK: %0
+    print(value2.get_name())
 
 
 # CHECK-LABEL: TEST: testValueSetType



More information about the Mlir-commits mailing list