[Mlir-commits] [mlir] e0ca7e9 - [MLIR][python bindings] Fix inferReturnTypes + AttrSizedOperandSegments for optional operands
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 26 12:52:23 PDT 2023
Author: max
Date: 2023-05-26T14:50:51-05:00
New Revision: e0ca7e99914609bbed0f30f4834a93d33dcef085
URL: https://github.com/llvm/llvm-project/commit/e0ca7e99914609bbed0f30f4834a93d33dcef085
DIFF: https://github.com/llvm/llvm-project/commit/e0ca7e99914609bbed0f30f4834a93d33dcef085.diff
LOG: [MLIR][python bindings] Fix inferReturnTypes + AttrSizedOperandSegments for optional operands
Right now `inferTypeOpInterface.inferReturnTypes` fails because there's a cast in there to `py::sequence` which throws a `TypeError` when it tries to cast the `None`s. Note `None`s are inserted into `operands` for omitted operands passed to the generated builder:
```
operands.append(_get_op_result_or_value(start) if start is not None else None)
operands.append(_get_op_result_or_value(stop) if stop is not None else None)
operands.append(_get_op_result_or_value(step) if step is not None else None)
```
Note also that skipping appending to the list operands doesn't work either because [[ https://github.com/llvm/llvm-project/blob/27c37327da67020f938aabf0f6405f57d688441e/mlir/lib/Bindings/Python/IRCore.cpp#L1585 | build generic ]] checks against the number of operand segments expected.
Currently the only way around is to handroll through `ir.Operation.create`.
Reviewed By: rkayaith
Differential Revision: https://reviews.llvm.org/D151409
Added:
Modified:
mlir/lib/Bindings/Python/IRInterfaces.cpp
mlir/test/python/dialects/python_test.py
mlir/test/python/python_test_ops.td
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index 25fcaccd236d..dd4190016e19 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -53,6 +53,9 @@ llvm::SmallVector<MlirValue> wrapOperands(std::optional<py::list> operandList) {
// Note: as the list may contain other lists this may not be final size.
mlirOperands.reserve(operandList->size());
for (const auto &&it : llvm::enumerate(*operandList)) {
+ if (it.value().is_none())
+ continue;
+
PyValue *val;
try {
val = py::cast<PyValue *>(it.value());
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 5346955c906f..37e508fb979e 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -4,6 +4,7 @@
import mlir.dialects.func as func
import mlir.dialects.python_test as test
import mlir.dialects.tensor as tensor
+import mlir.dialects.arith as arith
def run(f):
@@ -467,3 +468,22 @@ def type_caster(pytype):
print(d.type)
# CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
print(repr(d.type))
+
+
+# CHECK-LABEL: TEST: testInferTypeOpInterface
+ at run
+def testInferTypeOpInterface():
+ with Context() as ctx, Location.unknown(ctx):
+ test.register_python_test_dialect(ctx)
+ module = Module.create()
+ with InsertionPoint(module.body):
+ i64 = IntegerType.get_signless(64)
+ zero = arith.ConstantOp(i64, 0)
+
+ one_operand = test.InferResultsVariadicInputsOp(single=zero, doubled=None)
+ # CHECK: i32
+ print(one_operand.result.type)
+
+ two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero)
+ # CHECK: f32
+ print(two_operands.result.type)
diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index 21bb95dd8278..2fc78cbddcd5 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -101,6 +101,31 @@ def InferResultsOp : TestOp<"infer_results_op", [InferTypeOpInterface]> {
}];
}
+def I32OrF32 : TypeConstraint<Or<[I32.predicate, F32.predicate]>,
+ "i32 or f32">;
+
+def InferResultsVariadicInputsOp : TestOp<"infer_results_variadic_inputs_op",
+ [InferTypeOpInterface, AttrSizedOperandSegments]> {
+ let arguments = (ins Optional<I64>:$single, Optional<I64>:$doubled);
+ let results = (outs I32OrF32:$res);
+
+ let extraClassDeclaration = [{
+ static ::mlir::LogicalResult inferReturnTypes(
+ ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location,
+ ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
+ ::mlir::OpaqueProperties,
+ ::mlir::RegionRange regions,
+ ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
+ ::mlir::Builder b(context);
+ if (operands.size() == 1)
+ inferredReturnTypes.push_back(b.getI32Type());
+ else if (operands.size() == 2)
+ inferredReturnTypes.push_back(b.getF32Type());
+ return ::mlir::success();
+ }
+ }];
+}
+
// If all result types are buildable, the InferTypeOpInterface is implied and is
// autogenerated by C++ ODS.
def InferResultsImpliedOp : TestOp<"infer_results_implied_op"> {
More information about the Mlir-commits
mailing list