[llvm-branch-commits] [mlir] b62c7e0 - [mlir][python] Swap shape and element_type order for MemRefType.

Stella Laurenzo via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jan 19 16:08:57 PST 2021


Author: Stella Laurenzo
Date: 2021-01-19T16:03:19-08:00
New Revision: b62c7e047420026dcfe84ad66969f501698acbee

URL: https://github.com/llvm/llvm-project/commit/b62c7e047420026dcfe84ad66969f501698acbee
DIFF: https://github.com/llvm/llvm-project/commit/b62c7e047420026dcfe84ad66969f501698acbee.diff

LOG: [mlir][python] Swap shape and element_type order for MemRefType.

* Matches how all of the other shaped types are declared.
* No super principled reason fro this ordering beyond that it makes the one that was different be like the rest.
* Also matches ordering of things like ndarray, et al.

Reviewed By: ftynse, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D94812

Added: 
    

Modified: 
    mlir/examples/python/linalg_matmul.py
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/test/Bindings/Python/ir_types.py

Removed: 
    


################################################################################
diff  --git a/mlir/examples/python/linalg_matmul.py b/mlir/examples/python/linalg_matmul.py
index e9be189bfaaf..0bd3c12a0378 100644
--- a/mlir/examples/python/linalg_matmul.py
+++ b/mlir/examples/python/linalg_matmul.py
@@ -31,9 +31,9 @@ def FuncOp(name: str, func_type: Type) -> Tuple[Operation, Block]:
 
 
 def build_matmul_buffers_func(func_name, m, k, n, dtype):
-  lhs_type = MemRefType.get(dtype, [m, k])
-  rhs_type = MemRefType.get(dtype, [k, n])
-  result_type = MemRefType.get(dtype, [m, n])
+  lhs_type = MemRefType.get([m, k], dtype)
+  rhs_type = MemRefType.get([k, n], dtype)
+  result_type = MemRefType.get([m, n], dtype)
   # TODO: There should be a one-liner for this.
   func_type = FunctionType.get([lhs_type, rhs_type, result_type], [])
   _, entry = FuncOp(func_name, func_type)
@@ -49,8 +49,6 @@ def build_matmul_buffers_func(func_name, m, k, n, dtype):
 
 
 def build_matmul_tensors_func(func_name, m, k, n, dtype):
-  # TODO: MemRefType and TensorTypes should not have inverted dtype/shapes
-  # from each other.
   lhs_type = RankedTensorType.get([m, k], dtype)
   rhs_type = RankedTensorType.get([k, n], dtype)
   result_type = RankedTensorType.get([m, n], dtype)

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 63bdd0c7a184..3c9f79e2a17a 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -2832,7 +2832,7 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
          "get",
-         [](PyType &elementType, std::vector<int64_t> shape,
+         [](std::vector<int64_t> shape, PyType &elementType,
             std::vector<PyAffineMap> layout, unsigned memorySpace,
             DefaultingPyLocation loc) {
            SmallVector<MlirAffineMap> maps;
@@ -2856,7 +2856,7 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
            }
            return PyMemRefType(elementType.getContext(), t);
          },
-         py::arg("element_type"), py::arg("shape"),
+         py::arg("shape"), py::arg("element_type"),
          py::arg("layout") = py::list(), py::arg("memory_space") = 0,
          py::arg("loc") = py::none(), "Create a memref type")
         .def_property_readonly("layout", &PyMemRefType::getLayout,

diff  --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index 64b684ee99e9..7402c644a1c1 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -326,7 +326,7 @@ def testMemRefType():
     f32 = F32Type.get()
     shape = [2, 3]
     loc = Location.unknown()
-    memref = MemRefType.get(f32, shape, memory_space=2)
+    memref = MemRefType.get(shape, f32, memory_space=2)
     # CHECK: memref type: memref<2x3xf32, 2>
     print("memref type:", memref)
     # CHECK: number of affine layout maps: 0
@@ -335,7 +335,7 @@ def testMemRefType():
     print("memory space:", memref.memory_space)
 
     layout = AffineMap.get_permutation([1, 0])
-    memref_layout = MemRefType.get(f32, shape, [layout])
+    memref_layout = MemRefType.get(shape, f32, [layout])
     # CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
     print("memref type:", memref_layout)
     assert len(memref_layout.layout) == 1
@@ -346,7 +346,7 @@ def testMemRefType():
 
     none = NoneType.get()
     try:
-      memref_invalid = MemRefType.get(none, shape)
+      memref_invalid = MemRefType.get(shape, none)
     except ValueError as e:
       # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
       # CHECK: or complex type.


        


More information about the llvm-branch-commits mailing list