[all-commits] [llvm/llvm-project] 5083e8: Folding extract_strided_metadata input into reinte...
ivangarcia44 via All-commits
all-commits at lists.llvm.org
Wed Apr 9 07:51:18 PDT 2025
Branch: refs/heads/main
Home: https://github.com/llvm/llvm-project
Commit: 5083e80c14a5c1f0ab40b5df95771ebbdda1adb2
https://github.com/llvm/llvm-project/commit/5083e80c14a5c1f0ab40b5df95771ebbdda1adb2
Author: ivangarcia44 <36650061+ivangarcia44 at users.noreply.github.com>
Date: 2025-04-09 (Wed, 09 Apr 2025)
Changed paths:
M mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
M mlir/test/Dialect/MemRef/canonicalize.mlir
Log Message:
-----------
Folding extract_strided_metadata input into reinterpret_cast (#134845)
We can always fold the input of a extract_strided_metadata operator to
the input of a reinterpret_cast operator, because they point to the same
memory. Note that the reinterpret_cast does not use the layout of its
input memref, only its base memory pointer which is the same as the base
pointer returned by the extract_strided_metadata operator and the base
pointer of the extract_strided_metadata memref input.
Operations like expand_shape, collapse_shape, and subview are lowered to
a pair of extract_strided_metadata and reinterpret_cast like this:
%base_buffer, %offset, %sizes:2, %strides:2 =
memref.extract_strided_metadata %input_memref :
memref<ID1x...xIDNxBaseType> -> memref<f32>, index, index, index, index,
index
%reinterpret_cast = memref.reinterpret_cast %base_buffer to offset:
[%o1], sizes: [%d1,...,%dN], strides: [%s1,...,%N] : memref<f32> to
memref<OD1x...xODNxBaseType >
In many cases the input of the extract_strided_metadata input can be
passed directly into the input of the reinterpret_cast operation like
this (see how %base_buffer is replaced by %input_memref in the
reinterpret_cast above and the input type is updated):
%base_buffer, %offset, %sizes:2, %strides:2 =
memref.extract_strided_metadata %input_memref :
memref<ID1x...xIDNxBaseType> -> memref<f32>, index, index, index, index,
index
%reinterpret_cast = memref.reinterpret_cast %input_memref to offset:
[%o1], sizes: [%d1,...,%dN], strides: [%s1,...,%N] :
memref<ID1x...xIDNxBaseType> to memref<OD1x...xODNxBaseType >
When dealing with static dimensions, the extract_strided_metatdata will
become deadcode and we end up only with a reinterpret_cast:
%reinterpret_cast = memref.reinterpret_cast %input_memref to offset:
[%o1], sizes: [%d1,...,%dN], strides: [%s1,...,%N] :
memref<ID1x...xIDNxBaseType> to memref<OD1x...xODNxBaseType >
Note that reinterpret_cast only reads the base memory pointer from the
input memref (%input_memref above), which is equivalent to the
%base_buffer returned by the extract_strided_metadata operation. Hence
it is legal always to use the extract_strided_metadata input memref
directly in the reinterpret_cast. Note that since this is a pointer,
this operation is legal even when the base pointer values are modified
between the operation pair.
@matthias-springer
@joker-eph
@sahas3
@Hanumanth04
@dixinzhou
@rafaelubalmw
---------
Co-authored-by: Ivan Garcia <igarcia at vdi-ah2ddp-178.dhcp.mathworks.com>
To unsubscribe from these emails, change your notification settings at https://github.com/llvm/llvm-project/settings/notifications
More information about the All-commits
mailing list