[Mlir-commits] [mlir] [mlir][python] enable memref.subview (PR #79393)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Mon Jan 29 06:03:11 PST 2024
================
@@ -1,5 +1,126 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+import operator
+from itertools import accumulate
+from typing import Optional
from ._memref_ops_gen import *
+from .arith import ConstantOp
+from .transform.structured import _dispatch_mixed_values, MixedValues
+from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType
+
+
+def _is_constant(i):
+ return isinstance(i, Value) and isinstance(i.owner.opview, ConstantOp)
+
+
+def _is_static(i):
+ return (isinstance(i, int) and not ShapedType.is_dynamic_size(i)) or _is_constant(i)
+
+
+def _infer_memref_subview_result_type(
+ source_memref_type, offsets, static_sizes, static_strides
+):
+ source_strides, source_offset = source_memref_type.strides_and_offset
+ # "canonicalize" from tuple|list -> list
+ offsets, static_sizes, static_strides, source_strides = map(
+ list, (offsets, static_sizes, static_strides, source_strides)
+ )
+
+ assert all(
+ all(_is_static(i) for i in s)
+ for s in [
+ static_sizes,
+ static_strides,
+ source_strides,
+ ]
+ ), f"Only inferring from python or mlir integer constant is supported"
----------------
ftynse wrote:
Nit: end python error messages with a period.
Also nit: why the `f` prefix here?
https://github.com/llvm/llvm-project/pull/79393
More information about the Mlir-commits
mailing list