[Mlir-commits] [mlir] [mlir][python] enable memref.subview (PR #79393)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Mon Jan 29 06:03:11 PST 2024


================
@@ -1,5 +1,126 @@
 #  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+import operator
+from itertools import accumulate
+from typing import Optional
 
 from ._memref_ops_gen import *
+from .arith import ConstantOp
+from .transform.structured import _dispatch_mixed_values, MixedValues
+from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType
+
+
+def _is_constant(i):
+    return isinstance(i, Value) and isinstance(i.owner.opview, ConstantOp)
+
+
+def _is_static(i):
+    return (isinstance(i, int) and not ShapedType.is_dynamic_size(i)) or _is_constant(i)
+
+
+def _infer_memref_subview_result_type(
+    source_memref_type, offsets, static_sizes, static_strides
+):
+    source_strides, source_offset = source_memref_type.strides_and_offset
+    # "canonicalize" from tuple|list -> list
+    offsets, static_sizes, static_strides, source_strides = map(
+        list, (offsets, static_sizes, static_strides, source_strides)
+    )
+
+    assert all(
+        all(_is_static(i) for i in s)
+        for s in [
+            static_sizes,
+            static_strides,
+            source_strides,
+        ]
+    ), f"Only inferring from python or mlir integer constant is supported"
+
+    for s in [offsets, static_sizes, static_strides]:
+        for idx, i in enumerate(s):
+            if _is_constant(i):
+                s[idx] = i.owner.opview.literal_value
+
+    if any(not _is_static(i) for i in offsets + [source_offset]):
+        target_offset = ShapedType.get_dynamic_size()
+    else:
+        target_offset = source_offset
+        for offset, target_stride in zip(offsets, source_strides):
+            target_offset += offset * target_stride
+
+    target_strides = []
+    for source_stride, static_stride in zip(source_strides, static_strides):
+        target_strides.append(source_stride * static_stride)
+
+    # If default striding then no need to complicate things for downstream ops (e.g., expand_shape).
+    default_strides = list(accumulate(static_sizes[1:][::-1], operator.mul))[::-1] + [1]
+    if target_strides == default_strides and target_offset == 0:
+        layout = None
+    else:
+        layout = StridedLayoutAttr.get(target_offset, target_strides)
+    return (
+        offsets,
+        static_sizes,
+        static_strides,
+        MemRefType.get(
+            static_sizes,
+            source_memref_type.element_type,
+            layout,
+            source_memref_type.memory_space,
+        ),
+    )
+
+
+_generated_subview = subview
+
+
+def subview(
+    source: Value,
+    offsets: MixedValues,
+    sizes: MixedValues,
+    strides: MixedValues,
+    *,
+    result_type: Optional[MemRefType] = None,
+    loc=None,
+    ip=None,
+):
+    if offsets is None:
+        offsets = []
+    if sizes is None:
+        sizes = []
+    if strides is None:
+        strides = []
+    source_strides, source_offset = source.type.strides_and_offset
+    if result_type is None and all(
+        all(_is_static(i) for i in s) for s in [sizes, strides, source_strides]
+    ):
+        # If any are arith.constant results then this will canonicalize to python int
+        # (which can then be used to fully specific the subview).
+        (
+            offsets,
+            sizes,
+            strides,
+            result_type,
+        ) = _infer_memref_subview_result_type(source.type, offsets, sizes, strides)
+    else:
+        assert (
----------------
ftynse wrote:

Nit: should this rather be `raise ValueError`?

https://github.com/llvm/llvm-project/pull/79393


More information about the Mlir-commits mailing list