[Mlir-commits] [mlir] [mlir][python] enable memref.subview (PR #79393)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 24 16:27:28 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Maksim Levental (makslevental)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/79393.diff
5 Files Affected:
- (modified) mlir/include/mlir-c/BuiltinTypes.h (+7)
- (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+9)
- (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+14)
- (modified) mlir/python/mlir/dialects/memref.py (+69)
- (modified) mlir/test/python/dialects/memref.py (+14)
``````````diff
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 1fd5691f41eec35..2523bddc475d823 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -408,6 +408,13 @@ MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type);
/// Returns the memory space of the given MemRef type.
MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type);
+/// Returns the strides of the MemRef if the layout map is in strided form.
+/// Both strides and offset are out params. strides must point to pre-allocated
+/// memory of length equal to the rank of the memref.
+MLIR_CAPI_EXPORTED void mlirMemRefTypeGetStridesAndOffset(MlirType type,
+ int64_t *strides,
+ int64_t *offset);
+
/// Returns the memory spcae of the given Unranked MemRef type.
MLIR_CAPI_EXPORTED MlirAttribute
mlirUnrankedMemrefGetMemorySpace(MlirType type);
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 56e895d3053796e..86f01a6381ae4e0 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -618,6 +618,15 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
return mlirMemRefTypeGetLayout(self);
},
"The layout of the MemRef type.")
+ .def_property_readonly(
+ "strides_and_offset",
+ [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
+ std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
+ int64_t offset;
+ mlirMemRefTypeGetStridesAndOffset(self, strides.data(), &offset);
+ return {strides, offset};
+ },
+ "The strides and offset of the MemRef type.")
.def_property_readonly(
"affine_map",
[](PyMemRefType &self) -> PyAffineMap {
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 6e645188dac8616..6a3653d8baf304a 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -16,6 +16,8 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
+#include <algorithm>
+
using namespace mlir;
//===----------------------------------------------------------------------===//
@@ -426,6 +428,18 @@ MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace());
}
+void mlirMemRefTypeGetStridesAndOffset(MlirType type, int64_t *strides,
+ int64_t *offset) {
+ MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type));
+ std::pair<SmallVector<int64_t>, int64_t> stridesOffsets =
+ getStridesAndOffset(memrefType);
+ assert(stridesOffsets.first.size() == memrefType.getRank() &&
+ "Strides and rank don't match for memref");
+ (void)std::copy(stridesOffsets.first.begin(), stridesOffsets.first.end(),
+ strides);
+ *offset = stridesOffsets.second;
+}
+
MlirTypeID mlirUnrankedMemRefTypeGetTypeID() {
return wrap(UnrankedMemRefType::getTypeID());
}
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index 3afb6a70cb9e0db..8023cbccd7a4183 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -1,5 +1,74 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from typing import Optional, Sequence
from ._memref_ops_gen import *
+from ..ir import Value, ShapedType, MemRefType, StridedLayoutAttr
+
+
+def _infer_memref_subview_result_type(
+ source_memref_type, static_offsets, static_sizes, static_strides
+):
+ source_strides, source_offset = source_memref_type.strides_and_offset
+ target_offset = source_offset
+ for static_offset, target_stride in zip(static_offsets, source_strides):
+ target_offset += static_offset * target_stride
+
+ target_strides = []
+ for source_stride, static_stride in zip(source_strides, static_strides):
+ target_strides.append(source_stride * static_stride)
+
+ layout = StridedLayoutAttr.get(target_offset, target_strides)
+ return MemRefType.get(
+ static_sizes,
+ source_memref_type.element_type,
+ layout,
+ source_memref_type.memory_space,
+ )
+
+
+_generated_subview = subview
+
+
+def subview(
+ source: Value,
+ offsets: Optional[Sequence[Value]] = None,
+ strides: Optional[Sequence[Value]] = None,
+ static_offsets: Optional[Sequence[int]] = None,
+ static_sizes: Optional[Sequence[int]] = None,
+ static_strides: Optional[Sequence[int]] = None,
+ *,
+ loc=None,
+ ip=None,
+):
+ if offsets is None:
+ offsets = []
+ if static_offsets is None:
+ static_offsets = []
+ if strides is None:
+ strides = []
+ if static_strides is None:
+ static_strides = []
+ assert static_sizes, f"this convenience method only handles static sizes"
+ sizes = []
+ S = ShapedType.get_dynamic_size()
+ if offsets and static_offsets:
+ assert all(s == S for s in static_offsets)
+ if strides and static_strides:
+ assert all(s == S for s in static_strides)
+ result_type = _infer_memref_subview_result_type(
+ source.type, static_offsets, static_sizes, static_strides
+ )
+ return _generated_subview(
+ result_type,
+ source,
+ offsets,
+ sizes,
+ strides,
+ static_offsets,
+ static_sizes,
+ static_strides,
+ loc=loc,
+ ip=ip,
+ )
diff --git a/mlir/test/python/dialects/memref.py b/mlir/test/python/dialects/memref.py
index 0c8a7ee282fe161..47c8ff86d30097c 100644
--- a/mlir/test/python/dialects/memref.py
+++ b/mlir/test/python/dialects/memref.py
@@ -88,3 +88,17 @@ def testMemRefAttr():
memref.global_("objFifo_in0", T.memref(16, T.i32()))
# CHECK: memref.global @objFifo_in0 : memref<16xi32>
print(module)
+
+
+# CHECK-LABEL: TEST: testSubViewOpInferReturnTypes
+ at run
+def testSubViewOpInferReturnTypes():
+ with Context() as ctx, Location.unknown(ctx):
+ module = Module.create()
+ with InsertionPoint(module.body):
+ x = memref.alloc(T.memref(10, 10, T.i32()), [], [])
+ # CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<10x10xi32>
+ print(x.owner)
+ y = memref.subview(x, [], [], [1, 1], [3, 3], [1, 1])
+ # CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
+ print(y.owner)
``````````
</details>
https://github.com/llvm/llvm-project/pull/79393
More information about the Mlir-commits
mailing list