[llvm-branch-commits] [mlir] b62c7e0 - [mlir][python] Swap shape and element_type order for MemRefType.
Stella Laurenzo via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Jan 19 16:08:57 PST 2021
Author: Stella Laurenzo
Date: 2021-01-19T16:03:19-08:00
New Revision: b62c7e047420026dcfe84ad66969f501698acbee
URL: https://github.com/llvm/llvm-project/commit/b62c7e047420026dcfe84ad66969f501698acbee
DIFF: https://github.com/llvm/llvm-project/commit/b62c7e047420026dcfe84ad66969f501698acbee.diff
LOG: [mlir][python] Swap shape and element_type order for MemRefType.
* Matches how all of the other shaped types are declared.
* No super principled reason fro this ordering beyond that it makes the one that was different be like the rest.
* Also matches ordering of things like ndarray, et al.
Reviewed By: ftynse, nicolasvasilache
Differential Revision: https://reviews.llvm.org/D94812
Added:
Modified:
mlir/examples/python/linalg_matmul.py
mlir/lib/Bindings/Python/IRModules.cpp
mlir/test/Bindings/Python/ir_types.py
Removed:
################################################################################
diff --git a/mlir/examples/python/linalg_matmul.py b/mlir/examples/python/linalg_matmul.py
index e9be189bfaaf..0bd3c12a0378 100644
--- a/mlir/examples/python/linalg_matmul.py
+++ b/mlir/examples/python/linalg_matmul.py
@@ -31,9 +31,9 @@ def FuncOp(name: str, func_type: Type) -> Tuple[Operation, Block]:
def build_matmul_buffers_func(func_name, m, k, n, dtype):
- lhs_type = MemRefType.get(dtype, [m, k])
- rhs_type = MemRefType.get(dtype, [k, n])
- result_type = MemRefType.get(dtype, [m, n])
+ lhs_type = MemRefType.get([m, k], dtype)
+ rhs_type = MemRefType.get([k, n], dtype)
+ result_type = MemRefType.get([m, n], dtype)
# TODO: There should be a one-liner for this.
func_type = FunctionType.get([lhs_type, rhs_type, result_type], [])
_, entry = FuncOp(func_name, func_type)
@@ -49,8 +49,6 @@ def build_matmul_buffers_func(func_name, m, k, n, dtype):
def build_matmul_tensors_func(func_name, m, k, n, dtype):
- # TODO: MemRefType and TensorTypes should not have inverted dtype/shapes
- # from each other.
lhs_type = RankedTensorType.get([m, k], dtype)
rhs_type = RankedTensorType.get([k, n], dtype)
result_type = RankedTensorType.get([m, n], dtype)
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 63bdd0c7a184..3c9f79e2a17a 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -2832,7 +2832,7 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](PyType &elementType, std::vector<int64_t> shape,
+ [](std::vector<int64_t> shape, PyType &elementType,
std::vector<PyAffineMap> layout, unsigned memorySpace,
DefaultingPyLocation loc) {
SmallVector<MlirAffineMap> maps;
@@ -2856,7 +2856,7 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
}
return PyMemRefType(elementType.getContext(), t);
},
- py::arg("element_type"), py::arg("shape"),
+ py::arg("shape"), py::arg("element_type"),
py::arg("layout") = py::list(), py::arg("memory_space") = 0,
py::arg("loc") = py::none(), "Create a memref type")
.def_property_readonly("layout", &PyMemRefType::getLayout,
diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index 64b684ee99e9..7402c644a1c1 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -326,7 +326,7 @@ def testMemRefType():
f32 = F32Type.get()
shape = [2, 3]
loc = Location.unknown()
- memref = MemRefType.get(f32, shape, memory_space=2)
+ memref = MemRefType.get(shape, f32, memory_space=2)
# CHECK: memref type: memref<2x3xf32, 2>
print("memref type:", memref)
# CHECK: number of affine layout maps: 0
@@ -335,7 +335,7 @@ def testMemRefType():
print("memory space:", memref.memory_space)
layout = AffineMap.get_permutation([1, 0])
- memref_layout = MemRefType.get(f32, shape, [layout])
+ memref_layout = MemRefType.get(shape, f32, [layout])
# CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
print("memref type:", memref_layout)
assert len(memref_layout.layout) == 1
@@ -346,7 +346,7 @@ def testMemRefType():
none = NoneType.get()
try:
- memref_invalid = MemRefType.get(none, shape)
+ memref_invalid = MemRefType.get(shape, none)
except ValueError as e:
# CHECK: invalid 'Type(none)' and expected floating point, integer, vector
# CHECK: or complex type.
More information about the llvm-branch-commits
mailing list