[Mlir-commits] [mlir] [mlir][python] enable memref.subview (PR #79393)

Aiden Grossman llvmlistbot at llvm.org
Mon Jan 29 11:40:19 PST 2024


https://github.com/boomanaiden154 updated https://github.com/llvm/llvm-project/pull/79393

>From a192a09976827dc052a42caac3d144bd5542f1e4 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 24 Jan 2024 18:25:54 -0600
Subject: [PATCH 1/2] [mlir][python] enable memref.subview

---
 mlir/include/mlir-c/BuiltinTypes.h   |   7 ++
 mlir/lib/Bindings/Python/IRTypes.cpp |   9 ++
 mlir/lib/CAPI/IR/BuiltinTypes.cpp    |  14 +++
 mlir/python/mlir/dialects/memref.py  | 121 +++++++++++++++++++
 mlir/test/python/dialects/memref.py  | 166 ++++++++++++++++++++++++++-
 5 files changed, 315 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 1fd5691f41eec35..2523bddc475d823 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -408,6 +408,13 @@ MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type);
 /// Returns the memory space of the given MemRef type.
 MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type);
 
+/// Returns the strides of the MemRef if the layout map is in strided form.
+/// Both strides and offset are out params. strides must point to pre-allocated
+/// memory of length equal to the rank of the memref.
+MLIR_CAPI_EXPORTED void mlirMemRefTypeGetStridesAndOffset(MlirType type,
+                                                          int64_t *strides,
+                                                          int64_t *offset);
+
 /// Returns the memory spcae of the given Unranked MemRef type.
 MLIR_CAPI_EXPORTED MlirAttribute
 mlirUnrankedMemrefGetMemorySpace(MlirType type);
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 56e895d3053796e..86f01a6381ae4e0 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -618,6 +618,15 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
               return mlirMemRefTypeGetLayout(self);
             },
             "The layout of the MemRef type.")
+        .def_property_readonly(
+            "strides_and_offset",
+            [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
+              std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
+              int64_t offset;
+              mlirMemRefTypeGetStridesAndOffset(self, strides.data(), &offset);
+              return {strides, offset};
+            },
+            "The strides and offset of the MemRef type.")
         .def_property_readonly(
             "affine_map",
             [](PyMemRefType &self) -> PyAffineMap {
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 6e645188dac8616..6a3653d8baf304a 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -16,6 +16,8 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Types.h"
 
+#include <algorithm>
+
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
@@ -426,6 +428,18 @@ MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
   return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace());
 }
 
+void mlirMemRefTypeGetStridesAndOffset(MlirType type, int64_t *strides,
+                                       int64_t *offset) {
+  MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type));
+  std::pair<SmallVector<int64_t>, int64_t> stridesOffsets =
+      getStridesAndOffset(memrefType);
+  assert(stridesOffsets.first.size() == memrefType.getRank() &&
+         "Strides and rank don't match for memref");
+  (void)std::copy(stridesOffsets.first.begin(), stridesOffsets.first.end(),
+                  strides);
+  *offset = stridesOffsets.second;
+}
+
 MlirTypeID mlirUnrankedMemRefTypeGetTypeID() {
   return wrap(UnrankedMemRefType::getTypeID());
 }
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index 3afb6a70cb9e0db..6ab6e0602e7a95d 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -1,5 +1,126 @@
 #  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+import operator
+from itertools import accumulate
+from typing import Optional
 
 from ._memref_ops_gen import *
+from .arith import ConstantOp
+from .transform.structured import _dispatch_mixed_values, MixedValues
+from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType
+
+
+def _is_constant(i):
+    return isinstance(i, Value) and isinstance(i.owner.opview, ConstantOp)
+
+
+def _is_static(i):
+    return (isinstance(i, int) and not ShapedType.is_dynamic_size(i)) or _is_constant(i)
+
+
+def _infer_memref_subview_result_type(
+    source_memref_type, offsets, static_sizes, static_strides
+):
+    source_strides, source_offset = source_memref_type.strides_and_offset
+    # "canonicalize" from tuple|list -> list
+    offsets, static_sizes, static_strides, source_strides = map(
+        list, (offsets, static_sizes, static_strides, source_strides)
+    )
+
+    assert all(
+        all(_is_static(i) for i in s)
+        for s in [
+            static_sizes,
+            static_strides,
+            source_strides,
+        ]
+    ), f"Only inferring from python or mlir integer constant is supported"
+
+    for s in [offsets, static_sizes, static_strides]:
+        for idx, i in enumerate(s):
+            if _is_constant(i):
+                s[idx] = i.owner.opview.literal_value
+
+    if any(not _is_static(i) for i in offsets + [source_offset]):
+        target_offset = ShapedType.get_dynamic_size()
+    else:
+        target_offset = source_offset
+        for offset, target_stride in zip(offsets, source_strides):
+            target_offset += offset * target_stride
+
+    target_strides = []
+    for source_stride, static_stride in zip(source_strides, static_strides):
+        target_strides.append(source_stride * static_stride)
+
+    # If default striding then no need to complicate things for downstream ops (e.g., expand_shape).
+    default_strides = list(accumulate(static_sizes[1:][::-1], operator.mul))[::-1] + [1]
+    if target_strides == default_strides and target_offset == 0:
+        layout = None
+    else:
+        layout = StridedLayoutAttr.get(target_offset, target_strides)
+    return (
+        offsets,
+        static_sizes,
+        static_strides,
+        MemRefType.get(
+            static_sizes,
+            source_memref_type.element_type,
+            layout,
+            source_memref_type.memory_space,
+        ),
+    )
+
+
+_generated_subview = subview
+
+
+def subview(
+    source: Value,
+    offsets: MixedValues,
+    sizes: MixedValues,
+    strides: MixedValues,
+    *,
+    result_type: Optional[MemRefType] = None,
+    loc=None,
+    ip=None,
+):
+    if offsets is None:
+        offsets = []
+    if sizes is None:
+        sizes = []
+    if strides is None:
+        strides = []
+    source_strides, source_offset = source.type.strides_and_offset
+    if result_type is None and all(
+        all(_is_static(i) for i in s) for s in [sizes, strides, source_strides]
+    ):
+        # If any are arith.constant results then this will canonicalize to python int
+        # (which can then be used to fully specific the subview).
+        (
+            offsets,
+            sizes,
+            strides,
+            result_type,
+        ) = _infer_memref_subview_result_type(source.type, offsets, sizes, strides)
+    else:
+        assert (
+            result_type is not None
+        ), "mixed static/dynamic offset/sizes/strides requires explicit result type"
+
+    offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets)
+    sizes, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes)
+    strides, _packed_strides, static_strides = _dispatch_mixed_values(strides)
+
+    return _generated_subview(
+        result_type,
+        source,
+        offsets,
+        sizes,
+        strides,
+        static_offsets,
+        static_sizes,
+        static_strides,
+        loc=loc,
+        ip=ip,
+    )
diff --git a/mlir/test/python/dialects/memref.py b/mlir/test/python/dialects/memref.py
index 0c8a7ee282fe161..0cf2fe15384fbd1 100644
--- a/mlir/test/python/dialects/memref.py
+++ b/mlir/test/python/dialects/memref.py
@@ -1,9 +1,10 @@
 # RUN: %PYTHON %s | FileCheck %s
 
-from mlir.ir import *
-import mlir.dialects.func as func
+import mlir.dialects.arith as arith
 import mlir.dialects.memref as memref
 import mlir.extras.types as T
+from mlir.dialects.memref import _infer_memref_subview_result_type
+from mlir.ir import *
 
 
 def run(f):
@@ -88,3 +89,164 @@ def testMemRefAttr():
             memref.global_("objFifo_in0", T.memref(16, T.i32()))
         # CHECK: memref.global @objFifo_in0 : memref<16xi32>
         print(module)
+
+
+# CHECK-LABEL: TEST: testSubViewOpInferReturnTypeSemantics
+ at run
+def testSubViewOpInferReturnTypeSemantics():
+    with Context() as ctx, Location.unknown(ctx):
+        module = Module.create()
+        with InsertionPoint(module.body):
+            x = memref.alloc(T.memref(10, 10, T.i32()), [], [])
+            # CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<10x10xi32>
+            print(x.owner)
+
+            y = memref.subview(x, [1, 1], [3, 3], [1, 1])
+            assert y.owner.verify()
+            # CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
+            print(y.owner)
+
+            z = memref.subview(
+                x,
+                [arith.constant(T.index(), 1), 1],
+                [3, 3],
+                [1, 1],
+            )
+            # CHECK: %{{.*}} =  memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
+            print(z.owner)
+
+            z = memref.subview(
+                x,
+                [arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
+                [3, 3],
+                [1, 1],
+            )
+            # CHECK: %{{.*}} =  memref.subview %[[ALLOC]][3, 4] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 34>>
+            print(z.owner)
+
+            s = arith.addi(arith.constant(T.index(), 3), arith.constant(T.index(), 4))
+            z = memref.subview(
+                x,
+                [s, 0],
+                [3, 3],
+                [1, 1],
+            )
+            # CHECK: {{.*}} = memref.subview %[[ALLOC]][%0, 0] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: ?>>
+            print(z)
+
+            try:
+                _infer_memref_subview_result_type(
+                    x.type,
+                    [arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
+                    [ShapedType.get_dynamic_size(), 3],
+                    [1, 1],
+                )
+            except AssertionError as e:
+                # CHECK: Only inferring from python or mlir integer constant is supported
+                print(e)
+
+            try:
+                memref.subview(
+                    x,
+                    [arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
+                    [ShapedType.get_dynamic_size(), 3],
+                    [1, 1],
+                )
+            except AssertionError as e:
+                # CHECK: mixed static/dynamic offset/sizes/strides requires explicit result type
+                print(e)
+
+            layout = StridedLayoutAttr.get(ShapedType.get_dynamic_size(), [10, 1])
+            x = memref.alloc(
+                T.memref(
+                    10,
+                    10,
+                    T.i32(),
+                    layout=layout,
+                ),
+                [],
+                [arith.constant(T.index(), 42)],
+            )
+            # CHECK: %[[DYNAMICALLOC:.*]] = memref.alloc()[%c42] : memref<10x10xi32, strided<[10, 1], offset: ?>>
+            print(x.owner)
+            y = memref.subview(
+                x,
+                [1, 1],
+                [3, 3],
+                [1, 1],
+                result_type=T.memref(3, 3, T.i32(), layout=layout),
+            )
+            # CHECK: %{{.*}} = memref.subview %[[DYNAMICALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32, strided<[10, 1], offset: ?>> to memref<3x3xi32, strided<[10, 1], offset: ?>>
+            print(y.owner)
+
+
+# CHECK-LABEL: TEST: testSubViewOpInferReturnTypeExtensiveSlicing
+ at run
+def testSubViewOpInferReturnTypeExtensiveSlicing():
+    def check_strides_offset(memref, np_view):
+        layout = memref.type.layout
+        dtype_size_in_bytes = np_view.dtype.itemsize
+        golden_strides = (np.array(np_view.strides) // dtype_size_in_bytes).tolist()
+        golden_offset = (
+            np_view.ctypes.data - np_view.base.ctypes.data
+        ) // dtype_size_in_bytes
+
+        assert (layout.strides, layout.offset) == (golden_strides, golden_offset)
+
+    with Context() as ctx, Location.unknown(ctx):
+        module = Module.create()
+        with InsertionPoint(module.body):
+            shape = (10, 22, 333, 4444)
+            golden_mem = np.zeros(shape, dtype=np.int32)
+            mem1 = memref.alloc(T.memref(*shape, T.i32()), [], [])
+
+            # fmt: off
+            check_strides_offset(memref.subview(mem1, (1, 0, 0, 0), (1, 22, 333, 4444), (1, 1, 1, 1)), golden_mem[1:2, ...])
+            check_strides_offset(memref.subview(mem1, (0, 1, 0, 0), (10, 1, 333, 4444), (1, 1, 1, 1)), golden_mem[:, 1:2])
+            check_strides_offset(memref.subview(mem1, (0, 0, 1, 0), (10, 22, 1, 4444), (1, 1, 1, 1)), golden_mem[:, :, 1:2])
+            check_strides_offset(memref.subview(mem1, (0, 0, 0, 1), (10, 22, 333, 1), (1, 1, 1, 1)), golden_mem[:, :, :, 1:2])
+            check_strides_offset(memref.subview(mem1, (0, 1, 0, 1), (10, 1, 333, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, :, 1:2])
+            check_strides_offset(memref.subview(mem1, (1, 0, 0, 1), (1, 22, 333, 1), (1, 1, 1, 1)), golden_mem[1:2, :, :, 1:2])
+            check_strides_offset(memref.subview(mem1, (1, 1, 0, 0), (1, 1, 333, 4444), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, :])
+            check_strides_offset(memref.subview(mem1, (0, 0, 1, 1), (10, 22, 1, 1), (1, 1, 1, 1)), golden_mem[:, :, 1:2, 1:2])
+            check_strides_offset(memref.subview(mem1, (0, 1, 1, 0), (10, 1, 1, 4444), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, :])
+            check_strides_offset(memref.subview(mem1, (1, 0, 1, 0), (1, 22, 1, 4444), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, :])
+            check_strides_offset(memref.subview(mem1, (1, 1, 0, 1), (1, 1, 333, 1), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, 1:2])
+            check_strides_offset(memref.subview(mem1, (1, 0, 1, 1), (1, 22, 1, 1), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, 1:2])
+            check_strides_offset(memref.subview(mem1, (0, 1, 1, 1), (10, 1, 1, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, 1:2])
+            check_strides_offset(memref.subview(mem1, (1, 1, 1, 0), (1, 1, 1, 4444), (1, 1, 1, 1)), golden_mem[1:2, 1:2, 1:2, :])
+            # fmt: on
+
+            # default strides and offset means no stridedlayout attribute means affinemap layout
+            assert memref.subview(
+                mem1, (0, 0, 0, 0), (10, 22, 333, 4444), (1, 1, 1, 1)
+            ).type.layout == AffineMapAttr.get(
+                AffineMap.get(
+                    4,
+                    0,
+                    [
+                        AffineDimExpr.get(0),
+                        AffineDimExpr.get(1),
+                        AffineDimExpr.get(2),
+                        AffineDimExpr.get(3),
+                    ],
+                )
+            )
+
+            shape = (7, 22, 333, 4444)
+            golden_mem = np.zeros(shape, dtype=np.int32)
+            mem2 = memref.alloc(T.memref(*shape, T.i32()), [], [])
+            # fmt: off
+            check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 333, 4444), (1, 2, 1, 1)), golden_mem[:, 0:22:2])
+            check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 4444), (1, 2, 30, 1)), golden_mem[:, 0:22:2, 0:330:30])
+            check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 11), (1, 2, 30, 400)), golden_mem[:, 0:22:2, 0:330:30, 0:4400:400])
+            check_strides_offset(memref.subview(mem2, (0, 0, 100, 1000), (7, 22, 20, 20), (1, 1, 5, 50)), golden_mem[:, :, 100:200:5, 1000:2000:50])
+            # fmt: on
+
+            shape = (8, 8)
+            golden_mem = np.zeros(shape, dtype=np.int32)
+            # fmt: off
+            mem3 = memref.alloc(T.memref(*shape, T.i32()), [], [])
+            check_strides_offset(memref.subview(mem3, (0, 0), (4, 4), (1, 1)), golden_mem[0:4, 0:4])
+            check_strides_offset(memref.subview(mem3, (4, 4), (4, 4), (1, 1)), golden_mem[4:8, 4:8])
+            # fmt: on

>From 3d5939e4ea3f6cb062fb57f8e249a65a9b18aca8 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Mon, 29 Jan 2024 13:40:07 -0600
Subject: [PATCH 2/2] incorporate comments

---
 mlir/include/mlir-c/BuiltinTypes.h            |   5 +-
 mlir/lib/Bindings/Python/IRTypes.cpp          |  11 +-
 mlir/lib/CAPI/IR/BuiltinTypes.cpp             |  18 +-
 mlir/python/mlir/dialects/_ods_common.py      | 174 +++++++++++++++++-
 mlir/python/mlir/dialects/memref.py           |  23 ++-
 .../mlir/dialects/transform/structured.py     | 169 ++---------------
 mlir/test/python/dialects/memref.py           |   4 +-
 7 files changed, 222 insertions(+), 182 deletions(-)

diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 2523bddc475d823..881b6dad2b84d77 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -411,9 +411,8 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type);
 /// Returns the strides of the MemRef if the layout map is in strided form.
 /// Both strides and offset are out params. strides must point to pre-allocated
 /// memory of length equal to the rank of the memref.
-MLIR_CAPI_EXPORTED void mlirMemRefTypeGetStridesAndOffset(MlirType type,
-                                                          int64_t *strides,
-                                                          int64_t *offset);
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(
+    MlirType type, int64_t *strides, int64_t *offset);
 
 /// Returns the memory spcae of the given Unranked MemRef type.
 MLIR_CAPI_EXPORTED MlirAttribute
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 86f01a6381ae4e0..c87f791e93fb84d 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -12,6 +12,8 @@
 
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
+#include "mlir-c/Support.h"
+
 #include <optional>
 
 namespace py = pybind11;
@@ -618,12 +620,15 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
               return mlirMemRefTypeGetLayout(self);
             },
             "The layout of the MemRef type.")
-        .def_property_readonly(
-            "strides_and_offset",
+        .def(
+            "get_strides_and_offset",
             [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
               std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
               int64_t offset;
-              mlirMemRefTypeGetStridesAndOffset(self, strides.data(), &offset);
+              if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
+                      self, strides.data(), &offset)))
+                throw std::runtime_error(
+                    "failed to extract strides and offset from memref");
               return {strides, offset};
             },
             "The strides and offset of the MemRef type.")
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 6a3653d8baf304a..f31fd14eb773167 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -9,12 +9,14 @@
 #include "mlir-c/BuiltinTypes.h"
 #include "mlir-c/AffineMap.h"
 #include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
 #include "mlir/CAPI/AffineMap.h"
 #include "mlir/CAPI/IR.h"
 #include "mlir/CAPI/Support.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Types.h"
+#include "mlir/Support/LogicalResult.h"
 
 #include <algorithm>
 
@@ -428,16 +430,18 @@ MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
   return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace());
 }
 
-void mlirMemRefTypeGetStridesAndOffset(MlirType type, int64_t *strides,
-                                       int64_t *offset) {
+MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type,
+                                                    int64_t *strides,
+                                                    int64_t *offset) {
   MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type));
-  std::pair<SmallVector<int64_t>, int64_t> stridesOffsets =
-      getStridesAndOffset(memrefType);
+  SmallVector<int64_t> strides_;
+  if (failed(getStridesAndOffset(memrefType, strides_, *offset)))
+    return mlirLogicalResultFailure();
+
   assert(stridesOffsets.first.size() == memrefType.getRank() &&
          "Strides and rank don't match for memref");
-  (void)std::copy(stridesOffsets.first.begin(), stridesOffsets.first.end(),
-                  strides);
-  *offset = stridesOffsets.second;
+  (void)std::copy(strides_.begin(), strides_.end(), strides);
+  return mlirLogicalResultSuccess();
 }
 
 MlirTypeID mlirUnrankedMemRefTypeGetTypeID() {
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 1685124fbccdc9f..3af3b5ce73bc60a 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -2,16 +2,30 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-# Provide a convenient name for sub-packages to resolve the main C-extension
-# with a relative import.
-from .._mlir_libs import _mlir as _cext
 from typing import (
+    List as _List,
+    Optional as _Optional,
     Sequence as _Sequence,
+    Tuple as _Tuple,
     Type as _Type,
     TypeVar as _TypeVar,
     Union as _Union,
 )
 
+from .._mlir_libs import _mlir as _cext
+from ..ir import (
+    ArrayAttr,
+    Attribute,
+    BoolAttr,
+    DenseI64ArrayAttr,
+    IntegerAttr,
+    IntegerType,
+    OpView,
+    Operation,
+    ShapedType,
+    Value,
+)
+
 __all__ = [
     "equally_sized_accessor",
     "get_default_loc_context",
@@ -138,3 +152,157 @@ def get_op_result_or_op_results(
 ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
 ResultValueT = _Union[ResultValueTypeTuple]
 VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
+
+StaticIntLike = _Union[int, IntegerAttr]
+ValueLike = _Union[Operation, OpView, Value]
+MixedInt = _Union[StaticIntLike, ValueLike]
+
+IntOrAttrList = _Sequence[_Union[IntegerAttr, int]]
+OptionalIntList = _Optional[_Union[ArrayAttr, IntOrAttrList]]
+
+BoolOrAttrList = _Sequence[_Union[BoolAttr, bool]]
+OptionalBoolList = _Optional[_Union[ArrayAttr, BoolOrAttrList]]
+
+MixedValues = _Union[_Sequence[_Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
+
+DynamicIndexList = _Sequence[_Union[MixedInt, _Sequence[MixedInt]]]
+
+
+def _dispatch_dynamic_index_list(
+    indices: _Union[DynamicIndexList, ArrayAttr],
+) -> _Tuple[_List[ValueLike], _Union[_List[int], ArrayAttr], _List[bool]]:
+    """Dispatches a list of indices to the appropriate form.
+
+    This is similar to the custom `DynamicIndexList` directive upstream:
+    provided indices may be in the form of dynamic SSA values or static values,
+    and they may be scalable (i.e., as a singleton list) or not. This function
+    dispatches each index into its respective form. It also extracts the SSA
+    values and static indices from various similar structures, respectively.
+    """
+    dynamic_indices = []
+    static_indices = [ShapedType.get_dynamic_size()] * len(indices)
+    scalable_indices = [False] * len(indices)
+
+    # ArrayAttr: Extract index values.
+    if isinstance(indices, ArrayAttr):
+        indices = [idx for idx in indices]
+
+    def process_nonscalable_index(i, index):
+        """Processes any form of non-scalable index.
+
+        Returns False if the given index was scalable and thus remains
+        unprocessed; True otherwise.
+        """
+        if isinstance(index, int):
+            static_indices[i] = index
+        elif isinstance(index, IntegerAttr):
+            static_indices[i] = index.value  # pytype: disable=attribute-error
+        elif isinstance(index, (Operation, Value, OpView)):
+            dynamic_indices.append(index)
+        else:
+            return False
+        return True
+
+    # Process each index at a time.
+    for i, index in enumerate(indices):
+        if not process_nonscalable_index(i, index):
+            # If it wasn't processed, it must be a scalable index, which is
+            # provided as a _Sequence of one value, so extract and process that.
+            scalable_indices[i] = True
+            assert len(index) == 1
+            ret = process_nonscalable_index(i, index[0])
+            assert ret
+
+    return dynamic_indices, static_indices, scalable_indices
+
+
+# Dispatches `MixedValues` that all represents integers in various forms into
+# the following three categories:
+#   - `dynamic_values`: a list of `Value`s, potentially from op results;
+#   - `packed_values`: a value handle, potentially from an op result, associated
+#                      to one or more payload operations of integer type;
+#   - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
+#                      `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
+# The input is in the form for `packed_values`, only that result is set and the
+# other two are empty. Otherwise, the input can be a mix of the other two forms,
+# and for each dynamic value, a special value is added to the `static_values`.
+def _dispatch_mixed_values(
+    values: MixedValues,
+) -> _Tuple[_List[Value], _Union[Operation, Value, OpView], DenseI64ArrayAttr]:
+    dynamic_values = []
+    packed_values = None
+    static_values = None
+    if isinstance(values, ArrayAttr):
+        static_values = values
+    elif isinstance(values, (Operation, Value, OpView)):
+        packed_values = values
+    else:
+        static_values = []
+        for size in values or []:
+            if isinstance(size, int):
+                static_values.append(size)
+            else:
+                static_values.append(ShapedType.get_dynamic_size())
+                dynamic_values.append(size)
+        static_values = DenseI64ArrayAttr.get(static_values)
+
+    return (dynamic_values, packed_values, static_values)
+
+
+def _get_value_or_attribute_value(
+    value_or_attr: _Union[any, Attribute, ArrayAttr]
+) -> any:
+    if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
+        return value_or_attr.value
+    if isinstance(value_or_attr, ArrayAttr):
+        return _get_value_list(value_or_attr)
+    return value_or_attr
+
+
+def _get_value_list(
+    sequence_or_array_attr: _Union[_Sequence[any], ArrayAttr]
+) -> _Sequence[any]:
+    return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
+
+
+def _get_int_array_attr(
+    values: _Optional[_Union[ArrayAttr, IntOrAttrList]]
+) -> ArrayAttr:
+    if values is None:
+        return None
+
+    # Turn into a Python list of Python ints.
+    values = _get_value_list(values)
+
+    # Make an ArrayAttr of IntegerAttrs out of it.
+    return ArrayAttr.get(
+        [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
+    )
+
+
+def _get_int_array_array_attr(
+    values: _Optional[_Union[ArrayAttr, _Sequence[_Union[ArrayAttr, IntOrAttrList]]]]
+) -> ArrayAttr:
+    """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
+
+    The input has to be a collection of a collection of integers, where any
+    Python _Sequence and ArrayAttr are admissible collections and Python ints and
+    any IntegerAttr are admissible integers. Both levels of collections are
+    turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
+    If the input is None, an empty ArrayAttr is returned.
+    """
+    if values is None:
+        return None
+
+    # Make sure the outer level is a list.
+    values = _get_value_list(values)
+
+    # The inner level is now either invalid or a mixed sequence of ArrayAttrs and
+    # Sequences. Make sure the nested values are all lists.
+    values = [_get_value_list(nested) for nested in values]
+
+    # Turn each nested list into an ArrayAttr.
+    values = [_get_int_array_attr(nested) for nested in values]
+
+    # Turn the outer list into an ArrayAttr.
+    return ArrayAttr.get(values)
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index 6ab6e0602e7a95d..3ff032bb02f4ecd 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -6,8 +6,8 @@
 from typing import Optional
 
 from ._memref_ops_gen import *
+from ._ods_common import _dispatch_mixed_values, MixedValues
 from .arith import ConstantOp
-from .transform.structured import _dispatch_mixed_values, MixedValues
 from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType
 
 
@@ -22,20 +22,23 @@ def _is_static(i):
 def _infer_memref_subview_result_type(
     source_memref_type, offsets, static_sizes, static_strides
 ):
-    source_strides, source_offset = source_memref_type.strides_and_offset
+    source_strides, source_offset = source_memref_type.get_strides_and_offset()
     # "canonicalize" from tuple|list -> list
     offsets, static_sizes, static_strides, source_strides = map(
         list, (offsets, static_sizes, static_strides, source_strides)
     )
 
-    assert all(
+    if not all(
         all(_is_static(i) for i in s)
         for s in [
             static_sizes,
             static_strides,
             source_strides,
         ]
-    ), f"Only inferring from python or mlir integer constant is supported"
+    ):
+        raise ValueError(
+            "Only inferring from python or mlir integer constant is supported."
+        )
 
     for s in [offsets, static_sizes, static_strides]:
         for idx, i in enumerate(s):
@@ -91,22 +94,22 @@ def subview(
         sizes = []
     if strides is None:
         strides = []
-    source_strides, source_offset = source.type.strides_and_offset
+    source_strides, source_offset = source.type.get_strides_and_offset()
     if result_type is None and all(
         all(_is_static(i) for i in s) for s in [sizes, strides, source_strides]
     ):
         # If any are arith.constant results then this will canonicalize to python int
-        # (which can then be used to fully specific the subview).
+        # (which can then be used to fully specify the subview).
         (
             offsets,
             sizes,
             strides,
             result_type,
         ) = _infer_memref_subview_result_type(source.type, offsets, sizes, strides)
-    else:
-        assert (
-            result_type is not None
-        ), "mixed static/dynamic offset/sizes/strides requires explicit result type"
+    elif result_type is None:
+        raise ValueError(
+            "mixed static/dynamic offset/sizes/strides requires explicit result type."
+        )
 
     offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets)
     sizes, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes)
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index 284c93823acbd34..d7b41c0bd2207d1 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -9,163 +9,24 @@
 try:
     from ...ir import *
     from ...dialects import transform
-    from .._ods_common import _cext as _ods_cext
+    from .._ods_common import (
+        DynamicIndexList,
+        IntOrAttrList,
+        MixedValues,
+        OptionalBoolList,
+        OptionalIntList,
+        _cext as _ods_cext,
+        _dispatch_dynamic_index_list,
+        _dispatch_mixed_values,
+        _get_int_array_array_attr,
+        _get_int_array_attr,
+        _get_value_list,
+        _get_value_or_attribute_value,
+    )
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
-from typing import List, Optional, Sequence, Tuple, Union, overload
-
-StaticIntLike = Union[int, IntegerAttr]
-ValueLike = Union[Operation, OpView, Value]
-MixedInt = Union[StaticIntLike, ValueLike]
-
-IntOrAttrList = Sequence[Union[IntegerAttr, int]]
-OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
-
-BoolOrAttrList = Sequence[Union[BoolAttr, bool]]
-OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]]
-
-MixedValues = Union[Sequence[Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
-
-DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]]
-
-
-def _dispatch_dynamic_index_list(
-    indices: Union[DynamicIndexList, ArrayAttr],
-) -> Tuple[List[ValueLike], Union[List[int], ArrayAttr], List[bool]]:
-    """Dispatches a list of indices to the appropriate form.
-
-    This is similar to the custom `DynamicIndexList` directive upstream:
-    provided indices may be in the form of dynamic SSA values or static values,
-    and they may be scalable (i.e., as a singleton list) or not. This function
-    dispatches each index into its respective form. It also extracts the SSA
-    values and static indices from various similar structures, respectively.
-    """
-    dynamic_indices = []
-    static_indices = [ShapedType.get_dynamic_size()] * len(indices)
-    scalable_indices = [False] * len(indices)
-
-    # ArrayAttr: Extract index values.
-    if isinstance(indices, ArrayAttr):
-        indices = [idx for idx in indices]
-
-    def process_nonscalable_index(i, index):
-        """Processes any form of non-scalable index.
-
-        Returns False if the given index was scalable and thus remains
-        unprocessed; True otherwise.
-        """
-        if isinstance(index, int):
-            static_indices[i] = index
-        elif isinstance(index, IntegerAttr):
-            static_indices[i] = index.value  # pytype: disable=attribute-error
-        elif isinstance(index, (Operation, Value, OpView)):
-            dynamic_indices.append(index)
-        else:
-            return False
-        return True
-
-    # Process each index at a time.
-    for i, index in enumerate(indices):
-        if not process_nonscalable_index(i, index):
-            # If it wasn't processed, it must be a scalable index, which is
-            # provided as a Sequence of one value, so extract and process that.
-            scalable_indices[i] = True
-            assert len(index) == 1
-            ret = process_nonscalable_index(i, index[0])
-            assert ret
-
-    return dynamic_indices, static_indices, scalable_indices
-
-
-# Dispatches `MixedValues` that all represents integers in various forms into
-# the following three categories:
-#   - `dynamic_values`: a list of `Value`s, potentially from op results;
-#   - `packed_values`: a value handle, potentially from an op result, associated
-#                      to one or more payload operations of integer type;
-#   - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
-#                      `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
-# The input is in the form for `packed_values`, only that result is set and the
-# other two are empty. Otherwise, the input can be a mix of the other two forms,
-# and for each dynamic value, a special value is added to the `static_values`.
-def _dispatch_mixed_values(
-    values: MixedValues,
-) -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]:
-    dynamic_values = []
-    packed_values = None
-    static_values = None
-    if isinstance(values, ArrayAttr):
-        static_values = values
-    elif isinstance(values, (Operation, Value, OpView)):
-        packed_values = values
-    else:
-        static_values = []
-        for size in values or []:
-            if isinstance(size, int):
-                static_values.append(size)
-            else:
-                static_values.append(ShapedType.get_dynamic_size())
-                dynamic_values.append(size)
-        static_values = DenseI64ArrayAttr.get(static_values)
-
-    return (dynamic_values, packed_values, static_values)
-
-
-def _get_value_or_attribute_value(
-    value_or_attr: Union[any, Attribute, ArrayAttr]
-) -> any:
-    if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
-        return value_or_attr.value
-    if isinstance(value_or_attr, ArrayAttr):
-        return _get_value_list(value_or_attr)
-    return value_or_attr
-
-
-def _get_value_list(
-    sequence_or_array_attr: Union[Sequence[any], ArrayAttr]
-) -> Sequence[any]:
-    return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
-
-
-def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr:
-    if values is None:
-        return None
-
-    # Turn into a Python list of Python ints.
-    values = _get_value_list(values)
-
-    # Make an ArrayAttr of IntegerAttrs out of it.
-    return ArrayAttr.get(
-        [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
-    )
-
-
-def _get_int_array_array_attr(
-    values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
-) -> ArrayAttr:
-    """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
-
-    The input has to be a collection of collection of integers, where any
-    Python Sequence and ArrayAttr are admissible collections and Python ints and
-    any IntegerAttr are admissible integers. Both levels of collections are
-    turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
-    If the input is None, an empty ArrayAttr is returned.
-    """
-    if values is None:
-        return None
-
-    # Make sure the outer level is a list.
-    values = _get_value_list(values)
-
-    # The inner level is now either invalid or a mixed sequence of ArrayAttrs and
-    # Sequences. Make sure the nested values are all lists.
-    values = [_get_value_list(nested) for nested in values]
-
-    # Turn each nested list into an ArrayAttr.
-    values = [_get_int_array_attr(nested) for nested in values]
-
-    # Turn the outer list into an ArrayAttr.
-    return ArrayAttr.get(values)
+from typing import List, Optional, Sequence, Union, overload
 
 
 @_ods_cext.register_operation(_Dialect, replace=True)
diff --git a/mlir/test/python/dialects/memref.py b/mlir/test/python/dialects/memref.py
index 0cf2fe15384fbd1..162c22aedbdc863 100644
--- a/mlir/test/python/dialects/memref.py
+++ b/mlir/test/python/dialects/memref.py
@@ -141,7 +141,7 @@ def testSubViewOpInferReturnTypeSemantics():
                     [ShapedType.get_dynamic_size(), 3],
                     [1, 1],
                 )
-            except AssertionError as e:
+            except ValueError as e:
                 # CHECK: Only inferring from python or mlir integer constant is supported
                 print(e)
 
@@ -152,7 +152,7 @@ def testSubViewOpInferReturnTypeSemantics():
                     [ShapedType.get_dynamic_size(), 3],
                     [1, 1],
                 )
-            except AssertionError as e:
+            except ValueError as e:
                 # CHECK: mixed static/dynamic offset/sizes/strides requires explicit result type
                 print(e)
 



More information about the Mlir-commits mailing list