[Mlir-commits] [mlir] 63d16d0 - [mlir] Support setting operand values in C and Python APIs.

Mike Urbach llvmlistbot at llvm.org
Tue Apr 27 19:17:52 PDT 2021


Author: Mike Urbach
Date: 2021-04-27T20:17:47-06:00
New Revision: 63d16d06f5b8f71382033b5ea4aa668f8150817a

URL: https://github.com/llvm/llvm-project/commit/63d16d06f5b8f71382033b5ea4aa668f8150817a
DIFF: https://github.com/llvm/llvm-project/commit/63d16d06f5b8f71382033b5ea4aa668f8150817a.diff

LOG: [mlir] Support setting operand values in C and Python APIs.

This adds `mlirOperationSetOperand` to the IR C API, similar to the
function to get an operand.

In the Python API, this adds `operands[index] = value` syntax, similar
to the syntax to get an operand with `operands[index]`.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D101398

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/PybindUtils.h
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/Bindings/Python/ir_operation.py
    mlir/test/CAPI/ir.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 8e92510aecdb..1b243165cbb3 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -366,6 +366,10 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumOperands(MlirOperation op);
 MLIR_CAPI_EXPORTED MlirValue mlirOperationGetOperand(MlirOperation op,
                                                      intptr_t pos);
 
+/// Sets the `pos`-th operand of the operation.
+MLIR_CAPI_EXPORTED void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
+                                                MlirValue newValue);
+
 /// Returns the number of results of the operation.
 MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumResults(MlirOperation op);
 

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index b93786e05f15..0945753f9fc9 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1640,6 +1640,15 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
     return PyOpOperandList(operation, startIndex, length, step);
   }
 
+  void dunderSetItem(intptr_t index, PyValue value) {
+    index = wrapIndex(index);
+    mlirOperationSetOperand(operation->get(), index, value.get());
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def("__setitem__", &PyOpOperandList::dunderSetItem);
+  }
+
 private:
   PyOperationRef operation;
 };

diff  --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index 0cea24482dfe..7a9b8ecb9b01 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -215,6 +215,16 @@ class Sliceable {
 protected:
   using ClassTy = pybind11::class_<Derived>;
 
+  intptr_t wrapIndex(intptr_t index) {
+    if (index < 0)
+      index = length + index;
+    if (index < 0 || index >= length) {
+      throw python::SetPyError(PyExc_IndexError,
+                               "attempt to access out of bounds");
+    }
+    return index;
+  }
+
 public:
   explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
       : startIndex(startIndex), length(length), step(step) {
@@ -228,12 +238,7 @@ class Sliceable {
   /// by taking elements in inverse order. Throws if the index is out of bounds.
   ElementTy dunderGetItem(intptr_t index) {
     // Negative indices mean we count from the end.
-    if (index < 0)
-      index = length + index;
-    if (index < 0 || index >= length) {
-      throw python::SetPyError(PyExc_IndexError,
-                               "attempt to access out of bounds");
-    }
+    index = wrapIndex(index);
 
     // Compute the linear index given the current slice properties.
     int linearIndex = index * step + startIndex;

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 000b8f565bb5..4e21835164ab 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -351,6 +351,11 @@ MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
   return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
 }
 
+void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
+                             MlirValue newValue) {
+  unwrap(op)->setOperand(static_cast<unsigned>(pos), unwrap(newValue));
+}
+
 intptr_t mlirOperationGetNumResults(MlirOperation op) {
   return static_cast<intptr_t>(unwrap(op)->getNumResults());
 }

diff  --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
index f7036cde771e..746cd3e6ddbf 100644
--- a/mlir/test/Bindings/Python/ir_operation.py
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -215,6 +215,38 @@ def testOperationOperandsSlice():
 run(testOperationOperandsSlice)
 
 
+# CHECK-LABEL: TEST: testOperationOperandsSet
+def testOperationOperandsSet():
+  with Context() as ctx, Location.unknown(ctx):
+    ctx.allow_unregistered_dialects = True
+    module = Module.parse(r"""
+      func @f1() {
+        %0 = "test.producer0"() : () -> i64
+        %1 = "test.producer1"() : () -> i64
+        %2 = "test.producer2"() : () -> i64
+        "test.consumer"(%0) : (i64) -> ()
+        return
+      }""")
+    func = module.body.operations[0]
+    entry_block = func.regions[0].blocks[0]
+    producer1 = entry_block.operations[1]
+    producer2 = entry_block.operations[2]
+    consumer = entry_block.operations[3]
+    assert len(consumer.operands) == 1
+    type = consumer.operands[0].type
+
+    # CHECK: test.producer1
+    consumer.operands[0] = producer1.result
+    print(consumer.operands[0])
+
+    # CHECK: test.producer2
+    consumer.operands[-1] = producer2.result
+    print(consumer.operands[0])
+
+
+run(testOperationOperandsSet)
+
+
 # CHECK-LABEL: TEST: testDetachedOperation
 def testDetachedOperation():
   ctx = Context()

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index c5eb174ac2ca..cb9aa5de523e 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -1511,6 +1511,71 @@ static int testBackreferences() {
   return 0;
 }
 
+/// Tests operand APIs.
+int testOperands() {
+  fprintf(stderr, "@testOperands\n");
+  // CHECK-LABEL: @testOperands
+
+  MlirContext ctx = mlirContextCreate();
+  MlirLocation loc = mlirLocationUnknownGet(ctx);
+  MlirType indexType = mlirIndexTypeGet(ctx);
+
+  // Create some constants to use as operands.
+  MlirAttribute indexZeroLiteral =
+      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
+  MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
+      mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
+      indexZeroLiteral);
+  MlirOperationState constZeroState = mlirOperationStateGet(
+      mlirStringRefCreateFromCString("std.constant"), loc);
+  mlirOperationStateAddResults(&constZeroState, 1, &indexType);
+  mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
+  MlirOperation constZero = mlirOperationCreate(&constZeroState);
+  MlirValue constZeroValue = mlirOperationGetResult(constZero, 0);
+
+  MlirAttribute indexOneLiteral =
+      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
+  MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet(
+      mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
+      indexOneLiteral);
+  MlirOperationState constOneState = mlirOperationStateGet(
+      mlirStringRefCreateFromCString("std.constant"), loc);
+  mlirOperationStateAddResults(&constOneState, 1, &indexType);
+  mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr);
+  MlirOperation constOne = mlirOperationCreate(&constOneState);
+  MlirValue constOneValue = mlirOperationGetResult(constOne, 0);
+
+  // Create the operation under test.
+  MlirOperationState opState =
+      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op"), loc);
+  MlirValue initialOperands[] = {constZeroValue};
+  mlirOperationStateAddOperands(&opState, 1, initialOperands);
+  MlirOperation op = mlirOperationCreate(&opState);
+
+  // Test operand APIs.
+  intptr_t numOperands = mlirOperationGetNumOperands(op);
+  fprintf(stderr, "Num Operands: %ld\n", numOperands);
+  // CHECK: Num Operands: 1
+
+  MlirValue opOperand = mlirOperationGetOperand(op, 0);
+  fprintf(stderr, "Original operand: ");
+  mlirValuePrint(opOperand, printToStderr, NULL);
+  // CHECK: Original operand: {{.+}} {value = 0 : index}
+
+  mlirOperationSetOperand(op, 0, constOneValue);
+  opOperand = mlirOperationGetOperand(op, 0);
+  fprintf(stderr, "Updated operand: ");
+  mlirValuePrint(opOperand, printToStderr, NULL);
+  // CHECK: Updated operand: {{.+}} {value = 1 : index}
+
+  mlirOperationDestroy(op);
+  mlirOperationDestroy(constZero);
+  mlirOperationDestroy(constOne);
+  mlirContextDestroy(ctx);
+
+  return 0;
+}
+
 // Wraps a diagnostic into additional text we can match against.
 MlirLogicalResult errorHandler(MlirDiagnostic diagnostic, void *userData) {
   fprintf(stderr, "processing diagnostic (userData: %ld) <<\n", (long)userData);
@@ -1588,6 +1653,8 @@ int main() {
     return 9;
   if (testBackreferences())
     return 10;
+  if (testOperands())
+    return 11;
 
   mlirContextDestroy(ctx);
 


        


More information about the Mlir-commits mailing list