[Mlir-commits] [mlir] [mlir][python] enable memref.subview (PR #79393)
Maksim Levental
llvmlistbot at llvm.org
Thu Jan 25 16:55:25 PST 2024
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/79393
>From eb0f298facf34cc6239e168890e6c5aff0d9e0be Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 24 Jan 2024 18:25:54 -0600
Subject: [PATCH] [mlir][python] enable memref.subview
---
mlir/include/mlir-c/BuiltinTypes.h | 7 ++
mlir/lib/Bindings/Python/IRTypes.cpp | 9 ++
mlir/lib/CAPI/IR/BuiltinTypes.cpp | 14 +++
mlir/python/mlir/dialects/memref.py | 121 +++++++++++++++++++
mlir/test/python/dialects/memref.py | 166 ++++++++++++++++++++++++++-
5 files changed, 315 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 1fd5691f41eec3..2523bddc475d82 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -408,6 +408,13 @@ MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type);
/// Returns the memory space of the given MemRef type.
MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type);
+/// Returns the strides of the MemRef if the layout map is in strided form.
+/// Both strides and offset are out params. strides must point to pre-allocated
+/// memory of length equal to the rank of the memref.
+MLIR_CAPI_EXPORTED void mlirMemRefTypeGetStridesAndOffset(MlirType type,
+ int64_t *strides,
+ int64_t *offset);
+
/// Returns the memory spcae of the given Unranked MemRef type.
MLIR_CAPI_EXPORTED MlirAttribute
mlirUnrankedMemrefGetMemorySpace(MlirType type);
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 56e895d3053796..86f01a6381ae4e 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -618,6 +618,15 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
return mlirMemRefTypeGetLayout(self);
},
"The layout of the MemRef type.")
+ .def_property_readonly(
+ "strides_and_offset",
+ [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
+ std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
+ int64_t offset;
+ mlirMemRefTypeGetStridesAndOffset(self, strides.data(), &offset);
+ return {strides, offset};
+ },
+ "The strides and offset of the MemRef type.")
.def_property_readonly(
"affine_map",
[](PyMemRefType &self) -> PyAffineMap {
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 6e645188dac861..6a3653d8baf304 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -16,6 +16,8 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
+#include <algorithm>
+
using namespace mlir;
//===----------------------------------------------------------------------===//
@@ -426,6 +428,18 @@ MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace());
}
+void mlirMemRefTypeGetStridesAndOffset(MlirType type, int64_t *strides,
+ int64_t *offset) {
+ MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type));
+ std::pair<SmallVector<int64_t>, int64_t> stridesOffsets =
+ getStridesAndOffset(memrefType);
+ assert(stridesOffsets.first.size() == memrefType.getRank() &&
+ "Strides and rank don't match for memref");
+ (void)std::copy(stridesOffsets.first.begin(), stridesOffsets.first.end(),
+ strides);
+ *offset = stridesOffsets.second;
+}
+
MlirTypeID mlirUnrankedMemRefTypeGetTypeID() {
return wrap(UnrankedMemRefType::getTypeID());
}
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index 3afb6a70cb9e0d..6ab6e0602e7a95 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -1,5 +1,126 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+import operator
+from itertools import accumulate
+from typing import Optional
from ._memref_ops_gen import *
+from .arith import ConstantOp
+from .transform.structured import _dispatch_mixed_values, MixedValues
+from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType
+
+
+def _is_constant(i):
+ return isinstance(i, Value) and isinstance(i.owner.opview, ConstantOp)
+
+
+def _is_static(i):
+ return (isinstance(i, int) and not ShapedType.is_dynamic_size(i)) or _is_constant(i)
+
+
+def _infer_memref_subview_result_type(
+ source_memref_type, offsets, static_sizes, static_strides
+):
+ source_strides, source_offset = source_memref_type.strides_and_offset
+ # "canonicalize" from tuple|list -> list
+ offsets, static_sizes, static_strides, source_strides = map(
+ list, (offsets, static_sizes, static_strides, source_strides)
+ )
+
+ assert all(
+ all(_is_static(i) for i in s)
+ for s in [
+ static_sizes,
+ static_strides,
+ source_strides,
+ ]
+ ), f"Only inferring from python or mlir integer constant is supported"
+
+ for s in [offsets, static_sizes, static_strides]:
+ for idx, i in enumerate(s):
+ if _is_constant(i):
+ s[idx] = i.owner.opview.literal_value
+
+ if any(not _is_static(i) for i in offsets + [source_offset]):
+ target_offset = ShapedType.get_dynamic_size()
+ else:
+ target_offset = source_offset
+ for offset, target_stride in zip(offsets, source_strides):
+ target_offset += offset * target_stride
+
+ target_strides = []
+ for source_stride, static_stride in zip(source_strides, static_strides):
+ target_strides.append(source_stride * static_stride)
+
+ # If default striding then no need to complicate things for downstream ops (e.g., expand_shape).
+ default_strides = list(accumulate(static_sizes[1:][::-1], operator.mul))[::-1] + [1]
+ if target_strides == default_strides and target_offset == 0:
+ layout = None
+ else:
+ layout = StridedLayoutAttr.get(target_offset, target_strides)
+ return (
+ offsets,
+ static_sizes,
+ static_strides,
+ MemRefType.get(
+ static_sizes,
+ source_memref_type.element_type,
+ layout,
+ source_memref_type.memory_space,
+ ),
+ )
+
+
+_generated_subview = subview
+
+
+def subview(
+ source: Value,
+ offsets: MixedValues,
+ sizes: MixedValues,
+ strides: MixedValues,
+ *,
+ result_type: Optional[MemRefType] = None,
+ loc=None,
+ ip=None,
+):
+ if offsets is None:
+ offsets = []
+ if sizes is None:
+ sizes = []
+ if strides is None:
+ strides = []
+ source_strides, source_offset = source.type.strides_and_offset
+ if result_type is None and all(
+ all(_is_static(i) for i in s) for s in [sizes, strides, source_strides]
+ ):
+ # If any are arith.constant results then this will canonicalize to python int
+ # (which can then be used to fully specific the subview).
+ (
+ offsets,
+ sizes,
+ strides,
+ result_type,
+ ) = _infer_memref_subview_result_type(source.type, offsets, sizes, strides)
+ else:
+ assert (
+ result_type is not None
+ ), "mixed static/dynamic offset/sizes/strides requires explicit result type"
+
+ offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets)
+ sizes, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes)
+ strides, _packed_strides, static_strides = _dispatch_mixed_values(strides)
+
+ return _generated_subview(
+ result_type,
+ source,
+ offsets,
+ sizes,
+ strides,
+ static_offsets,
+ static_sizes,
+ static_strides,
+ loc=loc,
+ ip=ip,
+ )
diff --git a/mlir/test/python/dialects/memref.py b/mlir/test/python/dialects/memref.py
index 0c8a7ee282fe16..205acbd6d42c4d 100644
--- a/mlir/test/python/dialects/memref.py
+++ b/mlir/test/python/dialects/memref.py
@@ -1,9 +1,10 @@
# RUN: %PYTHON %s | FileCheck %s
-from mlir.ir import *
-import mlir.dialects.func as func
+import mlir.dialects.arith as arith
import mlir.dialects.memref as memref
import mlir.extras.types as T
+from mlir.dialects.memref import _infer_memref_subview_result_type
+from mlir.ir import *
def run(f):
@@ -88,3 +89,164 @@ def testMemRefAttr():
memref.global_("objFifo_in0", T.memref(16, T.i32()))
# CHECK: memref.global @objFifo_in0 : memref<16xi32>
print(module)
+
+
+# CHECK-LABEL: TEST: testSubViewOpInferReturnType
+ at run
+def testSubViewOpInferReturnType():
+ with Context() as ctx, Location.unknown(ctx):
+ module = Module.create()
+ with InsertionPoint(module.body):
+ x = memref.alloc(T.memref(10, 10, T.i32()), [], [])
+ # CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<10x10xi32>
+ print(x.owner)
+
+ y = memref.subview(x, [1, 1], [3, 3], [1, 1])
+ assert y.owner.verify()
+ # CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
+ print(y.owner)
+
+ z = memref.subview(
+ x,
+ [arith.constant(T.index(), 1), 1],
+ [3, 3],
+ [1, 1],
+ )
+ # CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
+ print(z.owner)
+
+ z = memref.subview(
+ x,
+ [arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
+ [3, 3],
+ [1, 1],
+ )
+ # CHECK: %{{.*}} = memref.subview %[[ALLOC]][3, 4] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 34>>
+ print(z.owner)
+
+ s = arith.addi(arith.constant(T.index(), 3), arith.constant(T.index(), 4))
+ z = memref.subview(
+ x,
+ [s, 0],
+ [3, 3],
+ [1, 1],
+ )
+ # CHECK: {{.*}} = memref.subview %[[ALLOC]][%0, 0] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: ?>>
+ print(z)
+
+ try:
+ _infer_memref_subview_result_type(
+ x.type,
+ [arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
+ [ShapedType.get_dynamic_size(), 3],
+ [1, 1],
+ )
+ except AssertionError as e:
+ # CHECK: Only inferring from python or mlir integer constant is supported
+ print(e)
+
+ try:
+ memref.subview(
+ x,
+ [arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
+ [ShapedType.get_dynamic_size(), 3],
+ [1, 1],
+ )
+ except AssertionError as e:
+ # CHECK: mixed static/dynamic offset/sizes/strides requires explicit result type
+ print(e)
+
+ layout = StridedLayoutAttr.get(ShapedType.get_dynamic_size(), [10, 1])
+ x = memref.alloc(
+ T.memref(
+ 10,
+ 10,
+ T.i32(),
+ layout=layout,
+ ),
+ [],
+ [arith.constant(T.index(), 42)],
+ )
+ # CHECK: %[[DYNAMICALLOC:.*]] = memref.alloc()[%c42] : memref<10x10xi32, strided<[10, 1], offset: ?>>
+ print(x.owner)
+ y = memref.subview(
+ x,
+ [1, 1],
+ [3, 3],
+ [1, 1],
+ result_type=T.memref(3, 3, T.i32(), layout=layout),
+ )
+ # CHECK: %{{.*}} = memref.subview %[[DYNAMICALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32, strided<[10, 1], offset: ?>> to memref<3x3xi32, strided<[10, 1], offset: ?>>
+ print(y.owner)
+
+
+# CHECK-LABEL: TEST: testSubViewOpInferReturnTypeExtensiveSlicing
+ at run
+def testSubViewOpInferReturnTypeExtensiveSlicing():
+ def check_strides_offset(memref, np_view):
+ layout = memref.type.layout
+ dtype_size_in_bytes = np_view.dtype.itemsize
+ golden_strides = (np.array(np_view.strides) // dtype_size_in_bytes).tolist()
+ golden_offset = (
+ np_view.ctypes.data - np_view.base.ctypes.data
+ ) // dtype_size_in_bytes
+
+ assert (layout.strides, layout.offset) == (golden_strides, golden_offset)
+
+ with Context() as ctx, Location.unknown(ctx):
+ module = Module.create()
+ with InsertionPoint(module.body):
+ shape = (10, 22, 333, 4444)
+ golden_mem = np.zeros(shape, dtype=np.int32)
+ mem1 = memref.alloc(T.memref(*shape, T.i32()), [], [])
+
+ # fmt: off
+ check_strides_offset(memref.subview(mem1, (1, 0, 0, 0), (1, 22, 333, 4444), (1, 1, 1, 1)), golden_mem[1:2, ...])
+ check_strides_offset(memref.subview(mem1, (0, 1, 0, 0), (10, 1, 333, 4444), (1, 1, 1, 1)), golden_mem[:, 1:2])
+ check_strides_offset(memref.subview(mem1, (0, 0, 1, 0), (10, 22, 1, 4444), (1, 1, 1, 1)), golden_mem[:, :, 1:2])
+ check_strides_offset(memref.subview(mem1, (0, 0, 0, 1), (10, 22, 333, 1), (1, 1, 1, 1)), golden_mem[:, :, :, 1:2])
+ check_strides_offset(memref.subview(mem1, (0, 1, 0, 1), (10, 1, 333, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, :, 1:2])
+ check_strides_offset(memref.subview(mem1, (1, 0, 0, 1), (1, 22, 333, 1), (1, 1, 1, 1)), golden_mem[1:2, :, :, 1:2])
+ check_strides_offset(memref.subview(mem1, (1, 1, 0, 0), (1, 1, 333, 4444), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, :])
+ check_strides_offset(memref.subview(mem1, (0, 0, 1, 1), (10, 22, 1, 1), (1, 1, 1, 1)), golden_mem[:, :, 1:2, 1:2])
+ check_strides_offset(memref.subview(mem1, (0, 1, 1, 0), (10, 1, 1, 4444), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, :])
+ check_strides_offset(memref.subview(mem1, (1, 0, 1, 0), (1, 22, 1, 4444), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, :])
+ check_strides_offset(memref.subview(mem1, (1, 1, 0, 1), (1, 1, 333, 1), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, 1:2])
+ check_strides_offset(memref.subview(mem1, (1, 0, 1, 1), (1, 22, 1, 1), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, 1:2])
+ check_strides_offset(memref.subview(mem1, (0, 1, 1, 1), (10, 1, 1, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, 1:2])
+ check_strides_offset(memref.subview(mem1, (1, 1, 1, 0), (1, 1, 1, 4444), (1, 1, 1, 1)), golden_mem[1:2, 1:2, 1:2, :])
+ # fmt: on
+
+ # default strides and offset means no stridedlayout attribute means affinemap layout
+ assert memref.subview(
+ mem1, (0, 0, 0, 0), (10, 22, 333, 4444), (1, 1, 1, 1)
+ ).type.layout == AffineMapAttr.get(
+ AffineMap.get(
+ 4,
+ 0,
+ [
+ AffineDimExpr.get(0),
+ AffineDimExpr.get(1),
+ AffineDimExpr.get(2),
+ AffineDimExpr.get(3),
+ ],
+ )
+ )
+
+ shape = (7, 22, 333, 4444)
+ golden_mem = np.zeros(shape, dtype=np.int32)
+ mem2 = memref.alloc(T.memref(*shape, T.i32()), [], [])
+ # fmt: off
+ check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 333, 4444), (1, 2, 1, 1)), golden_mem[:, 0:22:2])
+ check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 4444), (1, 2, 30, 1)), golden_mem[:, 0:22:2, 0:330:30])
+ check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 11), (1, 2, 30, 400)), golden_mem[:, 0:22:2, 0:330:30, 0:4400:400])
+ check_strides_offset(memref.subview(mem2, (0, 0, 100, 1000), (7, 22, 20, 20), (1, 1, 5, 50)), golden_mem[:, :, 100:200:5, 1000:2000:50])
+ # fmt: on
+
+ shape = (8, 8)
+ golden_mem = np.zeros(shape, dtype=np.int32)
+ # fmt: off
+ mem3 = memref.alloc(T.memref(*shape, T.i32()), [], [])
+ check_strides_offset(memref.subview(mem3, (0, 0), (4, 4), (1, 1)), golden_mem[0:4, 0:4])
+ check_strides_offset(memref.subview(mem3, (4, 4), (4, 4), (1, 1)), golden_mem[4:8, 4:8])
+ # fmt: on
More information about the Mlir-commits
mailing list