[Mlir-commits] [mlir] ee308c9 - [mlir][py] Fix infer return type invocation for variadics

Jacques Pienaar llvmlistbot at llvm.org
Mon Feb 6 17:04:00 PST 2023


Author: Jacques Pienaar
Date: 2023-02-06T17:01:53-08:00
New Revision: ee308c99ed0877edc286870089219179a2c64a9e

URL: https://github.com/llvm/llvm-project/commit/ee308c99ed0877edc286870089219179a2c64a9e
DIFF: https://github.com/llvm/llvm-project/commit/ee308c99ed0877edc286870089219179a2c64a9e.diff

LOG: [mlir][py] Fix infer return type invocation for variadics

Previously we only allowed the flattened list passed in, but the same
input provided here as to buildGeneric so flatten accordingly. We have
less info here than in buildGeneric so the error is more generic if
unpacking fails.

Differential Revision: https://reviews.llvm.org/D143240

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRInterfaces.cpp
    mlir/lib/Interfaces/ViewLikeInterface.cpp
    mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
    mlir/test/python/dialects/tensor.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index fed8a5066fe66..b917bf0c17b63 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -12,6 +12,7 @@
 #include "IRModule.h"
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/Interfaces.h"
+#include "llvm/ADT/STLExtras.h"
 
 namespace py = pybind11;
 
@@ -183,9 +184,9 @@ class PyInferTypeOpInterface
   }
 
   /// Given the arguments required to build an operation, attempts to infer its
-  /// return types. Throws value_error on faliure.
+  /// return types. Throws value_error on failure.
   std::vector<PyType>
-  inferReturnTypes(std::optional<std::vector<PyValue>> operands,
+  inferReturnTypes(std::optional<py::list> operandList,
                    std::optional<PyAttribute> attributes,
                    std::optional<std::vector<PyRegion>> regions,
                    DefaultingPyMlirContext context,
@@ -193,10 +194,45 @@ class PyInferTypeOpInterface
     llvm::SmallVector<MlirValue> mlirOperands;
     llvm::SmallVector<MlirRegion> mlirRegions;
 
-    if (operands) {
-      mlirOperands.reserve(operands->size());
-      for (PyValue &value : *operands) {
-        mlirOperands.push_back(value);
+    if (operandList && !operandList->empty()) {
+      // Note: as the list may contain other lists this may not be final size.
+      mlirOperands.reserve(operandList->size());
+      for (const auto& it : llvm::enumerate(*operandList)) {
+        PyValue* val;
+        try {
+          val = py::cast<PyValue *>(it.value());
+          if (!val)
+            throw py::cast_error();
+          mlirOperands.push_back(val->get());
+          continue;
+        } catch (py::cast_error &err) {
+        }
+
+        try {
+          auto vals = py::cast<py::sequence>(it.value());
+          for (py::object v : vals) {
+            try {
+              val = py::cast<PyValue *>(v);
+              if (!val)
+                throw py::cast_error();
+              mlirOperands.push_back(val->get());
+            } catch (py::cast_error &err) {
+              throw py::value_error(
+                  (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
+                   " must be a Value or Sequence of Values (" + err.what() +
+                   ")")
+                      .str());
+            }
+          }
+          continue;
+        } catch (py::cast_error &err) {
+          throw py::value_error(
+              (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
+               " must be a Value or Sequence of Values (" + err.what() + ")")
+                  .str());
+        }
+
+        throw py::cast_error();
       }
     }
 

diff  --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index d3c8fde35ef88..9d30f2797c0e8 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -24,8 +24,8 @@ LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op,
                                                    ValueRange values) {
   // Check static and dynamic offsets/sizes/strides does not overflow type.
   if (staticVals.size() != numElements)
-    return op->emitError("expected ")
-           << numElements << " " << name << " values";
+    return op->emitError("expected ") << numElements << " " << name
+                                      << " values, got " << staticVals.size();
   unsigned expectedNumDynamicEntries =
       llvm::count_if(staticVals, [&](int64_t staticVal) {
         return ShapedType::isDynamic(staticVal);

diff  --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 505946ca1a843..63a3125ec715d 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -667,7 +667,7 @@ class IndexType(Type):
 
 class InferTypeOpInterface:
     def __init__(self, object: object, context: Optional[Context] = None) -> None: ...
-    def inferReturnTypes(self, operands: Optional[List[Value]] = None, attributes: Optional[Attribute] = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[Type]: ...
+    def inferReturnTypes(self, operands: Optional[List] = None, attributes: Optional[Attribute] = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[Type]: ...
     @property
     def operation(self) -> Operation: ...
     @property

diff  --git a/mlir/test/python/dialects/tensor.py b/mlir/test/python/dialects/tensor.py
index f7f73a12ed4c5..d8ea426beadf2 100644
--- a/mlir/test/python/dialects/tensor.py
+++ b/mlir/test/python/dialects/tensor.py
@@ -74,3 +74,30 @@ def zero_d():
         return tensor.EmptyOp([], f32)
 
   print(module)
+
+
+# CHECK-LABEL: TEST: testInferTypesInsertSlice
+ at run
+def testInferTypesInsertSlice():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    f32Type = F32Type.get()
+    indexType = IndexType.get()
+    with InsertionPoint(module.body):
+
+      @func.FuncOp.from_py_func(
+          RankedTensorType.get((1, 1), f32Type),
+          RankedTensorType.get((1, 1), f32Type))
+      # CHECK: func @f
+      # CHECK:      tensor.insert_slice %arg0 into %arg1[0, 0] [1, 1] [0, 0] :
+      # CHECK-SAME:   tensor<1x1xf32> into tensor<1x1xf32>
+      def f(source, dest):
+        c0 = arith.ConstantOp(indexType, 0)
+        c1 = arith.ConstantOp(indexType, 1)
+        d0 = tensor.InsertSliceOp(source, dest, [], [], [],
+                                  DenseI64ArrayAttr.get([0, 0]),
+                                  DenseI64ArrayAttr.get([1, 1]),
+                                  DenseI64ArrayAttr.get([0, 0]))
+        return [d0.result]
+
+  print(module)


        


More information about the Mlir-commits mailing list