[Mlir-commits] [mlir] [mlir][python] enable memref.subview (PR #79393)

Maksim Levental llvmlistbot at llvm.org
Thu Jan 25 00:25:38 PST 2024


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/79393

>From b2a3b31d5867d6037ddd4a4f9f0d93d3e4268ca0 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 24 Jan 2024 18:25:54 -0600
Subject: [PATCH] [mlir][python] enable memref.subview

---
 mlir/include/mlir-c/BuiltinTypes.h   |   7 ++
 mlir/lib/Bindings/Python/IRTypes.cpp |   9 +++
 mlir/lib/CAPI/IR/BuiltinTypes.cpp    |  14 ++++
 mlir/python/mlir/dialects/memref.py  | 100 +++++++++++++++++++++++++++
 mlir/test/python/dialects/memref.py  |  96 +++++++++++++++++++++++++
 5 files changed, 226 insertions(+)

diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 1fd5691f41eec3..2523bddc475d82 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -408,6 +408,13 @@ MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type);
 /// Returns the memory space of the given MemRef type.
 MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type);
 
+/// Returns the strides of the MemRef if the layout map is in strided form.
+/// Both strides and offset are out params. strides must point to pre-allocated
+/// memory of length equal to the rank of the memref.
+MLIR_CAPI_EXPORTED void mlirMemRefTypeGetStridesAndOffset(MlirType type,
+                                                          int64_t *strides,
+                                                          int64_t *offset);
+
 /// Returns the memory spcae of the given Unranked MemRef type.
 MLIR_CAPI_EXPORTED MlirAttribute
 mlirUnrankedMemrefGetMemorySpace(MlirType type);
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 56e895d3053796..86f01a6381ae4e 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -618,6 +618,15 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
               return mlirMemRefTypeGetLayout(self);
             },
             "The layout of the MemRef type.")
+        .def_property_readonly(
+            "strides_and_offset",
+            [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
+              std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
+              int64_t offset;
+              mlirMemRefTypeGetStridesAndOffset(self, strides.data(), &offset);
+              return {strides, offset};
+            },
+            "The strides and offset of the MemRef type.")
         .def_property_readonly(
             "affine_map",
             [](PyMemRefType &self) -> PyAffineMap {
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 6e645188dac861..6a3653d8baf304 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -16,6 +16,8 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Types.h"
 
+#include <algorithm>
+
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
@@ -426,6 +428,18 @@ MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
   return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace());
 }
 
+void mlirMemRefTypeGetStridesAndOffset(MlirType type, int64_t *strides,
+                                       int64_t *offset) {
+  MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type));
+  std::pair<SmallVector<int64_t>, int64_t> stridesOffsets =
+      getStridesAndOffset(memrefType);
+  assert(stridesOffsets.first.size() == memrefType.getRank() &&
+         "Strides and rank don't match for memref");
+  (void)std::copy(stridesOffsets.first.begin(), stridesOffsets.first.end(),
+                  strides);
+  *offset = stridesOffsets.second;
+}
+
 MlirTypeID mlirUnrankedMemRefTypeGetTypeID() {
   return wrap(UnrankedMemRefType::getTypeID());
 }
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index 3afb6a70cb9e0d..2c7afd60e897aa 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -1,5 +1,105 @@
 #  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from typing import Optional
 
 from ._memref_ops_gen import *
+from .arith import ConstantOp, _is_integer_like_type
+from .transform.structured import _dispatch_mixed_values, MixedValues
+from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType
+
+
+def _infer_memref_subview_result_type(
+    source_memref_type, static_offsets, static_sizes, static_strides
+):
+    source_strides, source_offset = source_memref_type.strides_and_offset
+    assert all(
+        all(
+            (isinstance(i, int) and not ShapedType.is_dynamic_size(i))
+            or (isinstance(i, Value) and isinstance(i.owner.opview, ConstantOp))
+            and _is_integer_like_type(i.type)
+            for i in s
+        )
+        for s in [
+            static_offsets,
+            static_sizes,
+            static_strides,
+            source_strides,
+            [source_offset],
+        ]
+    ), f"Only inferring from python or mlir integer constant is supported"
+    for s in [static_offsets, static_sizes, static_strides]:
+        for idx, i in enumerate(s):
+            if isinstance(i, Value):
+                s[idx] = i.owner.opview.literal_value
+
+    target_offset = source_offset
+    for static_offset, target_stride in zip(static_offsets, source_strides):
+        target_offset += static_offset * target_stride
+
+    target_strides = []
+    for source_stride, static_stride in zip(source_strides, static_strides):
+        target_strides.append(source_stride * static_stride)
+
+    layout = StridedLayoutAttr.get(target_offset, target_strides)
+    return MemRefType.get(
+        static_sizes,
+        source_memref_type.element_type,
+        layout,
+        source_memref_type.memory_space,
+    )
+
+
+_generated_subview = subview
+
+
+def subview(
+    source: Value,
+    offsets: MixedValues,
+    sizes: MixedValues,
+    strides: MixedValues,
+    *,
+    result_type: Optional[MemRefType] = None,
+    loc=None,
+    ip=None,
+):
+    if offsets is None:
+        offsets = []
+    if sizes is None:
+        sizes = []
+    if strides is None:
+        strides = []
+
+    source_strides, source_offset = source.type.strides_and_offset
+    if all(
+        all(
+            (isinstance(i, int) and not ShapedType.is_dynamic_size(i))
+            or (isinstance(i, Value) and isinstance(i.owner.opview, ConstantOp))
+            for i in s
+        )
+        for s in [offsets, sizes, strides, source_strides, [source_offset]]
+    ):
+        result_type = _infer_memref_subview_result_type(
+            source.type, offsets, sizes, strides
+        )
+    else:
+        assert (
+            result_type is not None
+        ), "mixed static/dynamic offset/sizes/strides requires explicit result type"
+
+    offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets)
+    sizes, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes)
+    strides, _packed_strides, static_strides = _dispatch_mixed_values(strides)
+
+    return _generated_subview(
+        result_type,
+        source,
+        offsets,
+        sizes,
+        strides,
+        static_offsets,
+        static_sizes,
+        static_strides,
+        loc=loc,
+        ip=ip,
+    )
diff --git a/mlir/test/python/dialects/memref.py b/mlir/test/python/dialects/memref.py
index 0c8a7ee282fe16..1ab39fdb6dc35d 100644
--- a/mlir/test/python/dialects/memref.py
+++ b/mlir/test/python/dialects/memref.py
@@ -3,6 +3,8 @@
 from mlir.ir import *
 import mlir.dialects.func as func
 import mlir.dialects.memref as memref
+from mlir.dialects.memref import _infer_memref_subview_result_type
+import mlir.dialects.arith as arith
 import mlir.extras.types as T
 
 
@@ -88,3 +90,97 @@ def testMemRefAttr():
             memref.global_("objFifo_in0", T.memref(16, T.i32()))
         # CHECK: memref.global @objFifo_in0 : memref<16xi32>
         print(module)
+
+
+# CHECK-LABEL: TEST: testSubViewOpInferReturnType
+ at run
+def testSubViewOpInferReturnType():
+    with Context() as ctx, Location.unknown(ctx):
+        module = Module.create()
+        with InsertionPoint(module.body):
+            x = memref.alloc(T.memref(10, 10, T.i32()), [], [])
+            # CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<10x10xi32>
+            print(x.owner)
+
+            y = memref.subview(x, [1, 1], [3, 3], [1, 1])
+            # CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
+            print(y.owner)
+
+            z = memref.subview(
+                x,
+                [arith.constant(T.index(), 1), 1],
+                [3, 3],
+                [1, 1],
+            )
+            # CHECK: %{{.*}} =  memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
+            print(z.owner)
+
+            z = memref.subview(
+                x,
+                [arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
+                [3, 3],
+                [1, 1],
+            )
+            # CHECK: %{{.*}} =  memref.subview %[[ALLOC]][3, 4] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 34>>
+            print(z.owner)
+
+            try:
+                memref.subview(
+                    x,
+                    [
+                        arith.addi(
+                            arith.constant(T.index(), 3), arith.constant(T.index(), 4)
+                        ),
+                        0,
+                    ],
+                    [3, 3],
+                    [1, 1],
+                )
+            except AssertionError as e:
+                # CHECK: mixed static/dynamic offset/sizes/strides requires explicit result type
+                print(e)
+
+            try:
+                _infer_memref_subview_result_type(
+                    x.type,
+                    [arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
+                    [ShapedType.get_dynamic_size(), 3],
+                    [1, 1],
+                )
+            except AssertionError as e:
+                # CHECK: Only inferring from python or mlir integer constant is supported
+                print(e)
+
+            try:
+                memref.subview(
+                    x,
+                    [arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
+                    [ShapedType.get_dynamic_size(), 3],
+                    [1, 1],
+                )
+            except AssertionError as e:
+                # CHECK: mixed static/dynamic offset/sizes/strides requires explicit result type
+                print(e)
+
+            layout = StridedLayoutAttr.get(ShapedType.get_dynamic_size(), [10, 1])
+            x = memref.alloc(
+                T.memref(
+                    10,
+                    10,
+                    T.i32(),
+                    layout=layout,
+                ),
+                [],
+                [arith.constant(T.index(), 42)],
+            )
+            # CHECK: %[[DYNAMICALLOC:.*]] = memref.alloc()[%c42] : memref<10x10xi32, strided<[10, 1], offset: ?>>
+            print(x.owner)
+            y = memref.subview(
+                x,
+                [1, 1],
+                [3, 3],
+                [1, 1],
+                result_type=T.memref(3, 3, T.i32(), layout=layout),
+            )
+            # CHECK: %subview_9 = memref.subview %[[DYNAMICALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32, strided<[10, 1], offset: ?>> to memref<3x3xi32, strided<[10, 1], offset: ?>>
+            print(y.owner)



More information about the Mlir-commits mailing list