[Mlir-commits] [mlir] [mlir][python] enable memref.subview (PR #79393)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Mon Jan 29 06:03:11 PST 2024
================
@@ -1,5 +1,126 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+import operator
+from itertools import accumulate
+from typing import Optional
from ._memref_ops_gen import *
+from .arith import ConstantOp
+from .transform.structured import _dispatch_mixed_values, MixedValues
+from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType
+
+
+def _is_constant(i):
+ return isinstance(i, Value) and isinstance(i.owner.opview, ConstantOp)
+
+
+def _is_static(i):
+ return (isinstance(i, int) and not ShapedType.is_dynamic_size(i)) or _is_constant(i)
+
+
+def _infer_memref_subview_result_type(
+ source_memref_type, offsets, static_sizes, static_strides
+):
+ source_strides, source_offset = source_memref_type.strides_and_offset
+ # "canonicalize" from tuple|list -> list
+ offsets, static_sizes, static_strides, source_strides = map(
+ list, (offsets, static_sizes, static_strides, source_strides)
+ )
+
+ assert all(
+ all(_is_static(i) for i in s)
+ for s in [
+ static_sizes,
+ static_strides,
+ source_strides,
+ ]
+ ), f"Only inferring from python or mlir integer constant is supported"
+
+ for s in [offsets, static_sizes, static_strides]:
+ for idx, i in enumerate(s):
+ if _is_constant(i):
+ s[idx] = i.owner.opview.literal_value
+
+ if any(not _is_static(i) for i in offsets + [source_offset]):
+ target_offset = ShapedType.get_dynamic_size()
+ else:
+ target_offset = source_offset
+ for offset, target_stride in zip(offsets, source_strides):
+ target_offset += offset * target_stride
+
+ target_strides = []
+ for source_stride, static_stride in zip(source_strides, static_strides):
+ target_strides.append(source_stride * static_stride)
+
+ # If default striding then no need to complicate things for downstream ops (e.g., expand_shape).
+ default_strides = list(accumulate(static_sizes[1:][::-1], operator.mul))[::-1] + [1]
+ if target_strides == default_strides and target_offset == 0:
+ layout = None
+ else:
+ layout = StridedLayoutAttr.get(target_offset, target_strides)
+ return (
+ offsets,
+ static_sizes,
+ static_strides,
+ MemRefType.get(
+ static_sizes,
+ source_memref_type.element_type,
+ layout,
+ source_memref_type.memory_space,
+ ),
+ )
+
+
+_generated_subview = subview
+
+
+def subview(
+ source: Value,
+ offsets: MixedValues,
+ sizes: MixedValues,
+ strides: MixedValues,
+ *,
+ result_type: Optional[MemRefType] = None,
+ loc=None,
+ ip=None,
+):
+ if offsets is None:
+ offsets = []
+ if sizes is None:
+ sizes = []
+ if strides is None:
+ strides = []
+ source_strides, source_offset = source.type.strides_and_offset
+ if result_type is None and all(
+ all(_is_static(i) for i in s) for s in [sizes, strides, source_strides]
+ ):
+ # If any are arith.constant results then this will canonicalize to python int
+ # (which can then be used to fully specific the subview).
+ (
+ offsets,
+ sizes,
+ strides,
+ result_type,
+ ) = _infer_memref_subview_result_type(source.type, offsets, sizes, strides)
+ else:
+ assert (
----------------
ftynse wrote:
Nit: should this rather be `raise ValueError`?
https://github.com/llvm/llvm-project/pull/79393
More information about the Mlir-commits
mailing list