[Mlir-commits] [mlir] Folding extract_strided_metadata input into reinterpret_cast on constant layout (PR #134845)

Matthias Springer llvmlistbot at llvm.org
Tue Apr 8 23:59:06 PDT 2025


================
@@ -2045,44 +2050,65 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
         op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
     if (!extractStridedMetadata)
       return failure();
+
     // Check if the reinterpret cast reconstructs a memref with the exact same
     // properties as the extract strided metadata.
-
-    // First, check that the strides are the same.
     SmallVector<OpFoldResult> extractStridesOfr =
         extractStridedMetadata.getConstifiedMixedStrides();
     SmallVector<OpFoldResult> reinterpretStridesOfr =
         op.getConstifiedMixedStrides();
-    if (extractStridesOfr.size() != reinterpretStridesOfr.size())
-      return failure();
+    auto isReinterpretCastNoop = [&]() -> bool {
+      // First, check that the strides are the same.
+      if (extractStridesOfr.size() != reinterpretStridesOfr.size())
+        return false;
 
-    unsigned rank = op.getType().getRank();
-    for (unsigned i = 0; i < rank; ++i) {
-      if (extractStridesOfr[i] != reinterpretStridesOfr[i])
-        return failure();
-    }
+      unsigned rank = op.getType().getRank();
+      for (unsigned i = 0; i < rank; ++i) {
+        if (extractStridesOfr[i] != reinterpretStridesOfr[i])
+          return false;
+      }
 
-    // Second, check the sizes.
-    assert(extractStridedMetadata.getSizes().size() ==
-               op.getMixedSizes().size() &&
-           "Strides and sizes rank must match");
-    SmallVector<OpFoldResult> extractSizesOfr =
-        extractStridedMetadata.getConstifiedMixedSizes();
-    SmallVector<OpFoldResult> reinterpretSizesOfr =
-        op.getConstifiedMixedSizes();
-    for (unsigned i = 0; i < rank; ++i) {
-      if (extractSizesOfr[i] != reinterpretSizesOfr[i])
-        return failure();
+      // Second, check the sizes.
+      assert(extractStridedMetadata.getSizes().size() ==
+                op.getMixedSizes().size() &&
+            "Strides and sizes rank must match");
+      SmallVector<OpFoldResult> extractSizesOfr =
+          extractStridedMetadata.getConstifiedMixedSizes();
+      SmallVector<OpFoldResult> reinterpretSizesOfr =
+          op.getConstifiedMixedSizes();
+      for (unsigned i = 0; i < rank; ++i) {
+        if (extractSizesOfr[i] != reinterpretSizesOfr[i])
+          return false;
+      }
+      // Finally, check the offset.
+      assert(op.getMixedOffsets().size() == 1 &&
+            "reinterpret_cast with more than one offset should have been "
+            "rejected by the verifier");
+      OpFoldResult extractOffsetOfr =
+          extractStridedMetadata.getConstifiedMixedOffset();
+      OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
+      return extractOffsetOfr == reinterpretOffsetOfr;
+    };
+
+    if (!isReinterpretCastNoop()) {
+      // If the extract_strided_metadata / reinterpret_cast pair can't be
+      // completely folded, then we could fold the input of the
+      // extract_strided_metadata into the input of the reinterpret_cast
+      // input. For some cases (e.g., static dimensions) the 
+      // the extract_strided_metadata is eliminated by dead code elimination.
+      //
+      // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
+      //
+      // We can always fold the input of a extract_strided_metadata operator
+      // to the input of a reinterpret_cast operator, because they point to
+      // the same memory. Note that the reinterpret_cast does not use the
+      // layout of its input memref, only its base memory pointer which is
+      // the same as the base pointer returned by the extract_strided_metadata
+      // operator and the base pointer of the extract_strided_metadata memref
+      // input.
+      op.setOperand(0, extractStridedMetadata.getSource());
----------------
matthias-springer wrote:

Wrap this in `rewriter.modifyOpInPlace`.

https://github.com/llvm/llvm-project/pull/134845


More information about the Mlir-commits mailing list