[Mlir-commits] [mlir] Folding extract_strided_metadata input into reinterpret_cast on constant layout (PR #134845)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 9 04:29:41 PDT 2025
================
@@ -2045,44 +2050,65 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
if (!extractStridedMetadata)
return failure();
+
// Check if the reinterpret cast reconstructs a memref with the exact same
// properties as the extract strided metadata.
-
- // First, check that the strides are the same.
SmallVector<OpFoldResult> extractStridesOfr =
extractStridedMetadata.getConstifiedMixedStrides();
SmallVector<OpFoldResult> reinterpretStridesOfr =
op.getConstifiedMixedStrides();
- if (extractStridesOfr.size() != reinterpretStridesOfr.size())
- return failure();
+ auto isReinterpretCastNoop = [&]() -> bool {
+ // First, check that the strides are the same.
+ if (extractStridesOfr.size() != reinterpretStridesOfr.size())
+ return false;
- unsigned rank = op.getType().getRank();
- for (unsigned i = 0; i < rank; ++i) {
- if (extractStridesOfr[i] != reinterpretStridesOfr[i])
- return failure();
- }
+ unsigned rank = op.getType().getRank();
+ for (unsigned i = 0; i < rank; ++i) {
+ if (extractStridesOfr[i] != reinterpretStridesOfr[i])
+ return false;
+ }
- // Second, check the sizes.
- assert(extractStridedMetadata.getSizes().size() ==
- op.getMixedSizes().size() &&
- "Strides and sizes rank must match");
- SmallVector<OpFoldResult> extractSizesOfr =
- extractStridedMetadata.getConstifiedMixedSizes();
- SmallVector<OpFoldResult> reinterpretSizesOfr =
- op.getConstifiedMixedSizes();
- for (unsigned i = 0; i < rank; ++i) {
- if (extractSizesOfr[i] != reinterpretSizesOfr[i])
- return failure();
+ // Second, check the sizes.
+ assert(extractStridedMetadata.getSizes().size() ==
+ op.getMixedSizes().size() &&
+ "Strides and sizes rank must match");
+ SmallVector<OpFoldResult> extractSizesOfr =
+ extractStridedMetadata.getConstifiedMixedSizes();
+ SmallVector<OpFoldResult> reinterpretSizesOfr =
+ op.getConstifiedMixedSizes();
+ for (unsigned i = 0; i < rank; ++i) {
+ if (extractSizesOfr[i] != reinterpretSizesOfr[i])
+ return false;
+ }
+ // Finally, check the offset.
+ assert(op.getMixedOffsets().size() == 1 &&
+ "reinterpret_cast with more than one offset should have been "
+ "rejected by the verifier");
+ OpFoldResult extractOffsetOfr =
+ extractStridedMetadata.getConstifiedMixedOffset();
+ OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
+ return extractOffsetOfr == reinterpretOffsetOfr;
+ };
+
+ if (!isReinterpretCastNoop()) {
+ // If the extract_strided_metadata / reinterpret_cast pair can't be
+ // completely folded, then we could fold the input of the
+ // extract_strided_metadata into the input of the reinterpret_cast
+ // input. For some cases (e.g., static dimensions) the
+ // the extract_strided_metadata is eliminated by dead code elimination.
+ //
+ // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
+ //
+ // We can always fold the input of a extract_strided_metadata operator
+ // to the input of a reinterpret_cast operator, because they point to
+ // the same memory. Note that the reinterpret_cast does not use the
+ // layout of its input memref, only its base memory pointer which is
+ // the same as the base pointer returned by the extract_strided_metadata
+ // operator and the base pointer of the extract_strided_metadata memref
+ // input.
+ op.setOperand(0, extractStridedMetadata.getSource());
----------------
ivangarcia44 wrote:
Done both, thanks
https://github.com/llvm/llvm-project/pull/134845
More information about the Mlir-commits
mailing list