[Mlir-commits] [mlir] ee308c9 - [mlir][py] Fix infer return type invocation for variadics
Jacques Pienaar
llvmlistbot at llvm.org
Mon Feb 6 17:04:00 PST 2023
Author: Jacques Pienaar
Date: 2023-02-06T17:01:53-08:00
New Revision: ee308c99ed0877edc286870089219179a2c64a9e
URL: https://github.com/llvm/llvm-project/commit/ee308c99ed0877edc286870089219179a2c64a9e
DIFF: https://github.com/llvm/llvm-project/commit/ee308c99ed0877edc286870089219179a2c64a9e.diff
LOG: [mlir][py] Fix infer return type invocation for variadics
Previously we only allowed the flattened list passed in, but the same
input provided here as to buildGeneric so flatten accordingly. We have
less info here than in buildGeneric so the error is more generic if
unpacking fails.
Differential Revision: https://reviews.llvm.org/D143240
Added:
Modified:
mlir/lib/Bindings/Python/IRInterfaces.cpp
mlir/lib/Interfaces/ViewLikeInterface.cpp
mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
mlir/test/python/dialects/tensor.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index fed8a5066fe66..b917bf0c17b63 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -12,6 +12,7 @@
#include "IRModule.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Interfaces.h"
+#include "llvm/ADT/STLExtras.h"
namespace py = pybind11;
@@ -183,9 +184,9 @@ class PyInferTypeOpInterface
}
/// Given the arguments required to build an operation, attempts to infer its
- /// return types. Throws value_error on faliure.
+ /// return types. Throws value_error on failure.
std::vector<PyType>
- inferReturnTypes(std::optional<std::vector<PyValue>> operands,
+ inferReturnTypes(std::optional<py::list> operandList,
std::optional<PyAttribute> attributes,
std::optional<std::vector<PyRegion>> regions,
DefaultingPyMlirContext context,
@@ -193,10 +194,45 @@ class PyInferTypeOpInterface
llvm::SmallVector<MlirValue> mlirOperands;
llvm::SmallVector<MlirRegion> mlirRegions;
- if (operands) {
- mlirOperands.reserve(operands->size());
- for (PyValue &value : *operands) {
- mlirOperands.push_back(value);
+ if (operandList && !operandList->empty()) {
+ // Note: as the list may contain other lists this may not be final size.
+ mlirOperands.reserve(operandList->size());
+ for (const auto& it : llvm::enumerate(*operandList)) {
+ PyValue* val;
+ try {
+ val = py::cast<PyValue *>(it.value());
+ if (!val)
+ throw py::cast_error();
+ mlirOperands.push_back(val->get());
+ continue;
+ } catch (py::cast_error &err) {
+ }
+
+ try {
+ auto vals = py::cast<py::sequence>(it.value());
+ for (py::object v : vals) {
+ try {
+ val = py::cast<PyValue *>(v);
+ if (!val)
+ throw py::cast_error();
+ mlirOperands.push_back(val->get());
+ } catch (py::cast_error &err) {
+ throw py::value_error(
+ (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
+ " must be a Value or Sequence of Values (" + err.what() +
+ ")")
+ .str());
+ }
+ }
+ continue;
+ } catch (py::cast_error &err) {
+ throw py::value_error(
+ (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
+ " must be a Value or Sequence of Values (" + err.what() + ")")
+ .str());
+ }
+
+ throw py::cast_error();
}
}
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index d3c8fde35ef88..9d30f2797c0e8 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -24,8 +24,8 @@ LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op,
ValueRange values) {
// Check static and dynamic offsets/sizes/strides does not overflow type.
if (staticVals.size() != numElements)
- return op->emitError("expected ")
- << numElements << " " << name << " values";
+ return op->emitError("expected ") << numElements << " " << name
+ << " values, got " << staticVals.size();
unsigned expectedNumDynamicEntries =
llvm::count_if(staticVals, [&](int64_t staticVal) {
return ShapedType::isDynamic(staticVal);
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 505946ca1a843..63a3125ec715d 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -667,7 +667,7 @@ class IndexType(Type):
class InferTypeOpInterface:
def __init__(self, object: object, context: Optional[Context] = None) -> None: ...
- def inferReturnTypes(self, operands: Optional[List[Value]] = None, attributes: Optional[Attribute] = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[Type]: ...
+ def inferReturnTypes(self, operands: Optional[List] = None, attributes: Optional[Attribute] = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[Type]: ...
@property
def operation(self) -> Operation: ...
@property
diff --git a/mlir/test/python/dialects/tensor.py b/mlir/test/python/dialects/tensor.py
index f7f73a12ed4c5..d8ea426beadf2 100644
--- a/mlir/test/python/dialects/tensor.py
+++ b/mlir/test/python/dialects/tensor.py
@@ -74,3 +74,30 @@ def zero_d():
return tensor.EmptyOp([], f32)
print(module)
+
+
+# CHECK-LABEL: TEST: testInferTypesInsertSlice
+ at run
+def testInferTypesInsertSlice():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32Type = F32Type.get()
+ indexType = IndexType.get()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((1, 1), f32Type),
+ RankedTensorType.get((1, 1), f32Type))
+ # CHECK: func @f
+ # CHECK: tensor.insert_slice %arg0 into %arg1[0, 0] [1, 1] [0, 0] :
+ # CHECK-SAME: tensor<1x1xf32> into tensor<1x1xf32>
+ def f(source, dest):
+ c0 = arith.ConstantOp(indexType, 0)
+ c1 = arith.ConstantOp(indexType, 1)
+ d0 = tensor.InsertSliceOp(source, dest, [], [], [],
+ DenseI64ArrayAttr.get([0, 0]),
+ DenseI64ArrayAttr.get([1, 1]),
+ DenseI64ArrayAttr.get([0, 0]))
+ return [d0.result]
+
+ print(module)
More information about the Mlir-commits
mailing list