[Mlir-commits] [mlir] 63d16d0 - [mlir] Support setting operand values in C and Python APIs.
Mike Urbach
llvmlistbot at llvm.org
Tue Apr 27 19:17:52 PDT 2021
Author: Mike Urbach
Date: 2021-04-27T20:17:47-06:00
New Revision: 63d16d06f5b8f71382033b5ea4aa668f8150817a
URL: https://github.com/llvm/llvm-project/commit/63d16d06f5b8f71382033b5ea4aa668f8150817a
DIFF: https://github.com/llvm/llvm-project/commit/63d16d06f5b8f71382033b5ea4aa668f8150817a.diff
LOG: [mlir] Support setting operand values in C and Python APIs.
This adds `mlirOperationSetOperand` to the IR C API, similar to the
function to get an operand.
In the Python API, this adds `operands[index] = value` syntax, similar
to the syntax to get an operand with `operands[index]`.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D101398
Added:
Modified:
mlir/include/mlir-c/IR.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/PybindUtils.h
mlir/lib/CAPI/IR/IR.cpp
mlir/test/Bindings/Python/ir_operation.py
mlir/test/CAPI/ir.c
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 8e92510aecdb..1b243165cbb3 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -366,6 +366,10 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumOperands(MlirOperation op);
MLIR_CAPI_EXPORTED MlirValue mlirOperationGetOperand(MlirOperation op,
intptr_t pos);
+/// Sets the `pos`-th operand of the operation.
+MLIR_CAPI_EXPORTED void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
+ MlirValue newValue);
+
/// Returns the number of results of the operation.
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumResults(MlirOperation op);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index b93786e05f15..0945753f9fc9 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1640,6 +1640,15 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
return PyOpOperandList(operation, startIndex, length, step);
}
+ void dunderSetItem(intptr_t index, PyValue value) {
+ index = wrapIndex(index);
+ mlirOperationSetOperand(operation->get(), index, value.get());
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def("__setitem__", &PyOpOperandList::dunderSetItem);
+ }
+
private:
PyOperationRef operation;
};
diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index 0cea24482dfe..7a9b8ecb9b01 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -215,6 +215,16 @@ class Sliceable {
protected:
using ClassTy = pybind11::class_<Derived>;
+ intptr_t wrapIndex(intptr_t index) {
+ if (index < 0)
+ index = length + index;
+ if (index < 0 || index >= length) {
+ throw python::SetPyError(PyExc_IndexError,
+ "attempt to access out of bounds");
+ }
+ return index;
+ }
+
public:
explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
: startIndex(startIndex), length(length), step(step) {
@@ -228,12 +238,7 @@ class Sliceable {
/// by taking elements in inverse order. Throws if the index is out of bounds.
ElementTy dunderGetItem(intptr_t index) {
// Negative indices mean we count from the end.
- if (index < 0)
- index = length + index;
- if (index < 0 || index >= length) {
- throw python::SetPyError(PyExc_IndexError,
- "attempt to access out of bounds");
- }
+ index = wrapIndex(index);
// Compute the linear index given the current slice properties.
int linearIndex = index * step + startIndex;
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 000b8f565bb5..4e21835164ab 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -351,6 +351,11 @@ MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
}
+void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
+ MlirValue newValue) {
+ unwrap(op)->setOperand(static_cast<unsigned>(pos), unwrap(newValue));
+}
+
intptr_t mlirOperationGetNumResults(MlirOperation op) {
return static_cast<intptr_t>(unwrap(op)->getNumResults());
}
diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
index f7036cde771e..746cd3e6ddbf 100644
--- a/mlir/test/Bindings/Python/ir_operation.py
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -215,6 +215,38 @@ def testOperationOperandsSlice():
run(testOperationOperandsSlice)
+# CHECK-LABEL: TEST: testOperationOperandsSet
+def testOperationOperandsSet():
+ with Context() as ctx, Location.unknown(ctx):
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(r"""
+ func @f1() {
+ %0 = "test.producer0"() : () -> i64
+ %1 = "test.producer1"() : () -> i64
+ %2 = "test.producer2"() : () -> i64
+ "test.consumer"(%0) : (i64) -> ()
+ return
+ }""")
+ func = module.body.operations[0]
+ entry_block = func.regions[0].blocks[0]
+ producer1 = entry_block.operations[1]
+ producer2 = entry_block.operations[2]
+ consumer = entry_block.operations[3]
+ assert len(consumer.operands) == 1
+ type = consumer.operands[0].type
+
+ # CHECK: test.producer1
+ consumer.operands[0] = producer1.result
+ print(consumer.operands[0])
+
+ # CHECK: test.producer2
+ consumer.operands[-1] = producer2.result
+ print(consumer.operands[0])
+
+
+run(testOperationOperandsSet)
+
+
# CHECK-LABEL: TEST: testDetachedOperation
def testDetachedOperation():
ctx = Context()
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index c5eb174ac2ca..cb9aa5de523e 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -1511,6 +1511,71 @@ static int testBackreferences() {
return 0;
}
+/// Tests operand APIs.
+int testOperands() {
+ fprintf(stderr, "@testOperands\n");
+ // CHECK-LABEL: @testOperands
+
+ MlirContext ctx = mlirContextCreate();
+ MlirLocation loc = mlirLocationUnknownGet(ctx);
+ MlirType indexType = mlirIndexTypeGet(ctx);
+
+ // Create some constants to use as operands.
+ MlirAttribute indexZeroLiteral =
+ mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
+ MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
+ mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
+ indexZeroLiteral);
+ MlirOperationState constZeroState = mlirOperationStateGet(
+ mlirStringRefCreateFromCString("std.constant"), loc);
+ mlirOperationStateAddResults(&constZeroState, 1, &indexType);
+ mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
+ MlirOperation constZero = mlirOperationCreate(&constZeroState);
+ MlirValue constZeroValue = mlirOperationGetResult(constZero, 0);
+
+ MlirAttribute indexOneLiteral =
+ mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
+ MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet(
+ mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
+ indexOneLiteral);
+ MlirOperationState constOneState = mlirOperationStateGet(
+ mlirStringRefCreateFromCString("std.constant"), loc);
+ mlirOperationStateAddResults(&constOneState, 1, &indexType);
+ mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr);
+ MlirOperation constOne = mlirOperationCreate(&constOneState);
+ MlirValue constOneValue = mlirOperationGetResult(constOne, 0);
+
+ // Create the operation under test.
+ MlirOperationState opState =
+ mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op"), loc);
+ MlirValue initialOperands[] = {constZeroValue};
+ mlirOperationStateAddOperands(&opState, 1, initialOperands);
+ MlirOperation op = mlirOperationCreate(&opState);
+
+ // Test operand APIs.
+ intptr_t numOperands = mlirOperationGetNumOperands(op);
+ fprintf(stderr, "Num Operands: %ld\n", numOperands);
+ // CHECK: Num Operands: 1
+
+ MlirValue opOperand = mlirOperationGetOperand(op, 0);
+ fprintf(stderr, "Original operand: ");
+ mlirValuePrint(opOperand, printToStderr, NULL);
+ // CHECK: Original operand: {{.+}} {value = 0 : index}
+
+ mlirOperationSetOperand(op, 0, constOneValue);
+ opOperand = mlirOperationGetOperand(op, 0);
+ fprintf(stderr, "Updated operand: ");
+ mlirValuePrint(opOperand, printToStderr, NULL);
+ // CHECK: Updated operand: {{.+}} {value = 1 : index}
+
+ mlirOperationDestroy(op);
+ mlirOperationDestroy(constZero);
+ mlirOperationDestroy(constOne);
+ mlirContextDestroy(ctx);
+
+ return 0;
+}
+
// Wraps a diagnostic into additional text we can match against.
MlirLogicalResult errorHandler(MlirDiagnostic diagnostic, void *userData) {
fprintf(stderr, "processing diagnostic (userData: %ld) <<\n", (long)userData);
@@ -1588,6 +1653,8 @@ int main() {
return 9;
if (testBackreferences())
return 10;
+ if (testOperands())
+ return 11;
mlirContextDestroy(ctx);
More information about the Mlir-commits
mailing list