[Mlir-commits] [mlir] [mlir] Implement a memory-space cast bubbling-down transform (PR #159454)

Fabian Mora llvmlistbot at llvm.org
Tue Sep 23 10:49:00 PDT 2025


https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/159454

>From a15b8ca9bd287d4ad6af320cae7ae01afb4234bc Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Wed, 17 Sep 2025 21:07:38 +0000
Subject: [PATCH 01/10] [mlir] Implement memory-space cast operand fusion into
 consumers

This commit adds functionality to fuse memory-space casts into consumer operations,
allowing operations to be performed directly on the original memory-space rather
than first casting to a different memory space.

Key changes:
- Introduce `MemorySpaceCastOpInterface` to handle memory-space cast operations
- Create a `FuseMemorySpaceCastsIntoConsumers` pass that identifies and fuses eligible casts
- Add implementation for memref and vector operations to handle memory-space cast fusion
- Add fuseCastOperands method to relevant operations to support the fusion

In particular, in the current implementation only memory-space casts into the default
memory-space can be fused.

Example:

```mlir
func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
    %memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32>
    %c0 = arith.constant 0 : index
    %c4 = arith.constant 4 : index
    %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32>
    %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32>
    %loaded = memref.load %collapsed[%c0] : memref<16xf32>
    %added = arith.addf %loaded, %arg2 : f32
    memref.store %added, %collapsed[%c0] : memref<16xf32>
    %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32
    return %collapsed : memref<16xf32>
}
// mlir-opt --fuse-memory-space-casts-into-consumers
func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
    %c4 = arith.constant 4 : index
    %c0 = arith.constant 0 : index
    %expand_shape = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
    %collapse_shape = memref.collapse_shape %expand_shape [[0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
    %memspacecast = memref.memory_space_cast %collapse_shape : memref<16xf32, 1> to memref<16xf32>
    %0 = memref.load %collapse_shape[%c0] : memref<16xf32, 1>
    %1 = arith.addf %0, %arg2 : f32
    memref.store %1, %collapse_shape[%c0] : memref<16xf32, 1>
    %2 = memref.atomic_rmw addf %arg2, %collapse_shape[%c4] : (f32, memref<16xf32, 1>) -> f32
    return %memspacecast : memref<16xf32>
}
```

Signed-off-by: Fabian Mora <fabian.mora-cordero at amd.com>
---
 mlir/include/mlir/Dialect/MemRef/IR/MemRef.h  |   1 +
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |  19 +-
 .../mlir/Dialect/Vector/IR/VectorOps.h        |   1 +
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  16 +-
 mlir/include/mlir/Interfaces/CMakeLists.txt   |   1 +
 .../include/mlir/Interfaces/MemOpInterfaces.h |  37 +++
 .../mlir/Interfaces/MemOpInterfaces.td        | 114 +++++++
 .../FuseMemorySpaceCastsIntoConsumers.h       |  20 ++
 mlir/include/mlir/Transforms/Passes.h         |   1 +
 mlir/include/mlir/Transforms/Passes.td        |  40 +++
 mlir/lib/Dialect/MemRef/IR/CMakeLists.txt     |   1 +
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 158 ++++++++++
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  64 ++++
 mlir/lib/Interfaces/CMakeLists.txt            |   2 +
 mlir/lib/Interfaces/MemOpInterfaces.cpp       |  73 +++++
 mlir/lib/Transforms/CMakeLists.txt            |   2 +
 .../FuseMemorySpaceCastsIntoConsumers.cpp     |  73 +++++
 .../test-fuse-casts-into-consumers.mlir       | 281 ++++++++++++++++++
 18 files changed, 897 insertions(+), 7 deletions(-)
 create mode 100644 mlir/include/mlir/Interfaces/MemOpInterfaces.h
 create mode 100644 mlir/include/mlir/Interfaces/MemOpInterfaces.td
 create mode 100644 mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h
 create mode 100644 mlir/lib/Interfaces/MemOpInterfaces.cpp
 create mode 100644 mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp
 create mode 100644 mlir/test/Transforms/test-fuse-casts-into-consumers.mlir

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index bdec699eb4ce4..30f33ed2fd1d6 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -18,6 +18,7 @@
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/MemOpInterfaces.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 671cc05e963b4..238a767ac8b73 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -15,6 +15,7 @@ include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/MemOpInterfaces.td"
 include "mlir/Interfaces/MemorySlotInterfaces.td"
 include "mlir/Interfaces/ShapedOpInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -145,7 +146,8 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
       Pure,
       ViewLikeOpInterface,
-      SameOperandsAndResultType
+      SameOperandsAndResultType,
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
     ]> {
   let summary =
       "assumption that gives alignment information to the input memref";
@@ -456,6 +458,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
 def MemRef_CastOp : MemRef_Op<"cast", [
       DeclareOpInterfaceMethods<CastOpInterface>,
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
       MemRefsNormalizable,
       Pure,
       SameOperandsAndResultShape,
@@ -1194,6 +1197,7 @@ def LoadOp : MemRef_Op<"load",
                      "memref", "result",
                      "::llvm::cast<MemRefType>($_self).getElementType()">,
       MemRefsNormalizable,
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
       DeclareOpInterfaceMethods<PromotableMemOpInterface>,
       DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
   let summary = "load operation";
@@ -1284,6 +1288,7 @@ def LoadOp : MemRef_Op<"load",
 def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
       DeclareOpInterfaceMethods<CastOpInterface>,
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+      DeclareOpInterfaceMethods<MemorySpaceCastOpInterface>,
       MemRefsNormalizable,
       Pure,
       SameOperandsAndResultElementType,
@@ -1376,6 +1381,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
 def MemRef_ReinterpretCastOp
   : MemRef_OpWithOffsetSizesAndStrides<"reinterpret_cast", [
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
       AttrSizedOperandSegments,
       MemRefsNormalizable,
       Pure,
@@ -1603,6 +1609,7 @@ def MemRef_RankOp : MemRef_Op<"rank", [Pure]> {
 
 def MemRef_ReshapeOp: MemRef_Op<"reshape", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
     Pure,
     ViewLikeOpInterface]>  {
   let summary = "memref reshape operation";
@@ -1701,6 +1708,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
 
 def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
   let summary = "operation to produce a memref with a higher rank.";
   let description = [{
@@ -1822,7 +1830,9 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
 }
 
 def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
-    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
+  ]> {
   let summary = "operation to produce a memref with a smaller rank.";
   let description = [{
     The `memref.collapse_shape` op produces a new view with a smaller rank
@@ -1929,6 +1939,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
                      "memref", "value",
                      "::llvm::cast<MemRefType>($_self).getElementType()">,
       MemRefsNormalizable,
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
       DeclareOpInterfaceMethods<PromotableMemOpInterface>,
       DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
   let summary = "store operation";
@@ -2006,6 +2017,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
 
 def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
     DeclareOpInterfaceMethods<ViewLikeOpInterface>,
     AttrSizedOperandSegments,
     OffsetSizeAndStrideOpInterface,
@@ -2281,6 +2293,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
 
 def MemRef_TransposeOp : MemRef_Op<"transpose", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
     Pure]>,
     Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
     Results<(outs AnyStridedMemRef)> {
@@ -2316,6 +2329,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [
 
 def MemRef_ViewOp : MemRef_Op<"view", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
     DeclareOpInterfaceMethods<ViewLikeOpInterface>,
     Pure]> {
   let summary = "memref view operation";
@@ -2392,6 +2406,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
 //===----------------------------------------------------------------------===//
 
 def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
       AllTypesMatch<["value", "result"]>,
       TypesMatchWith<"value type matches element type of memref",
                      "memref", "value",
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 63410b8bea747..bbf55f5d507e3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -27,6 +27,7 @@
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/IndexingMapOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/MemOpInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 26d06624cb976..93e9bfc78ea75 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -24,6 +24,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/IndexingMapOpInterface.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/MemOpInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/VectorInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
@@ -1246,6 +1247,7 @@ def Vector_TransferReadOp :
       DeclareOpInterfaceMethods<MaskableOpInterface>,
       DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
       DeclareOpInterfaceMethods<ConditionallySpeculatable>,
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
       AttrSizedOperandSegments,
       DestinationStyleOpInterface
     ]>,
@@ -1493,6 +1495,7 @@ def Vector_TransferWriteOp :
       DeclareOpInterfaceMethods<MaskableOpInterface>,
       DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
       DeclareOpInterfaceMethods<ConditionallySpeculatable>,
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
       AttrSizedOperandSegments,
       DestinationStyleOpInterface
   ]>,
@@ -1649,6 +1652,7 @@ def Vector_TransferWriteOp :
 
 def Vector_LoadOp : Vector_Op<"load", [
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
   ]> {
   let summary = "reads an n-D slice of memory into an n-D vector";
   let description = [{
@@ -1765,6 +1769,7 @@ def Vector_LoadOp : Vector_Op<"load", [
 
 def Vector_StoreOp : Vector_Op<"store", [
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
   ]> {
   let summary = "writes an n-D vector to an n-D slice of memory";
   let description = [{
@@ -1869,7 +1874,7 @@ def Vector_StoreOp : Vector_Op<"store", [
 }
 
 def Vector_MaskedLoadOp :
-  Vector_Op<"maskedload">,
+  Vector_Op<"maskedload", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -1961,7 +1966,7 @@ def Vector_MaskedLoadOp :
 }
 
 def Vector_MaskedStoreOp :
-  Vector_Op<"maskedstore">,
+  Vector_Op<"maskedstore", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2041,6 +2046,7 @@ def Vector_MaskedStoreOp :
 def Vector_GatherOp :
   Vector_Op<"gather", [
     DeclareOpInterfaceMethods<MaskableOpInterface>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
   ]>,
     Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
@@ -2144,7 +2150,7 @@ def Vector_GatherOp :
 }
 
 def Vector_ScatterOp :
-  Vector_Op<"scatter">,
+  Vector_Op<"scatter", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$offsets,
                VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
@@ -2229,7 +2235,7 @@ def Vector_ScatterOp :
 }
 
 def Vector_ExpandLoadOp :
-  Vector_Op<"expandload">,
+  Vector_Op<"expandload", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
                FixedVectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2317,7 +2323,7 @@ def Vector_ExpandLoadOp :
 }
 
 def Vector_CompressStoreOp :
-  Vector_Op<"compressstore">,
+  Vector_Op<"compressstore", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
                FixedVectorOfNonZeroRankOf<[I1]>:$mask,
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index 2add220fdfb7c..a5feb592045c0 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_interface(IndexingMapOpInterface)
 add_mlir_interface(InferIntRangeInterface)
 add_mlir_interface(InferTypeOpInterface)
 add_mlir_interface(LoopLikeInterface)
+add_mlir_interface(MemOpInterfaces)
 add_mlir_interface(ParallelCombiningOpInterface)
 add_mlir_interface(RuntimeVerifiableOpInterface)
 add_mlir_interface(ShapedOpInterfaces)
diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.h b/mlir/include/mlir/Interfaces/MemOpInterfaces.h
new file mode 100644
index 0000000000000..cc9f4c6b3882e
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.h
@@ -0,0 +1,37 @@
+//===- MemOpInterfaces.h - Memory operation interfaces ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains declarations of interfaces for operations that interact
+// with memory.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_MEMOPINTERFACES_H
+#define MLIR_INTERFACES_MEMOPINTERFACES_H
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace detail {
+/// Attempt to verify the given memory space cast operation.
+LogicalResult verifyMemorySpaceCastOpInterface(Operation *op);
+
+/// Tries to fuse inplace a `MemorySpaceCastOpInterface` operation referenced by
+/// `operand`. On success, it returns `results`, and sets `modifiedInPlace` to
+/// true. It returns failure if `operand` doesn't reference a
+/// `MemorySpaceCastOpInterface` op.
+FailureOr<SmallVector<Value>>
+fuseInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results,
+                               bool &modifiedInPlace);
+} // namespace detail
+} // namespace mlir
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/MemOpInterfaces.h.inc"
+
+#endif // MLIR_INTERFACES_MEMOPINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.td b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
new file mode 100644
index 0000000000000..0b8ba19171fb7
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
@@ -0,0 +1,114 @@
+//===- MemOpInterfaces.td - Memory operation interfaces -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains interfaces for operations that interact with memory.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_MEMOPINTERFACES_TD
+#define MLIR_INTERFACES_MEMOPINTERFACES_TD
+
+include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+def FuseMemorySpaceCastConsumerOpInterface :
+    OpInterface<"FuseMemorySpaceCastConsumerOpInterface"> {
+  let description = [{
+    An interface to fuse memory-space cast operands into a consumer operation.
+    It is the responsibility of the interface to determine which casts can be
+    fused into the operation.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<[{
+        Attempt to fuse the incoming cast-like operands. Returns `success`
+        and any new results on fusion success, otherwise it returns failure.
+        If new results are produced, these must be compatible with the original
+        operation results.
+
+        The `modifiedInPlace` parameter indicates whether the operation was
+        modified in place. If `false` and the fusion succeeded, then the
+        interface guarantees it is valid to erase the original operation.
+        If `true`, then the interface must guarantee no operations were created
+        by the method, and that no further IR modification is necessary. It is
+        considered an error if `modifiedInPlace` is true and the fusion failed.
+
+        Any implementations of this method must not erase/replace the original
+        operation, instead it is the caller responsibility to erase or replace
+        the op with the results provided by the method.
+
+        Finally, any implementations of this method have to guarantee that the
+        IR remains valid at all times.
+      }],
+      "::llvm::FailureOr<::llvm::SmallVector<::mlir::Value>>", "fuseCastOperands",
+      (ins "::mlir::OpBuilder &":$builder, "bool &":$modifiedInPlace)
+    >,
+  ];
+}
+
+def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
+  let description = [{
+    An interface for operations that perform memory-space casts. This
+    interface assumes that the cast operation is `pure`.
+
+    These operations expect to have a well-defined ptr-like operand, and
+    a well-defined target ptr-like result.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<[{
+        Returns the source ptr-like value.
+      }],
+      "::mlir::TypedValue<::mlir::PtrLikeTypeInterface>",  "getSourcePtr"
+    >,
+    InterfaceMethod<[{
+        Returns the target ptr-like value.
+      }],
+      "::mlir::TypedValue<::mlir::PtrLikeTypeInterface>", "getTargetPtr"
+    >,
+    InterfaceMethod<[{
+        Returns whether the memory space cast specified by `tgt` and `src`
+        is supported.
+      }],
+      "bool", "isValidMemorySpaceCast",
+      (ins "::mlir::PtrLikeTypeInterface":$tgt,
+           "::mlir::PtrLikeTypeInterface":$src)
+    >,
+    InterfaceMethod<[{
+        Clones the memory space cast op with the given source and target type.
+      }],
+      "::mlir::MemorySpaceCastOpInterface", "cloneMemorySpaceCastOp",
+      (ins "::mlir::OpBuilder &":$builder, "::mlir::Type":$tgt,
+           "::mlir::Value":$src)
+    >,
+    InterfaceMethod<[{
+        Returns whether the cast allows to be fused.
+      }],
+      "bool", "isFusableMemorySpaceCast"
+    >
+  ];
+  let verify = [{
+    return ::mlir::detail::verifyMemorySpaceCastOpInterface($_op);
+  }];
+  let dependentTraits = [Pure];
+  let extraClassDeclaration = [{
+    /// Returns the underlying `MemorySpaceCastOpInterface` op if `value`
+    /// is produced by a `MemorySpaceCastOpInterface` op, and
+    /// `isFusableMemorySpaceCast` returns true, otherwise it returns null.
+    static ::mlir::MemorySpaceCastOpInterface
+    getIfFusableCast(::mlir::Value value) {
+      auto op = ::llvm::dyn_cast_or_null<::mlir::MemorySpaceCastOpInterface>(
+        value.getDefiningOp());
+      if (!op || !op.isFusableMemorySpaceCast())
+        return nullptr;
+      return op;
+    }
+  }];
+}
+
+#endif // MLIR_INTERFACES_MEMOPINTERFACES_TD
diff --git a/mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h b/mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h
new file mode 100644
index 0000000000000..9333f92a10289
--- /dev/null
+++ b/mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h
@@ -0,0 +1,20 @@
+//===-- FuseMemorySpaceCastsIntoConsumers.h - Cast fusion patterns -C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_FUSEMEMORYSPACECASTSINTOCONSUMERS_H
+#define MLIR_TRANSFORMS_FUSEMEMORYSPACECASTSINTOCONSUMERS_H
+
+namespace mlir {
+class RewritePatternSet;
+/// Collect a set of patterns to fuse memory-space cast operations into
+/// consumers.
+void populateFuseMemorySpaceCastIntoConsumersPatterns(
+    RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_FUSEMEMORYSPACECASTSINTOCONSUMERS_H
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 9cd2ef34e15ea..610a9671fede8 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -46,6 +46,7 @@ class GreedyRewriteConfig;
 #define GEN_PASS_DECL_SYMBOLPRIVATIZE
 #define GEN_PASS_DECL_TOPOLOGICALSORT
 #define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
+#define GEN_PASS_DECL_FUSEMEMORYSPACECASTSINTOCONSUMERS
 #include "mlir/Transforms/Passes.h.inc"
 
 /// Creates an instance of the Canonicalizer pass, configured with default
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index beb59784947c5..69280e3d443ea 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -585,4 +585,44 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
   ];
 }
 
+def FuseMemorySpaceCastsIntoConsumers :
+    Pass<"fuse-memory-space-casts-into-consumers"> {
+  let summary = "Fuses memory-space cast operations into consumers.";
+  let description = [{
+    This pass tries to fuse all possible memory-space cast operations into their consumers.
+    It does this by looking for `FuseMemorySpaceCastConsumerOpInterface`
+    operations, and invoking the interface methods to perform the fusion.
+
+    Example:
+
+    ```mlir
+    func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
+      %memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32>
+      %c0 = arith.constant 0 : index
+      %c4 = arith.constant 4 : index
+      %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32>
+      %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32>
+      %loaded = memref.load %collapsed[%c0] : memref<16xf32>
+      %added = arith.addf %loaded, %arg2 : f32
+      memref.store %added, %collapsed[%c0] : memref<16xf32>
+      %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32
+      return %collapsed : memref<16xf32>
+    }
+    // mlir-opt --fuse-casts-into-consumers
+    func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
+      %c4 = arith.constant 4 : index
+      %c0 = arith.constant 0 : index
+      %expand_shape = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
+      %collapse_shape = memref.collapse_shape %expand_shape [[0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
+      %memspacecast = memref.memory_space_cast %collapse_shape : memref<16xf32, 1> to memref<16xf32>
+      %0 = memref.load %collapse_shape[%c0] : memref<16xf32, 1>
+      %1 = arith.addf %0, %arg2 : f32
+      memref.store %1, %collapse_shape[%c0] : memref<16xf32, 1>
+      %2 = memref.atomic_rmw addf %arg2, %collapse_shape[%c4] : (f32, memref<16xf32, 1>) -> f32
+      return %memspacecast : memref<16xf32>
+    }
+    ```
+  }];
+}
+
 #endif // MLIR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index 734294bd014c6..e25a0121a3359 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
   MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRIR
+  MLIRMemOpInterfaces
   MLIRMemorySlotInterfaces
   MLIRShapedOpInterfaces
   MLIRSideEffectInterfaces
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 5d15d5f6e3de4..0ddb2b0ca1645 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -111,6 +111,65 @@ static void constifyIndexValues(SmallVectorImpl<OpFoldResult> &values,
   }
 }
 
+/// Helper function to retrieve a fusable memory-space cast, and the
+/// corresponding new result memref type.
+static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
+getFuseCastInfo(BaseMemRefType resultTy, Value src) {
+  MemorySpaceCastOpInterface castOp =
+      MemorySpaceCastOpInterface::getIfFusableCast(src);
+
+  // Bail if the cast is not fusable.
+  if (!castOp)
+    return {};
+
+  // Transform the source and target type of `castOp` to have the same metadata
+  // as `resultTy`. Bail if not possible.
+  FailureOr<PtrLikeTypeInterface> srcTy = resultTy.clonePtrWith(
+      castOp.getSourcePtr().getType().getMemorySpace(), std::nullopt);
+  if (failed(srcTy))
+    return {};
+
+  FailureOr<PtrLikeTypeInterface> tgtTy = resultTy.clonePtrWith(
+      castOp.getTargetPtr().getType().getMemorySpace(), std::nullopt);
+  if (failed(tgtTy))
+    return {};
+
+  // Check if this is a valid memory-space cast.
+  if (!castOp.isValidMemorySpaceCast(*tgtTy, *srcTy))
+    return {};
+
+  return std::make_tuple(castOp, *tgtTy, *srcTy);
+}
+
+/// Implementation of `fuseCastOperands` method for memref operations that
+/// return a single memref result.
+template <typename ConcreteOpTy>
+static FailureOr<SmallVector<Value>>
+fuseCastOperandsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder,
+                                  bool &modifiedInPlace, OpOperand &src) {
+  auto [castOp, tgtTy, resTy] = getFuseCastInfo(op.getType(), src.get());
+  // Bail if we cannot cast.
+  if (!castOp)
+    return failure();
+
+  modifiedInPlace = false;
+
+  // Create the new operands.
+  SmallVector<Value> operands;
+  llvm::append_range(operands, op->getOperands());
+  operands[src.getOperandNumber()] = castOp.getSourcePtr();
+
+  // Create the fused op and results.
+  auto newOp = ConcreteOpTy::create(
+      builder, op.getLoc(), TypeRange(resTy), operands, op.getProperties(),
+      llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
+
+  // Insert a memory-space cast to the original memory space of the op.
+  MemorySpaceCastOpInterface result =
+      castOp.cloneMemorySpaceCastOp(builder, tgtTy, newOp.getResult());
+  return SmallVector<Value>({result.getTargetPtr()});
+}
+
 //===----------------------------------------------------------------------===//
 // AllocOp / AllocaOp
 //===----------------------------------------------------------------------===//
@@ -542,6 +601,12 @@ OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
   return getMemref();
 }
 
+FailureOr<SmallVector<Value>>
+AssumeAlignmentOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+                                           getMemrefMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // CastOp
 //===----------------------------------------------------------------------===//
@@ -710,6 +775,12 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
 }
 
+FailureOr<SmallVector<Value>> CastOp::fuseCastOperands(OpBuilder &builder,
+                                                       bool &modifiedInPlace) {
+  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+                                           getSourceMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // CopyOp
 //===----------------------------------------------------------------------===//
@@ -1601,6 +1672,12 @@ OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
   return OpFoldResult();
 }
 
+FailureOr<SmallVector<Value>> LoadOp::fuseCastOperands(OpBuilder &builder,
+                                                       bool &modifiedInPlace) {
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getMemrefMutable(), getResult(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // MemorySpaceCastOp
 //===----------------------------------------------------------------------===//
@@ -1645,6 +1722,33 @@ OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
   return Value{};
 }
 
+TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
+  return cast<TypedValue<PtrLikeTypeInterface>>(getSource());
+}
+
+TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
+  return cast<TypedValue<PtrLikeTypeInterface>>(getDest());
+}
+
+bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
+                                               PtrLikeTypeInterface src) {
+  return isa<MemRefType>(tgt) &&
+         tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
+}
+
+MemorySpaceCastOpInterface
+MemorySpaceCastOp::cloneMemorySpaceCastOp(OpBuilder &b, Type tgt, Value src) {
+  assert(isValidMemorySpaceCast(cast<PtrLikeTypeInterface>(tgt),
+                                cast<PtrLikeTypeInterface>(src.getType())) &&
+         "invalid arguments");
+  return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
+}
+
+bool MemorySpaceCastOp::isFusableMemorySpaceCast() {
+  // Only allow fusion when this is discarding information.
+  return getDest().getType().getMemorySpace() == nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // PrefetchOp
 //===----------------------------------------------------------------------===//
@@ -2041,6 +2145,12 @@ void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
 }
 
+FailureOr<SmallVector<Value>>
+ReinterpretCastOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+                                           getSourceMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // Reassociative reshape ops
 //===----------------------------------------------------------------------===//
@@ -2348,6 +2458,12 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
       ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
 }
 
+FailureOr<SmallVector<Value>>
+ExpandShapeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+                                           getSrcMutable());
+}
+
 /// Compute the layout map after collapsing a given source MemRef type with the
 /// specified reassociation indices.
 ///
@@ -2569,6 +2685,12 @@ OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
                                                        adaptor.getOperands());
 }
 
+FailureOr<SmallVector<Value>>
+CollapseShapeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+                                           getSrcMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // ReshapeOp
 //===----------------------------------------------------------------------===//
@@ -2609,6 +2731,12 @@ LogicalResult ReshapeOp::verify() {
   return success();
 }
 
+FailureOr<SmallVector<Value>>
+ReshapeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+                                           getSourceMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // StoreOp
 //===----------------------------------------------------------------------===//
@@ -2626,6 +2754,12 @@ LogicalResult StoreOp::fold(FoldAdaptor adaptor,
   return foldMemRefCast(*this, getValueToStore());
 }
 
+FailureOr<SmallVector<Value>> StoreOp::fuseCastOperands(OpBuilder &builder,
+                                                        bool &modifiedInPlace) {
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getMemrefMutable(), ValueRange(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // SubViewOp
 //===----------------------------------------------------------------------===//
@@ -3282,6 +3416,12 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+FailureOr<SmallVector<Value>>
+SubViewOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+                                           getSourceMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // TransposeOp
 //===----------------------------------------------------------------------===//
@@ -3382,6 +3522,12 @@ OpFoldResult TransposeOp::fold(FoldAdaptor) {
   return {};
 }
 
+FailureOr<SmallVector<Value>>
+TransposeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+                                           getInMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // ViewOp
 //===----------------------------------------------------------------------===//
@@ -3525,6 +3671,12 @@ void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
 }
 
+FailureOr<SmallVector<Value>> ViewOp::fuseCastOperands(OpBuilder &builder,
+                                                       bool &modifiedInPlace) {
+  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+                                           getSourceMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // AtomicRMWOp
 //===----------------------------------------------------------------------===//
@@ -3570,6 +3722,12 @@ OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
   return OpFoldResult();
 }
 
+FailureOr<SmallVector<Value>>
+AtomicRMWOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getMemrefMutable(), getResult(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8d6e263934fb4..806e6c1c070aa 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5087,6 +5087,14 @@ void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<TransferReadAfterWriteToBroadcast>(context);
 }
 
+FailureOr<SmallVector<Value>>
+TransferReadOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  if (!hasPureBufferSemantics())
+    return failure();
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getBaseMutable(), getResult(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // TransferWriteOp
 //===----------------------------------------------------------------------===//
@@ -5574,6 +5582,14 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
 }
 
+FailureOr<SmallVector<Value>>
+TransferWriteOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  if (!hasPureBufferSemantics())
+    return failure();
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getBaseMutable(), ValueRange(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // LoadOp
 //===----------------------------------------------------------------------===//
@@ -5628,6 +5644,12 @@ std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
   return llvm::to_vector<4>(getVectorType().getShape());
 }
 
+FailureOr<SmallVector<Value>> LoadOp::fuseCastOperands(OpBuilder &builder,
+                                                       bool &modifiedInPlace) {
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getBaseMutable(), getResult(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // StoreOp
 //===----------------------------------------------------------------------===//
@@ -5667,6 +5689,12 @@ std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
   return llvm::to_vector<4>(getVectorType().getShape());
 }
 
+FailureOr<SmallVector<Value>> StoreOp::fuseCastOperands(OpBuilder &builder,
+                                                        bool &modifiedInPlace) {
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getBaseMutable(), ValueRange(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // MaskedLoadOp
 //===----------------------------------------------------------------------===//
@@ -5721,6 +5749,12 @@ OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
   return OpFoldResult();
 }
 
+FailureOr<SmallVector<Value>>
+MaskedLoadOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getBaseMutable(), getResult(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // MaskedStoreOp
 //===----------------------------------------------------------------------===//
@@ -5771,6 +5805,12 @@ LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
   return memref::foldMemRefCast(*this);
 }
 
+FailureOr<SmallVector<Value>>
+MaskedStoreOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getBaseMutable(), ValueRange(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // GatherOp
 //===----------------------------------------------------------------------===//
@@ -5874,6 +5914,12 @@ void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<GatherFolder, FoldContiguousGather>(context);
 }
 
+FailureOr<SmallVector<Value>>
+GatherOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getBaseMutable(), getResult(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // ScatterOp
 //===----------------------------------------------------------------------===//
@@ -5936,6 +5982,12 @@ void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ScatterFolder, FoldContiguousScatter>(context);
 }
 
+FailureOr<SmallVector<Value>>
+ScatterOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getBaseMutable(), ValueRange(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // ExpandLoadOp
 //===----------------------------------------------------------------------===//
@@ -5984,6 +6036,12 @@ void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ExpandLoadFolder>(context);
 }
 
+FailureOr<SmallVector<Value>>
+ExpandLoadOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getBaseMutable(), getResult(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // CompressStoreOp
 //===----------------------------------------------------------------------===//
@@ -6030,6 +6088,12 @@ void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<CompressStoreFolder>(context);
 }
 
+FailureOr<SmallVector<Value>>
+CompressStoreOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getBaseMutable(), ValueRange(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // ShapeCastOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index fdc19844702bc..388de1c3e5abf 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -11,6 +11,7 @@ set(LLVM_OPTIONAL_SOURCES
   InferIntRangeInterface.cpp
   InferTypeOpInterface.cpp
   LoopLikeInterface.cpp
+  MemOpInterfaces.cpp
   MemorySlotInterfaces.cpp
   ParallelCombiningOpInterface.cpp
   RuntimeVerifiableOpInterface.cpp
@@ -79,6 +80,7 @@ add_mlir_library(MLIRLoopLikeInterface
   MLIRFunctionInterfaces
 )
 
+add_mlir_interface_library(MemOpInterfaces)
 add_mlir_interface_library(MemorySlotInterfaces)
 add_mlir_interface_library(ParallelCombiningOpInterface)
 add_mlir_interface_library(RuntimeVerifiableOpInterface)
diff --git a/mlir/lib/Interfaces/MemOpInterfaces.cpp b/mlir/lib/Interfaces/MemOpInterfaces.cpp
new file mode 100644
index 0000000000000..013d828da1d66
--- /dev/null
+++ b/mlir/lib/Interfaces/MemOpInterfaces.cpp
@@ -0,0 +1,73 @@
+//===- MemOpInterfaces.cpp - Memory operation interfaces ---------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/MemOpInterfaces.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
+
+using namespace mlir;
+
+LogicalResult mlir::detail::verifyMemorySpaceCastOpInterface(Operation *op) {
+  auto memCastOp = cast<MemorySpaceCastOpInterface>(op);
+
+  // Verify that the source and target pointers are valid
+  Value sourcePtr = memCastOp.getSourcePtr();
+  Value targetPtr = memCastOp.getTargetPtr();
+
+  if (!sourcePtr || !targetPtr) {
+    return op->emitError()
+           << "memory space cast op must have valid source and target pointers";
+  }
+
+  if (sourcePtr.getType().getTypeID() != targetPtr.getType().getTypeID()) {
+    return op->emitError()
+           << "expected source and target types of the same kind";
+  }
+
+  // Verify the Types are of `PtrLikeTypeInterface` type.
+  auto sourceType = dyn_cast<PtrLikeTypeInterface>(sourcePtr.getType());
+  if (!sourceType) {
+    return op->emitError()
+           << "source type must implement `PtrLikeTypeInterface`, but got: "
+           << sourcePtr.getType();
+  }
+
+  auto targetType = dyn_cast<PtrLikeTypeInterface>(targetPtr.getType());
+  if (!targetType) {
+    return op->emitError()
+           << "target type must implement `PtrLikeTypeInterface`, but got: "
+           << targetPtr.getType();
+  }
+
+  // Verify that the operation has exactly one result
+  if (op->getNumResults() != 1) {
+    return op->emitError()
+           << "memory space cast op must have exactly one result";
+  }
+
+  return success();
+}
+
+FailureOr<SmallVector<Value>> mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+    OpOperand &operand, ValueRange results, bool &modifiedInPlace) {
+  MemorySpaceCastOpInterface castOp =
+      MemorySpaceCastOpInterface::getIfFusableCast(operand.get());
+
+  // Bail if the src is not produced by a `MemorySpaceCastOpInterface`.
+  if (!castOp)
+    return failure();
+
+  // Modify the op.
+  modifiedInPlace = true;
+  operand.set(castOp.getSourcePtr());
+  return llvm::to_vector_of<Value>(results);
+}
+
+#include "mlir/Interfaces/MemOpInterfaces.cpp.inc"
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 058039e47313e..e9a7d3e4abe99 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_library(MLIRTransforms
   ControlFlowSink.cpp
   CSE.cpp
   GenerateRuntimeVerification.cpp
+  FuseMemorySpaceCastsIntoConsumers.cpp
   InlinerPass.cpp
   LocationSnapshot.cpp
   LoopInvariantCodeMotion.cpp
@@ -31,6 +32,7 @@ add_mlir_library(MLIRTransforms
   MLIRAnalysis
   MLIRFunctionInterfaces
   MLIRLoopLikeInterface
+  MLIRMemOpInterfaces
   MLIRMemorySlotInterfaces
   MLIRPass
   MLIRRuntimeVerifiableOpInterface
diff --git a/mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp b/mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp
new file mode 100644
index 0000000000000..010b88ac12de2
--- /dev/null
+++ b/mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp
@@ -0,0 +1,73 @@
+//===- FuseMemorySpaceCastsIntoConsumers.cpp - Fuse casts transform -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/MemOpInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+
+namespace mlir {
+#define GEN_PASS_DEF_FUSEMEMORYSPACECASTSINTOCONSUMERS
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+namespace {
+//===----------------------------------------------------------------------===//
+// FuseCastsPattern pattern
+//===----------------------------------------------------------------------===//
+/// Pattern to fuse casts into consumer operations.
+struct FuseCastsPattern
+    : public OpInterfaceRewritePattern<FuseMemorySpaceCastConsumerOpInterface> {
+  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+
+  LogicalResult matchAndRewrite(FuseMemorySpaceCastConsumerOpInterface op,
+                                PatternRewriter &rewriter) const override {
+    bool modifiedInPlace = false;
+    FailureOr<SmallVector<Value>> results =
+        op.fuseCastOperands(rewriter, modifiedInPlace);
+    assert((!failed(results) || !modifiedInPlace) &&
+           "expected `modifiedInPlace` to be false on fusion failure");
+    if (failed(results))
+      return failure();
+    if (modifiedInPlace) {
+      rewriter.modifyOpInPlace(op, []() {});
+      return success();
+    }
+    rewriter.replaceOp(op, *results);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// FuseMemorySpaceCastsIntoConsumers pass
+//===----------------------------------------------------------------------===//
+
+struct FuseMemorySpaceCastsIntoConsumers
+    : public impl::FuseMemorySpaceCastsIntoConsumersBase<
+          FuseMemorySpaceCastsIntoConsumers> {
+  using impl::FuseMemorySpaceCastsIntoConsumersBase<
+      FuseMemorySpaceCastsIntoConsumers>::FuseMemorySpaceCastsIntoConsumersBase;
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateFuseMemorySpaceCastIntoConsumersPatterns(patterns);
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+void mlir::populateFuseMemorySpaceCastIntoConsumersPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<FuseCastsPattern>(patterns.getContext());
+}
diff --git a/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir b/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir
new file mode 100644
index 0000000000000..69a15f429cec2
--- /dev/null
+++ b/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir
@@ -0,0 +1,281 @@
+// RUN: mlir-opt %s --fuse-memory-space-casts-into-consumers | FileCheck %s
+
+#map = affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>
+// CHECK-LABEL:   func.func @load_store(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) {
+// CHECK:           %[[VAL_0:.*]] = memref.load %[[ARG0]]{{\[}}%[[ARG1]]] : memref<?xf32, 1>
+// CHECK:           memref.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[ARG1]]] : memref<?xf32, 1>
+// CHECK:           return
+// CHECK:         }
+func.func @load_store(%arg0: memref<?xf32, 1>, %arg1: index) {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %0 = memref.load %memspacecast[%arg1] : memref<?xf32>
+  memref.store %0, %memspacecast[%arg1] : memref<?xf32>
+  return
+}
+
+// CHECK-LABEL:   func.func @load_store_unfoldable(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) {
+// CHECK:           %[[VAL_0:.*]] = memref.memory_space_cast %[[ARG0]] : memref<?xf32, 1> to memref<?xf32, 2>
+// CHECK:           %[[VAL_1:.*]] = memref.load %[[VAL_0]]{{\[}}%[[ARG1]]] : memref<?xf32, 2>
+// CHECK:           memref.store %[[VAL_1]], %[[VAL_0]]{{\[}}%[[ARG1]]] : memref<?xf32, 2>
+// CHECK:           return
+// CHECK:         }
+func.func @load_store_unfoldable(%arg0: memref<?xf32, 1>, %arg1: index) {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32, 2>
+  %0 = memref.load %memspacecast[%arg1] : memref<?xf32, 2>
+  memref.store %0, %memspacecast[%arg1] : memref<?xf32, 2>
+  return
+}
+
+// CHECK-LABEL:   func.func @view(
+// CHECK-SAME:                    %[[ARG0:.*]]: memref<?xi8, 1>,
+// CHECK-SAME:                    %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> memref<?x?xi8> {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 100 : index
+// CHECK:           %[[VAL_1:.*]] = memref.view %[[ARG0]]{{\[}}%[[ARG1]]]{{\[}}%[[ARG2]], %[[VAL_0]]] : memref<?xi8, 1> to memref<?x?xi8, 1>
+// CHECK:           %[[VAL_2:.*]] = memref.memory_space_cast %[[VAL_1]] : memref<?x?xi8, 1> to memref<?x?xi8>
+// CHECK:           return %[[VAL_2]] : memref<?x?xi8>
+// CHECK:         }
+func.func @view(%arg0: memref<?xi8, 1>, %arg1: index, %arg2: index) -> memref<?x?xi8> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xi8, 1> to memref<?xi8>
+  %c100 = arith.constant 100 : index
+  %view = memref.view %memspacecast[%arg1][%arg2, %c100] : memref<?xi8> to memref<?x?xi8>
+  return %view : memref<?x?xi8>
+}
+
+// CHECK-LABEL:   func.func @subview(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?x?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) -> memref<8x2xf32, strided<[?, 2], offset: ?>> {
+// CHECK:           %[[VAL_0:.*]] = memref.subview %[[ARG0]][4, 2] [8, 2] [3, 2] : memref<?x?xf32, 1> to memref<8x2xf32, strided<[?, 2], offset: ?>, 1>
+// CHECK:           %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<8x2xf32, strided<[?, 2], offset: ?>, 1> to memref<8x2xf32, strided<[?, 2], offset: ?>>
+// CHECK:           return %[[VAL_1]] : memref<8x2xf32, strided<[?, 2], offset: ?>>
+// CHECK:         }
+func.func @subview(%arg0: memref<?x?xf32, 1>, %arg1: index) -> memref<8x2xf32, strided<[?, 2], offset: ?>> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?x?xf32, 1> to memref<?x?xf32>
+  %subview = memref.subview %memspacecast[4, 2] [8, 2] [3, 2] : memref<?x?xf32> to memref<8x2xf32, strided<[?, 2], offset: ?>>
+  return %subview : memref<8x2xf32, strided<[?, 2], offset: ?>>
+}
+
+// CHECK-LABEL:   func.func @reinterpret_cast(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 10 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_2:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[VAL_1]]], sizes: [10, %[[VAL_0]]], strides: {{\[}}%[[VAL_0]], 1] : memref<?xf32, 1> to memref<10x?xf32, strided<[?, 1], offset: ?>, 1>
+// CHECK:           %[[VAL_3:.*]] = memref.memory_space_cast %[[VAL_2]] : memref<10x?xf32, strided<[?, 1], offset: ?>, 1> to memref<10x?xf32, strided<[?, 1], offset: ?>>
+// CHECK:           return %[[VAL_3]] : memref<10x?xf32, strided<[?, 1], offset: ?>>
+// CHECK:         }
+func.func @reinterpret_cast(%arg0: memref<?xf32, 1>, %arg1: index) -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %c0 = arith.constant 0 : index
+  %c10 = arith.constant 10 : index
+  %reinterpret_cast = memref.reinterpret_cast %memspacecast to offset: [%c0], sizes: [10, %c10], strides: [%c10, 1] : memref<?xf32> to memref<10x?xf32, strided<[?, 1], offset: ?>>
+  return %reinterpret_cast : memref<10x?xf32, strided<[?, 1], offset: ?>>
+}
+
+// CHECK-LABEL:   func.func @reshape(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?x?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: memref<1xindex>) -> memref<?xf32> {
+// CHECK:           %[[VAL_0:.*]] = memref.reshape %[[ARG0]](%[[ARG1]]) : (memref<?x?xf32, 1>, memref<1xindex>) -> memref<?xf32, 1>
+// CHECK:           %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<?xf32, 1> to memref<?xf32>
+// CHECK:           return %[[VAL_1]] : memref<?xf32>
+// CHECK:         }
+func.func @reshape(%arg0: memref<?x?xf32, 1>, %arg1: memref<1xindex>) -> memref<?xf32> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?x?xf32, 1> to memref<?x?xf32>
+  %reshape = memref.reshape %memspacecast(%arg1) : (memref<?x?xf32>, memref<1xindex>) -> memref<?xf32>
+  return %reshape : memref<?xf32>
+}
+
+// CHECK-LABEL:   func.func @expand_shape(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<12xf32, 1>) -> memref<3x4xf32> {
+// CHECK:           %[[VAL_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1]] output_shape [3, 4] : memref<12xf32, 1> into memref<3x4xf32, 1>
+// CHECK:           %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<3x4xf32, 1> to memref<3x4xf32>
+// CHECK:           return %[[VAL_1]] : memref<3x4xf32>
+// CHECK:         }
+func.func @expand_shape(%arg0: memref<12xf32, 1>) -> memref<3x4xf32> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<12xf32, 1> to memref<12xf32>
+  %expand_shape = memref.expand_shape %memspacecast [[0, 1]] output_shape [3, 4] : memref<12xf32> into memref<3x4xf32>
+  return %expand_shape : memref<3x4xf32>
+}
+
+// CHECK-LABEL:   func.func @collapse_shape(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<3x4xf32, 1>) -> memref<12xf32> {
+// CHECK:           %[[VAL_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[\[}}0, 1]] : memref<3x4xf32, 1> into memref<12xf32, 1>
+// CHECK:           %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<12xf32, 1> to memref<12xf32>
+// CHECK:           return %[[VAL_1]] : memref<12xf32>
+// CHECK:         }
+func.func @collapse_shape(%arg0: memref<3x4xf32, 1>) -> memref<12xf32> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<3x4xf32, 1> to memref<3x4xf32>
+  %collapse_shape = memref.collapse_shape %memspacecast [[0, 1]] : memref<3x4xf32> into memref<12xf32>
+  return %collapse_shape : memref<12xf32>
+}
+
+// CHECK-LABEL:   func.func @transpose(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?x?xf32, 1>) -> memref<?x?xf32, #[[$ATTR_0]]> {
+// CHECK:           %[[VAL_0:.*]] = memref.transpose %[[ARG0]] (d0, d1) -> (d1, d0) : memref<?x?xf32, 1> to memref<?x?xf32, #[[$ATTR_0]], 1>
+// CHECK:           %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<?x?xf32, #[[$ATTR_0]], 1> to memref<?x?xf32, #[[$ATTR_0]]>
+// CHECK:           return %[[VAL_1]] : memref<?x?xf32, #[[$ATTR_0]]>
+// CHECK:         }
+func.func @transpose(%arg0: memref<?x?xf32, 1>) -> memref<?x?xf32, #map> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?x?xf32, 1> to memref<?x?xf32>
+  %transpose = memref.transpose %memspacecast (d0, d1) -> (d1, d0) : memref<?x?xf32> to memref<?x?xf32, #map>
+  return %transpose : memref<?x?xf32, #map>
+}
+
+// CHECK-LABEL:   func.func @atomic_rmw(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index,
+// CHECK-SAME:      %[[ARG2:.*]]: f32) -> f32 {
+// CHECK:           %[[VAL_0:.*]] = memref.atomic_rmw addf %[[ARG2]], %[[ARG0]]{{\[}}%[[ARG1]]] : (f32, memref<?xf32, 1>) -> f32
+// CHECK:           return %[[VAL_0]] : f32
+// CHECK:         }
+func.func @atomic_rmw(%arg0: memref<?xf32, 1>, %arg1: index, %arg2: f32) -> f32 {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %0 = memref.atomic_rmw addf %arg2, %memspacecast[%arg1] : (f32, memref<?xf32>) -> f32
+  return %0 : f32
+}
+
+// CHECK-LABEL:   func.func @assume_alignment(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>) -> memref<?xf32> {
+// CHECK:           %[[VAL_0:.*]] = memref.assume_alignment %[[ARG0]], 16 : memref<?xf32, 1>
+// CHECK:           %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<?xf32, 1> to memref<?xf32>
+// CHECK:           return %[[VAL_1]] : memref<?xf32>
+// CHECK:         }
+func.func @assume_alignment(%arg0: memref<?xf32, 1>) -> memref<?xf32> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %1 = memref.assume_alignment %memspacecast, 16 : memref<?xf32>
+  return %1 : memref<?xf32>
+}
+
+// CHECK-LABEL:   func.func @op_with_cast_sequence(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<4x4xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index,
+// CHECK-SAME:      %[[ARG2:.*]]: f32) -> memref<16xf32> {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 4 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_2:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
+// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
+// CHECK:           %[[VAL_4:.*]] = memref.memory_space_cast %[[VAL_3]] : memref<16xf32, 1> to memref<16xf32>
+// CHECK:           %[[VAL_5:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_1]]] : memref<16xf32, 1>
+// CHECK:           %[[VAL_6:.*]] = arith.addf %[[VAL_5]], %[[ARG2]] : f32
+// CHECK:           memref.store %[[VAL_6]], %[[VAL_3]]{{\[}}%[[VAL_1]]] : memref<16xf32, 1>
+// CHECK:           %[[VAL_7:.*]] = memref.atomic_rmw addf %[[ARG2]], %[[VAL_3]]{{\[}}%[[VAL_0]]] : (f32, memref<16xf32, 1>) -> f32
+// CHECK:           return %[[VAL_4]] : memref<16xf32>
+// CHECK:         }
+func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32>
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32>
+  %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32>
+  %loaded = memref.load %collapsed[%c0] : memref<16xf32>
+  %added = arith.addf %loaded, %arg2 : f32
+  memref.store %added, %collapsed[%c0] : memref<16xf32>
+  %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32
+  return %collapsed : memref<16xf32>
+}
+
+// CHECK-LABEL:   func.func @transfer_read_write(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_1:.*]] = vector.transfer_read %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xf32>
+// CHECK:           vector.transfer_write %[[VAL_1]], %[[ARG0]]{{\[}}%[[ARG1]]] : vector<4xf32>, memref<?xf32, 1>
+// CHECK:           return
+// CHECK:         }
+func.func @transfer_read_write(%arg0: memref<?xf32, 1>, %arg1: index) {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %c0 = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %memspacecast[%arg1], %c0 : memref<?xf32>, vector<4xf32>
+  vector.transfer_write %0, %memspacecast[%arg1] : vector<4xf32>, memref<?xf32>
+  return
+}
+
+// NOTE: The operations disappear because they can get folded.
+// CHECK-LABEL:   func.func @transfer_read_write_tensor(
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<?xf32>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) -> tensor<?xf32> {
+// CHECK:           return %[[ARG0]] : tensor<?xf32>
+// CHECK:         }
+func.func @transfer_read_write_tensor(%arg0: tensor<?xf32>, %arg1: index) -> tensor<?xf32> {
+  %c0 = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %arg0[%arg1], %c0 : tensor<?xf32>, vector<4xf32>
+  %1 = vector.transfer_write %0, %arg0[%arg1] : vector<4xf32>, tensor<?xf32>
+  return %1 : tensor<?xf32>
+}
+
+// CHECK-LABEL:   func.func @vector_load_store(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) {
+// CHECK:           %[[VAL_0:.*]] = vector.load %[[ARG0]]{{\[}}%[[ARG1]]] : memref<?xf32, 1>, vector<4xf32>
+// CHECK:           vector.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[ARG1]]] : memref<?xf32, 1>, vector<4xf32>
+// CHECK:           return
+// CHECK:         }
+func.func @vector_load_store(%arg0: memref<?xf32, 1>, %arg1: index) {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %0 = vector.load %memspacecast[%arg1] : memref<?xf32>, vector<4xf32>
+  vector.store %0, %memspacecast[%arg1] : memref<?xf32>, vector<4xf32>
+  return
+}
+
+// CHECK-LABEL:   func.func @masked_load_store(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK:           %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+// CHECK:           %[[VAL_2:.*]] = vector.maskedload %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+// CHECK:           vector.maskedstore %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_2]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32>
+// CHECK:           return
+// CHECK:         }
+func.func @masked_load_store(%arg0: memref<?xf32, 1>, %arg1: index) {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %mask = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+  %passthrough = arith.constant dense<0.0> : vector<4xf32>
+  %0 = vector.maskedload %memspacecast[%arg1], %mask, %passthrough : memref<?xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  vector.maskedstore %memspacecast[%arg1], %mask, %0 : memref<?xf32>, vector<4xi1>, vector<4xf32>
+  return
+}
+
+// CHECK-LABEL:   func.func @gather_scatter(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK:           %[[VAL_1:.*]] = arith.constant dense<true> : vector<4xi1>
+// CHECK:           %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = vector.gather %[[ARG0]]{{\[}}%[[VAL_3]]] {{\[}}%[[VAL_2]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+// CHECK:           vector.scatter %[[ARG0]]{{\[}}%[[VAL_3]]] {{\[}}%[[VAL_2]]], %[[VAL_1]], %[[VAL_4]] : memref<?xf32, 1>, vector<4xindex>, vector<4xi1>, vector<4xf32>
+// CHECK:           return
+// CHECK:         }
+func.func @gather_scatter(%arg0: memref<?xf32, 1>, %arg1: index) {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %c0 = arith.constant 0 : index
+  %indices = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+  %mask = arith.constant dense<true> : vector<4xi1>
+  %passthrough = arith.constant dense<0.0> : vector<4xf32>
+  %0 = vector.gather %memspacecast[%c0] [%indices], %mask, %passthrough : memref<?xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  vector.scatter %memspacecast[%c0] [%indices], %mask, %0 : memref<?xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32>
+  return
+}
+
+// CHECK-LABEL:   func.func @expandload_compressstore(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK:           %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+// CHECK:           %[[VAL_2:.*]] = vector.expandload %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+// CHECK:           vector.compressstore %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_2]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32>
+// CHECK:           return
+// CHECK:         }
+func.func @expandload_compressstore(%arg0: memref<?xf32, 1>, %arg1: index) {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %mask = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+  %passthrough = arith.constant dense<0.0> : vector<4xf32>
+  %0 = vector.expandload %memspacecast[%arg1], %mask, %passthrough : memref<?xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  vector.compressstore %memspacecast[%arg1], %mask, %0 : memref<?xf32>, vector<4xi1>, vector<4xf32>
+  return
+}

>From 3709f67bd0fbe1d0b8b86121c20fb0f5c9a933d2 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Fri, 19 Sep 2025 12:00:22 +0000
Subject: [PATCH 02/10] address comments 1/2

---
 mlir/include/mlir/Transforms/Passes.td        |  7 ++--
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      |  2 +-
 .../test-fuse-casts-into-consumers.mlir       | 41 +++++++++++++------
 3 files changed, 34 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 69280e3d443ea..3204e80919456 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -589,9 +589,10 @@ def FuseMemorySpaceCastsIntoConsumers :
     Pass<"fuse-memory-space-casts-into-consumers"> {
   let summary = "Fuses memory-space cast operations into consumers.";
   let description = [{
-    This pass tries to fuse all possible memory-space cast operations into their consumers.
-    It does this by looking for `FuseMemorySpaceCastConsumerOpInterface`
-    operations, and invoking the interface methods to perform the fusion.
+    This pass tries to iteratively fuse all possible memory-space cast
+    operations into their consumers. It does this by looking for
+    `FuseMemorySpaceCastConsumerOpInterface` operations, and invoking the
+    interface methods to perform the fusion.
 
     Example:
 
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 0ddb2b0ca1645..11fd43ff54575 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1732,7 +1732,7 @@ TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
 
 bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
                                                PtrLikeTypeInterface src) {
-  return isa<MemRefType>(tgt) &&
+  return isa<BaseMemRefType>(tgt) &&
          tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
 }
 
diff --git a/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir b/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir
index 69a15f429cec2..7534332b3663a 100644
--- a/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir
+++ b/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir
@@ -32,6 +32,23 @@ func.func @load_store_unfoldable(%arg0: memref<?xf32, 1>, %arg1: index) {
   return
 }
 
+// CHECK-LABEL:   func.func @cast(
+// CHECK-SAME:                    %[[ARG0:.*]]: memref<2xf32, 1>,
+// CHECK-SAME:                    %[[ARG1:.*]]: memref<*xf32, 1>) -> (memref<*xf32>, memref<3x2xf32>) {
+// CHECK:           %[[VAL_0:.*]] = memref.cast %[[ARG0]] : memref<2xf32, 1> to memref<*xf32, 1>
+// CHECK:           %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<*xf32, 1> to memref<*xf32>
+// CHECK:           %[[VAL_2:.*]] = memref.cast %[[ARG1]] : memref<*xf32, 1> to memref<3x2xf32, 1>
+// CHECK:           %[[VAL_3:.*]] = memref.memory_space_cast %[[VAL_2]] : memref<3x2xf32, 1> to memref<3x2xf32>
+// CHECK:           return %[[VAL_1]], %[[VAL_3]] : memref<*xf32>, memref<3x2xf32>
+// CHECK:         }
+func.func @cast(%arg0: memref<2xf32, 1>, %arg1: memref<*xf32, 1>) -> (memref<*xf32>, memref<3x2xf32>) {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<2xf32, 1> to memref<2xf32>
+  %1 = memref.cast %memspacecast : memref<2xf32> to memref<*xf32>
+  %memspacecast_1 = memref.memory_space_cast %arg1 : memref<*xf32, 1> to memref<*xf32>
+  %2 = memref.cast %memspacecast_1 : memref<*xf32> to memref<3x2xf32>
+  return %1, %2 : memref<*xf32>, memref<3x2xf32>
+}
+
 // CHECK-LABEL:   func.func @view(
 // CHECK-SAME:                    %[[ARG0:.*]]: memref<?xi8, 1>,
 // CHECK-SAME:                    %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> memref<?x?xi8> {
@@ -63,8 +80,8 @@ func.func @subview(%arg0: memref<?x?xf32, 1>, %arg1: index) -> memref<8x2xf32, s
 // CHECK-LABEL:   func.func @reinterpret_cast(
 // CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
 // CHECK-SAME:      %[[ARG1:.*]]: index) -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
-// CHECK:           %[[VAL_0:.*]] = arith.constant 10 : index
-// CHECK:           %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_0:.*]] = arith.constant 10 : index
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_2:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[VAL_1]]], sizes: [10, %[[VAL_0]]], strides: {{\[}}%[[VAL_0]], 1] : memref<?xf32, 1> to memref<10x?xf32, strided<[?, 1], offset: ?>, 1>
 // CHECK:           %[[VAL_3:.*]] = memref.memory_space_cast %[[VAL_2]] : memref<10x?xf32, strided<[?, 1], offset: ?>, 1> to memref<10x?xf32, strided<[?, 1], offset: ?>>
 // CHECK:           return %[[VAL_3]] : memref<10x?xf32, strided<[?, 1], offset: ?>>
@@ -155,8 +172,8 @@ func.func @assume_alignment(%arg0: memref<?xf32, 1>) -> memref<?xf32> {
 // CHECK-SAME:      %[[ARG0:.*]]: memref<4x4xf32, 1>,
 // CHECK-SAME:      %[[ARG1:.*]]: index,
 // CHECK-SAME:      %[[ARG2:.*]]: f32) -> memref<16xf32> {
-// CHECK:           %[[VAL_0:.*]] = arith.constant 4 : index
-// CHECK:           %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_0:.*]] = arith.constant 4 : index
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_2:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
 // CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
 // CHECK:           %[[VAL_4:.*]] = memref.memory_space_cast %[[VAL_3]] : memref<16xf32, 1> to memref<16xf32>
@@ -225,8 +242,8 @@ func.func @vector_load_store(%arg0: memref<?xf32, 1>, %arg1: index) {
 // CHECK-LABEL:   func.func @masked_load_store(
 // CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
 // CHECK-SAME:      %[[ARG1:.*]]: index) {
-// CHECK:           %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
-// CHECK:           %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+// CHECK-DAG:       %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
 // CHECK:           %[[VAL_2:.*]] = vector.maskedload %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32> into vector<4xf32>
 // CHECK:           vector.maskedstore %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_2]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32>
 // CHECK:           return
@@ -243,10 +260,10 @@ func.func @masked_load_store(%arg0: memref<?xf32, 1>, %arg1: index) {
 // CHECK-LABEL:   func.func @gather_scatter(
 // CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
 // CHECK-SAME:      %[[ARG1:.*]]: index) {
-// CHECK:           %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
-// CHECK:           %[[VAL_1:.*]] = arith.constant dense<true> : vector<4xi1>
-// CHECK:           %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant dense<true> : vector<4xi1>
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_4:.*]] = vector.gather %[[ARG0]]{{\[}}%[[VAL_3]]] {{\[}}%[[VAL_2]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
 // CHECK:           vector.scatter %[[ARG0]]{{\[}}%[[VAL_3]]] {{\[}}%[[VAL_2]]], %[[VAL_1]], %[[VAL_4]] : memref<?xf32, 1>, vector<4xindex>, vector<4xi1>, vector<4xf32>
 // CHECK:           return
@@ -265,8 +282,8 @@ func.func @gather_scatter(%arg0: memref<?xf32, 1>, %arg1: index) {
 // CHECK-LABEL:   func.func @expandload_compressstore(
 // CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
 // CHECK-SAME:      %[[ARG1:.*]]: index) {
-// CHECK:           %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
-// CHECK:           %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+// CHECK-DAG:       %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
 // CHECK:           %[[VAL_2:.*]] = vector.expandload %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32> into vector<4xf32>
 // CHECK:           vector.compressstore %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_2]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32>
 // CHECK:           return

>From 4789d11a5916bdcb3b2cf060954a26b8d57fb190 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Fri, 19 Sep 2025 12:56:13 +0000
Subject: [PATCH 03/10] address comements 2/2

---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |  24 ++--
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  20 +--
 .../include/mlir/Interfaces/MemOpInterfaces.h |  11 +-
 .../mlir/Interfaces/MemOpInterfaces.td        |  45 ++++---
 .../Transforms/BubbleDownMemorySpaceCasts.h   |  20 +++
 .../FuseMemorySpaceCastsIntoConsumers.h       |  20 ---
 mlir/include/mlir/Transforms/Passes.h         |   2 +-
 mlir/include/mlir/Transforms/Passes.td        |  15 +--
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 124 ++++++++----------
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  80 +++++------
 mlir/lib/Interfaces/MemOpInterfaces.cpp       |  12 +-
 ...ers.cpp => BubbleDownMemorySpaceCasts.cpp} |  46 +++----
 mlir/lib/Transforms/CMakeLists.txt            |   2 +-
 ... test-bubble-down-memory-space-casts.mlir} |   2 +-
 14 files changed, 205 insertions(+), 218 deletions(-)
 create mode 100644 mlir/include/mlir/Transforms/BubbleDownMemorySpaceCasts.h
 delete mode 100644 mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h
 rename mlir/lib/Transforms/{FuseMemorySpaceCastsIntoConsumers.cpp => BubbleDownMemorySpaceCasts.cpp} (53%)
 rename mlir/test/Transforms/{test-fuse-casts-into-consumers.mlir => test-bubble-down-memory-space-casts.mlir} (99%)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 238a767ac8b73..c708d7f3d884a 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -147,7 +147,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
       Pure,
       ViewLikeOpInterface,
       SameOperandsAndResultType,
-      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
+      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
     ]> {
   let summary =
       "assumption that gives alignment information to the input memref";
@@ -458,7 +458,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
 def MemRef_CastOp : MemRef_Op<"cast", [
       DeclareOpInterfaceMethods<CastOpInterface>,
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
       MemRefsNormalizable,
       Pure,
       SameOperandsAndResultShape,
@@ -1197,7 +1197,7 @@ def LoadOp : MemRef_Op<"load",
                      "memref", "result",
                      "::llvm::cast<MemRefType>($_self).getElementType()">,
       MemRefsNormalizable,
-      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
       DeclareOpInterfaceMethods<PromotableMemOpInterface>,
       DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
   let summary = "load operation";
@@ -1381,7 +1381,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
 def MemRef_ReinterpretCastOp
   : MemRef_OpWithOffsetSizesAndStrides<"reinterpret_cast", [
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
       AttrSizedOperandSegments,
       MemRefsNormalizable,
       Pure,
@@ -1609,7 +1609,7 @@ def MemRef_RankOp : MemRef_Op<"rank", [Pure]> {
 
 def MemRef_ReshapeOp: MemRef_Op<"reshape", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+    DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
     Pure,
     ViewLikeOpInterface]>  {
   let summary = "memref reshape operation";
@@ -1708,7 +1708,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
 
 def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+    DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
   let summary = "operation to produce a memref with a higher rank.";
   let description = [{
@@ -1831,7 +1831,7 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
 
 def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
+    DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
   ]> {
   let summary = "operation to produce a memref with a smaller rank.";
   let description = [{
@@ -1939,7 +1939,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
                      "memref", "value",
                      "::llvm::cast<MemRefType>($_self).getElementType()">,
       MemRefsNormalizable,
-      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
       DeclareOpInterfaceMethods<PromotableMemOpInterface>,
       DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
   let summary = "store operation";
@@ -2017,7 +2017,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
 
 def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+    DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
     DeclareOpInterfaceMethods<ViewLikeOpInterface>,
     AttrSizedOperandSegments,
     OffsetSizeAndStrideOpInterface,
@@ -2293,7 +2293,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
 
 def MemRef_TransposeOp : MemRef_Op<"transpose", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+    DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
     Pure]>,
     Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
     Results<(outs AnyStridedMemRef)> {
@@ -2329,7 +2329,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [
 
 def MemRef_ViewOp : MemRef_Op<"view", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+    DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
     DeclareOpInterfaceMethods<ViewLikeOpInterface>,
     Pure]> {
   let summary = "memref view operation";
@@ -2406,7 +2406,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
 //===----------------------------------------------------------------------===//
 
 def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
-      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
       AllTypesMatch<["value", "result"]>,
       TypesMatchWith<"value type matches element type of memref",
                      "memref", "value",
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 93e9bfc78ea75..252c0b72456df 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1247,7 +1247,7 @@ def Vector_TransferReadOp :
       DeclareOpInterfaceMethods<MaskableOpInterface>,
       DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
       DeclareOpInterfaceMethods<ConditionallySpeculatable>,
-      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
       AttrSizedOperandSegments,
       DestinationStyleOpInterface
     ]>,
@@ -1495,7 +1495,7 @@ def Vector_TransferWriteOp :
       DeclareOpInterfaceMethods<MaskableOpInterface>,
       DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
       DeclareOpInterfaceMethods<ConditionallySpeculatable>,
-      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
       AttrSizedOperandSegments,
       DestinationStyleOpInterface
   ]>,
@@ -1652,7 +1652,7 @@ def Vector_TransferWriteOp :
 
 def Vector_LoadOp : Vector_Op<"load", [
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
-    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
+    DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
   ]> {
   let summary = "reads an n-D slice of memory into an n-D vector";
   let description = [{
@@ -1769,7 +1769,7 @@ def Vector_LoadOp : Vector_Op<"load", [
 
 def Vector_StoreOp : Vector_Op<"store", [
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
-    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
+    DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
   ]> {
   let summary = "writes an n-D vector to an n-D slice of memory";
   let description = [{
@@ -1874,7 +1874,7 @@ def Vector_StoreOp : Vector_Op<"store", [
 }
 
 def Vector_MaskedLoadOp :
-  Vector_Op<"maskedload", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
+  Vector_Op<"maskedload", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -1966,7 +1966,7 @@ def Vector_MaskedLoadOp :
 }
 
 def Vector_MaskedStoreOp :
-  Vector_Op<"maskedstore", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
+  Vector_Op<"maskedstore", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2046,7 +2046,7 @@ def Vector_MaskedStoreOp :
 def Vector_GatherOp :
   Vector_Op<"gather", [
     DeclareOpInterfaceMethods<MaskableOpInterface>,
-    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+    DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
   ]>,
     Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
@@ -2150,7 +2150,7 @@ def Vector_GatherOp :
 }
 
 def Vector_ScatterOp :
-  Vector_Op<"scatter", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
+  Vector_Op<"scatter", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$offsets,
                VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
@@ -2235,7 +2235,7 @@ def Vector_ScatterOp :
 }
 
 def Vector_ExpandLoadOp :
-  Vector_Op<"expandload", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
+  Vector_Op<"expandload", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
                FixedVectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2323,7 +2323,7 @@ def Vector_ExpandLoadOp :
 }
 
 def Vector_CompressStoreOp :
-  Vector_Op<"compressstore", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
+  Vector_Op<"compressstore", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
                FixedVectorOfNonZeroRankOf<[I1]>:$mask,
diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.h b/mlir/include/mlir/Interfaces/MemOpInterfaces.h
index cc9f4c6b3882e..d4ed71e38f4ff 100644
--- a/mlir/include/mlir/Interfaces/MemOpInterfaces.h
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.h
@@ -21,13 +21,12 @@ namespace detail {
 /// Attempt to verify the given memory space cast operation.
 LogicalResult verifyMemorySpaceCastOpInterface(Operation *op);
 
-/// Tries to fuse inplace a `MemorySpaceCastOpInterface` operation referenced by
-/// `operand`. On success, it returns `results`, and sets `modifiedInPlace` to
-/// true. It returns failure if `operand` doesn't reference a
+/// Tries to bubble-down inplace a `MemorySpaceCastOpInterface` operation
+/// referenced by `operand`. On success, it returns `results` and true. It
+/// returns failure if `operand` doesn't reference a
 /// `MemorySpaceCastOpInterface` op.
-FailureOr<SmallVector<Value>>
-fuseInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results,
-                               bool &modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results);
 } // namespace detail
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.td b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
index 0b8ba19171fb7..d097b00c8e80c 100644
--- a/mlir/include/mlir/Interfaces/MemOpInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
@@ -16,27 +16,26 @@
 include "mlir/IR/OpBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
-def FuseMemorySpaceCastConsumerOpInterface :
-    OpInterface<"FuseMemorySpaceCastConsumerOpInterface"> {
+def MemorySpaceCastConsumerOpInterface :
+    OpInterface<"MemorySpaceCastConsumerOpInterface"> {
   let description = [{
-    An interface to fuse memory-space cast operands into a consumer operation.
-    It is the responsibility of the interface to determine which casts can be
-    fused into the operation.
+    An interface for operations that can consume memory-space cast-like
+    operations.
   }];
   let cppNamespace = "::mlir";
   let methods = [
     InterfaceMethod<[{
-        Attempt to fuse the incoming cast-like operands. Returns `success`
-        and any new results on fusion success, otherwise it returns failure.
+        Attempt to bubble-down the incoming cast-like operands. On success
+        returns any new results, and whether the operation was modified in
+        place, otherwise it returns failure.
         If new results are produced, these must be compatible with the original
         operation results.
 
-        The `modifiedInPlace` parameter indicates whether the operation was
-        modified in place. If `false` and the fusion succeeded, then the
-        interface guarantees it is valid to erase the original operation.
-        If `true`, then the interface must guarantee no operations were created
-        by the method, and that no further IR modification is necessary. It is
-        considered an error if `modifiedInPlace` is true and the fusion failed.
+        If the operation was not modified in place, then the interface
+        guarantees it is valid to erase the original operation.
+        If the operation was modified in place, then the interface must
+        guarantee no operations were created by the method, and that no further
+        IR modification is necessary.
 
         Any implementations of this method must not erase/replace the original
         operation, instead it is the caller responsibility to erase or replace
@@ -45,8 +44,9 @@ def FuseMemorySpaceCastConsumerOpInterface :
         Finally, any implementations of this method have to guarantee that the
         IR remains valid at all times.
       }],
-      "::llvm::FailureOr<::llvm::SmallVector<::mlir::Value>>", "fuseCastOperands",
-      (ins "::mlir::OpBuilder &":$builder, "bool &":$modifiedInPlace)
+      "::llvm::FailureOr<std::pair<::llvm::SmallVector<::mlir::Value>, bool>>",
+      "bubbleDownCasts",
+      (ins "::mlir::OpBuilder &":$builder)
     >,
   ];
 }
@@ -83,13 +83,16 @@ def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
         Clones the memory space cast op with the given source and target type.
       }],
       "::mlir::MemorySpaceCastOpInterface", "cloneMemorySpaceCastOp",
-      (ins "::mlir::OpBuilder &":$builder, "::mlir::Type":$tgt,
+      (ins "::mlir::OpBuilder &":$builder, "::mlir::PtrLikeTypeInterface":$tgt,
            "::mlir::Value":$src)
     >,
     InterfaceMethod<[{
-        Returns whether the cast allows to be fused.
+        Returns whether the memory-space cast is lossless. A lossless
+        memory-space cast must not lose any information encoded in the memory
+        space. An example of such cast, is any conversion to the generic memory
+        space. 
       }],
-      "bool", "isFusableMemorySpaceCast"
+      "bool", "isLosslessCast"
     >
   ];
   let verify = [{
@@ -99,12 +102,12 @@ def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
   let extraClassDeclaration = [{
     /// Returns the underlying `MemorySpaceCastOpInterface` op if `value`
     /// is produced by a `MemorySpaceCastOpInterface` op, and
-    /// `isFusableMemorySpaceCast` returns true, otherwise it returns null.
+    /// `isLosslessCast` returns true, otherwise it returns null.
     static ::mlir::MemorySpaceCastOpInterface
-    getIfFusableCast(::mlir::Value value) {
+    getIfLosslessCast(::mlir::Value value) {
       auto op = ::llvm::dyn_cast_or_null<::mlir::MemorySpaceCastOpInterface>(
         value.getDefiningOp());
-      if (!op || !op.isFusableMemorySpaceCast())
+      if (!op || !op.isLosslessCast())
         return nullptr;
       return op;
     }
diff --git a/mlir/include/mlir/Transforms/BubbleDownMemorySpaceCasts.h b/mlir/include/mlir/Transforms/BubbleDownMemorySpaceCasts.h
new file mode 100644
index 0000000000000..99db092879a90
--- /dev/null
+++ b/mlir/include/mlir/Transforms/BubbleDownMemorySpaceCasts.h
@@ -0,0 +1,20 @@
+//===-- BubbleDownMemorySpaceCasts.h - Bubble down cast patterns ---C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_BUBBLEDOWNMEMORYSPACECASTS_H
+#define MLIR_TRANSFORMS_BUBBLEDOWNMEMORYSPACECASTS_H
+
+namespace mlir {
+class PatternBenefit;
+class RewritePatternSet;
+/// Collect a set of patterns to bubble-down memory-space cast operations.
+void populateBubbleDownMemorySpaceCastPatterns(RewritePatternSet &patterns,
+                                               PatternBenefit benefit);
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_BUBBLEDOWNMEMORYSPACECASTS_H
diff --git a/mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h b/mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h
deleted file mode 100644
index 9333f92a10289..0000000000000
--- a/mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h
+++ /dev/null
@@ -1,20 +0,0 @@
-//===-- FuseMemorySpaceCastsIntoConsumers.h - Cast fusion patterns -C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_TRANSFORMS_FUSEMEMORYSPACECASTSINTOCONSUMERS_H
-#define MLIR_TRANSFORMS_FUSEMEMORYSPACECASTSINTOCONSUMERS_H
-
-namespace mlir {
-class RewritePatternSet;
-/// Collect a set of patterns to fuse memory-space cast operations into
-/// consumers.
-void populateFuseMemorySpaceCastIntoConsumersPatterns(
-    RewritePatternSet &patterns);
-} // namespace mlir
-
-#endif // MLIR_TRANSFORMS_FUSEMEMORYSPACECASTSINTOCONSUMERS_H
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 610a9671fede8..1c035f2a843ff 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -46,7 +46,7 @@ class GreedyRewriteConfig;
 #define GEN_PASS_DECL_SYMBOLPRIVATIZE
 #define GEN_PASS_DECL_TOPOLOGICALSORT
 #define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
-#define GEN_PASS_DECL_FUSEMEMORYSPACECASTSINTOCONSUMERS
+#define GEN_PASS_DECL_BUBBLEDOWNMEMORYSPACECASTS
 #include "mlir/Transforms/Passes.h.inc"
 
 /// Creates an instance of the Canonicalizer pass, configured with default
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 3204e80919456..8f0b80c5e511b 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -585,14 +585,13 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
   ];
 }
 
-def FuseMemorySpaceCastsIntoConsumers :
-    Pass<"fuse-memory-space-casts-into-consumers"> {
-  let summary = "Fuses memory-space cast operations into consumers.";
+def BubbleDownMemorySpaceCasts :
+    Pass<"bubble-down-memory-space-casts"> {
+  let summary = "Bubbles down memory-space cast operations.";
   let description = [{
-    This pass tries to iteratively fuse all possible memory-space cast
-    operations into their consumers. It does this by looking for
-    `FuseMemorySpaceCastConsumerOpInterface` operations, and invoking the
-    interface methods to perform the fusion.
+    This pass tries to iteratively bubble down all possible memory-space cast
+    operations. It does this by looking for `MemorySpaceCastConsumerOpInterface`
+    operations, and invoking the interface methods to perform the bubbling down.
 
     Example:
 
@@ -609,7 +608,7 @@ def FuseMemorySpaceCastsIntoConsumers :
       %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32
       return %collapsed : memref<16xf32>
     }
-    // mlir-opt --fuse-casts-into-consumers
+    // mlir-opt --bubble-down-memory-space-casts
     func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
       %c4 = arith.constant 4 : index
       %c0 = arith.constant 0 : index
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11fd43ff54575..6f276efb84c1c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -111,14 +111,14 @@ static void constifyIndexValues(SmallVectorImpl<OpFoldResult> &values,
   }
 }
 
-/// Helper function to retrieve a fusable memory-space cast, and the
+/// Helper function to retrieve a lossless memory-space cast, and the
 /// corresponding new result memref type.
 static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
-getFuseCastInfo(BaseMemRefType resultTy, Value src) {
+getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src) {
   MemorySpaceCastOpInterface castOp =
-      MemorySpaceCastOpInterface::getIfFusableCast(src);
+      MemorySpaceCastOpInterface::getIfLosslessCast(src);
 
-  // Bail if the cast is not fusable.
+  // Bail if the cast is not lossless.
   if (!castOp)
     return {};
 
@@ -141,25 +141,23 @@ getFuseCastInfo(BaseMemRefType resultTy, Value src) {
   return std::make_tuple(castOp, *tgtTy, *srcTy);
 }
 
-/// Implementation of `fuseCastOperands` method for memref operations that
+/// Implementation of `bubbleDownCasts` method for memref operations that
 /// return a single memref result.
 template <typename ConcreteOpTy>
-static FailureOr<SmallVector<Value>>
-fuseCastOperandsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder,
-                                  bool &modifiedInPlace, OpOperand &src) {
-  auto [castOp, tgtTy, resTy] = getFuseCastInfo(op.getType(), src.get());
+static FailureOr<std::pair<SmallVector<Value>, bool>>
+bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder,
+                                 OpOperand &src) {
+  auto [castOp, tgtTy, resTy] = getMemorySpaceCastInfo(op.getType(), src.get());
   // Bail if we cannot cast.
   if (!castOp)
     return failure();
 
-  modifiedInPlace = false;
-
   // Create the new operands.
   SmallVector<Value> operands;
   llvm::append_range(operands, op->getOperands());
   operands[src.getOperandNumber()] = castOp.getSourcePtr();
 
-  // Create the fused op and results.
+  // Create the new op and results.
   auto newOp = ConcreteOpTy::create(
       builder, op.getLoc(), TypeRange(resTy), operands, op.getProperties(),
       llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
@@ -167,7 +165,7 @@ fuseCastOperandsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder,
   // Insert a memory-space cast to the original memory space of the op.
   MemorySpaceCastOpInterface result =
       castOp.cloneMemorySpaceCastOp(builder, tgtTy, newOp.getResult());
-  return SmallVector<Value>({result.getTargetPtr()});
+  return std::make_pair(SmallVector<Value>({result.getTargetPtr()}), false);
 }
 
 //===----------------------------------------------------------------------===//
@@ -601,10 +599,9 @@ OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
   return getMemref();
 }
 
-FailureOr<SmallVector<Value>>
-AssumeAlignmentOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
-  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
-                                           getMemrefMutable());
+FailureOr<std::pair<SmallVector<Value>, bool>>
+AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
+  return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
 }
 
 //===----------------------------------------------------------------------===//
@@ -775,10 +772,9 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
 }
 
-FailureOr<SmallVector<Value>> CastOp::fuseCastOperands(OpBuilder &builder,
-                                                       bool &modifiedInPlace) {
-  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
-                                           getSourceMutable());
+FailureOr<std::pair<SmallVector<Value>, bool>>
+CastOp::bubbleDownCasts(OpBuilder &builder) {
+  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
 }
 
 //===----------------------------------------------------------------------===//
@@ -1672,10 +1668,10 @@ OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
   return OpFoldResult();
 }
 
-FailureOr<SmallVector<Value>> LoadOp::fuseCastOperands(OpBuilder &builder,
-                                                       bool &modifiedInPlace) {
-  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
-      getMemrefMutable(), getResult(), modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+LoadOp::bubbleDownCasts(OpBuilder &builder) {
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
+                                                            getResult());
 }
 
 //===----------------------------------------------------------------------===//
@@ -1737,15 +1733,16 @@ bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
 }
 
 MemorySpaceCastOpInterface
-MemorySpaceCastOp::cloneMemorySpaceCastOp(OpBuilder &b, Type tgt, Value src) {
-  assert(isValidMemorySpaceCast(cast<PtrLikeTypeInterface>(tgt),
-                                cast<PtrLikeTypeInterface>(src.getType())) &&
-         "invalid arguments");
+MemorySpaceCastOp::cloneMemorySpaceCastOp(OpBuilder &b,
+                                          PtrLikeTypeInterface tgt, Value src) {
+  assert(
+      isValidMemorySpaceCast(tgt, cast<PtrLikeTypeInterface>(src.getType())) &&
+      "invalid arguments");
   return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
 }
 
-bool MemorySpaceCastOp::isFusableMemorySpaceCast() {
-  // Only allow fusion when this is discarding information.
+bool MemorySpaceCastOp::isLosslessCast() {
+  // The only cast we recognize as lossless is to the generic space.
   return getDest().getType().getMemorySpace() == nullptr;
 }
 
@@ -2145,10 +2142,9 @@ void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
 }
 
-FailureOr<SmallVector<Value>>
-ReinterpretCastOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
-  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
-                                           getSourceMutable());
+FailureOr<std::pair<SmallVector<Value>, bool>>
+ReinterpretCastOp::bubbleDownCasts(OpBuilder &builder) {
+  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
 }
 
 //===----------------------------------------------------------------------===//
@@ -2458,10 +2454,9 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
       ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
 }
 
-FailureOr<SmallVector<Value>>
-ExpandShapeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
-  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
-                                           getSrcMutable());
+FailureOr<std::pair<SmallVector<Value>, bool>>
+ExpandShapeOp::bubbleDownCasts(OpBuilder &builder) {
+  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
 }
 
 /// Compute the layout map after collapsing a given source MemRef type with the
@@ -2685,10 +2680,9 @@ OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
                                                        adaptor.getOperands());
 }
 
-FailureOr<SmallVector<Value>>
-CollapseShapeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
-  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
-                                           getSrcMutable());
+FailureOr<std::pair<SmallVector<Value>, bool>>
+CollapseShapeOp::bubbleDownCasts(OpBuilder &builder) {
+  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
 }
 
 //===----------------------------------------------------------------------===//
@@ -2731,10 +2725,9 @@ LogicalResult ReshapeOp::verify() {
   return success();
 }
 
-FailureOr<SmallVector<Value>>
-ReshapeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
-  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
-                                           getSourceMutable());
+FailureOr<std::pair<SmallVector<Value>, bool>>
+ReshapeOp::bubbleDownCasts(OpBuilder &builder) {
+  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
 }
 
 //===----------------------------------------------------------------------===//
@@ -2754,10 +2747,10 @@ LogicalResult StoreOp::fold(FoldAdaptor adaptor,
   return foldMemRefCast(*this, getValueToStore());
 }
 
-FailureOr<SmallVector<Value>> StoreOp::fuseCastOperands(OpBuilder &builder,
-                                                        bool &modifiedInPlace) {
-  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
-      getMemrefMutable(), ValueRange(), modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+StoreOp::bubbleDownCasts(OpBuilder &builder) {
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
+                                                            ValueRange());
 }
 
 //===----------------------------------------------------------------------===//
@@ -3416,10 +3409,9 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
-FailureOr<SmallVector<Value>>
-SubViewOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
-  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
-                                           getSourceMutable());
+FailureOr<std::pair<SmallVector<Value>, bool>>
+SubViewOp::bubbleDownCasts(OpBuilder &builder) {
+  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
 }
 
 //===----------------------------------------------------------------------===//
@@ -3522,10 +3514,9 @@ OpFoldResult TransposeOp::fold(FoldAdaptor) {
   return {};
 }
 
-FailureOr<SmallVector<Value>>
-TransposeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
-  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
-                                           getInMutable());
+FailureOr<std::pair<SmallVector<Value>, bool>>
+TransposeOp::bubbleDownCasts(OpBuilder &builder) {
+  return bubbleDownCastsPassthroughOpImpl(*this, builder, getInMutable());
 }
 
 //===----------------------------------------------------------------------===//
@@ -3671,10 +3662,9 @@ void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
 }
 
-FailureOr<SmallVector<Value>> ViewOp::fuseCastOperands(OpBuilder &builder,
-                                                       bool &modifiedInPlace) {
-  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
-                                           getSourceMutable());
+FailureOr<std::pair<SmallVector<Value>, bool>>
+ViewOp::bubbleDownCasts(OpBuilder &builder) {
+  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
 }
 
 //===----------------------------------------------------------------------===//
@@ -3722,10 +3712,10 @@ OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
   return OpFoldResult();
 }
 
-FailureOr<SmallVector<Value>>
-AtomicRMWOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
-  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
-      getMemrefMutable(), getResult(), modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+AtomicRMWOp::bubbleDownCasts(OpBuilder &builder) {
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
+                                                            getResult());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 806e6c1c070aa..77dcb1fc6220e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5087,12 +5087,12 @@ void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<TransferReadAfterWriteToBroadcast>(context);
 }
 
-FailureOr<SmallVector<Value>>
-TransferReadOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+FailureOr<std::pair<SmallVector<Value>, bool>>
+TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
   if (!hasPureBufferSemantics())
     return failure();
-  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
-      getBaseMutable(), getResult(), modifiedInPlace);
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+                                                            getResult());
 }
 
 //===----------------------------------------------------------------------===//
@@ -5582,12 +5582,12 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
 }
 
-FailureOr<SmallVector<Value>>
-TransferWriteOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+FailureOr<std::pair<SmallVector<Value>, bool>>
+TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
   if (!hasPureBufferSemantics())
     return failure();
-  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
-      getBaseMutable(), ValueRange(), modifiedInPlace);
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+                                                            ValueRange());
 }
 
 //===----------------------------------------------------------------------===//
@@ -5644,10 +5644,10 @@ std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
   return llvm::to_vector<4>(getVectorType().getShape());
 }
 
-FailureOr<SmallVector<Value>> LoadOp::fuseCastOperands(OpBuilder &builder,
-                                                       bool &modifiedInPlace) {
-  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
-      getBaseMutable(), getResult(), modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+LoadOp::bubbleDownCasts(OpBuilder &builder) {
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+                                                            getResult());
 }
 
 //===----------------------------------------------------------------------===//
@@ -5689,10 +5689,10 @@ std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
   return llvm::to_vector<4>(getVectorType().getShape());
 }
 
-FailureOr<SmallVector<Value>> StoreOp::fuseCastOperands(OpBuilder &builder,
-                                                        bool &modifiedInPlace) {
-  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
-      getBaseMutable(), ValueRange(), modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+StoreOp::bubbleDownCasts(OpBuilder &builder) {
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+                                                            ValueRange());
 }
 
 //===----------------------------------------------------------------------===//
@@ -5749,10 +5749,10 @@ OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
   return OpFoldResult();
 }
 
-FailureOr<SmallVector<Value>>
-MaskedLoadOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
-  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
-      getBaseMutable(), getResult(), modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+                                                            getResult());
 }
 
 //===----------------------------------------------------------------------===//
@@ -5805,10 +5805,10 @@ LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
   return memref::foldMemRefCast(*this);
 }
 
-FailureOr<SmallVector<Value>>
-MaskedStoreOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
-  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
-      getBaseMutable(), ValueRange(), modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+                                                            ValueRange());
 }
 
 //===----------------------------------------------------------------------===//
@@ -5914,10 +5914,10 @@ void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<GatherFolder, FoldContiguousGather>(context);
 }
 
-FailureOr<SmallVector<Value>>
-GatherOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
-  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
-      getBaseMutable(), getResult(), modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+GatherOp::bubbleDownCasts(OpBuilder &builder) {
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+                                                            getResult());
 }
 
 //===----------------------------------------------------------------------===//
@@ -5982,10 +5982,10 @@ void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ScatterFolder, FoldContiguousScatter>(context);
 }
 
-FailureOr<SmallVector<Value>>
-ScatterOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
-  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
-      getBaseMutable(), ValueRange(), modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+ScatterOp::bubbleDownCasts(OpBuilder &builder) {
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+                                                            ValueRange());
 }
 
 //===----------------------------------------------------------------------===//
@@ -6036,10 +6036,10 @@ void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ExpandLoadFolder>(context);
 }
 
-FailureOr<SmallVector<Value>>
-ExpandLoadOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
-  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
-      getBaseMutable(), getResult(), modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+                                                            getResult());
 }
 
 //===----------------------------------------------------------------------===//
@@ -6088,10 +6088,10 @@ void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<CompressStoreFolder>(context);
 }
 
-FailureOr<SmallVector<Value>>
-CompressStoreOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
-  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
-      getBaseMutable(), ValueRange(), modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+                                                            ValueRange());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/MemOpInterfaces.cpp b/mlir/lib/Interfaces/MemOpInterfaces.cpp
index 013d828da1d66..10303185ad833 100644
--- a/mlir/lib/Interfaces/MemOpInterfaces.cpp
+++ b/mlir/lib/Interfaces/MemOpInterfaces.cpp
@@ -55,19 +55,19 @@ LogicalResult mlir::detail::verifyMemorySpaceCastOpInterface(Operation *op) {
   return success();
 }
 
-FailureOr<SmallVector<Value>> mlir::detail::fuseInPlaceMemorySpaceCastImpl(
-    OpOperand &operand, ValueRange results, bool &modifiedInPlace) {
+FailureOr<std::pair<SmallVector<Value>, bool>>
+mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand,
+                                                   ValueRange results) {
   MemorySpaceCastOpInterface castOp =
-      MemorySpaceCastOpInterface::getIfFusableCast(operand.get());
+      MemorySpaceCastOpInterface::getIfLosslessCast(operand.get());
 
-  // Bail if the src is not produced by a `MemorySpaceCastOpInterface`.
+  // Bail if the src is not valid.
   if (!castOp)
     return failure();
 
   // Modify the op.
-  modifiedInPlace = true;
   operand.set(castOp.getSourcePtr());
-  return llvm::to_vector_of<Value>(results);
+  return std::make_pair(llvm::to_vector_of<Value>(results), true);
 }
 
 #include "mlir/Interfaces/MemOpInterfaces.cpp.inc"
diff --git a/mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp b/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp
similarity index 53%
rename from mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp
rename to mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp
index 010b88ac12de2..96e0e8d584ea7 100644
--- a/mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp
+++ b/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp
@@ -1,4 +1,4 @@
-//===- FuseMemorySpaceCastsIntoConsumers.cpp - Fuse casts transform -------===//
+//===- BubbleDownMemorySpaceCasts.cpp - Bubble down casts transform -------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h"
+#include "mlir/Transforms/BubbleDownMemorySpaceCasts.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/MemOpInterfaces.h"
 #include "mlir/Pass/Pass.h"
@@ -17,57 +17,53 @@
 using namespace mlir;
 
 namespace mlir {
-#define GEN_PASS_DEF_FUSEMEMORYSPACECASTSINTOCONSUMERS
+#define GEN_PASS_DEF_BUBBLEDOWNMEMORYSPACECASTS
 #include "mlir/Transforms/Passes.h.inc"
 } // namespace mlir
 
 namespace {
 //===----------------------------------------------------------------------===//
-// FuseCastsPattern pattern
+// BubbleDownCastsPattern pattern
 //===----------------------------------------------------------------------===//
-/// Pattern to fuse casts into consumer operations.
-struct FuseCastsPattern
-    : public OpInterfaceRewritePattern<FuseMemorySpaceCastConsumerOpInterface> {
+/// Pattern to bubble down casts into consumer operations.
+struct BubbleDownCastsPattern
+    : public OpInterfaceRewritePattern<MemorySpaceCastConsumerOpInterface> {
   using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
 
-  LogicalResult matchAndRewrite(FuseMemorySpaceCastConsumerOpInterface op,
+  LogicalResult matchAndRewrite(MemorySpaceCastConsumerOpInterface op,
                                 PatternRewriter &rewriter) const override {
-    bool modifiedInPlace = false;
-    FailureOr<SmallVector<Value>> results =
-        op.fuseCastOperands(rewriter, modifiedInPlace);
-    assert((!failed(results) || !modifiedInPlace) &&
-           "expected `modifiedInPlace` to be false on fusion failure");
+    FailureOr<std::pair<SmallVector<Value>, bool>> results =
+        op.bubbleDownCasts(rewriter);
     if (failed(results))
       return failure();
-    if (modifiedInPlace) {
+    if (results->second) {
       rewriter.modifyOpInPlace(op, []() {});
       return success();
     }
-    rewriter.replaceOp(op, *results);
+    rewriter.replaceOp(op, results->first);
     return success();
   }
 };
 
 //===----------------------------------------------------------------------===//
-// FuseMemorySpaceCastsIntoConsumers pass
+// BubbleDownMemorySpaceCasts pass
 //===----------------------------------------------------------------------===//
 
-struct FuseMemorySpaceCastsIntoConsumers
-    : public impl::FuseMemorySpaceCastsIntoConsumersBase<
-          FuseMemorySpaceCastsIntoConsumers> {
-  using impl::FuseMemorySpaceCastsIntoConsumersBase<
-      FuseMemorySpaceCastsIntoConsumers>::FuseMemorySpaceCastsIntoConsumersBase;
+struct BubbleDownMemorySpaceCasts
+    : public impl::BubbleDownMemorySpaceCastsBase<BubbleDownMemorySpaceCasts> {
+  using impl::BubbleDownMemorySpaceCastsBase<
+      BubbleDownMemorySpaceCasts>::BubbleDownMemorySpaceCastsBase;
 
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
-    populateFuseMemorySpaceCastIntoConsumersPatterns(patterns);
+    populateBubbleDownMemorySpaceCastPatterns(patterns, PatternBenefit());
     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
       signalPassFailure();
   }
 };
 } // namespace
 
-void mlir::populateFuseMemorySpaceCastIntoConsumersPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<FuseCastsPattern>(patterns.getContext());
+void mlir::populateBubbleDownMemorySpaceCastPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<BubbleDownCastsPattern>(patterns.getContext(), benefit);
 }
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index e9a7d3e4abe99..54b67f5c7a91e 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -6,7 +6,7 @@ add_mlir_library(MLIRTransforms
   ControlFlowSink.cpp
   CSE.cpp
   GenerateRuntimeVerification.cpp
-  FuseMemorySpaceCastsIntoConsumers.cpp
+  BubbleDownMemorySpaceCasts.cpp
   InlinerPass.cpp
   LocationSnapshot.cpp
   LoopInvariantCodeMotion.cpp
diff --git a/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir b/mlir/test/Transforms/test-bubble-down-memory-space-casts.mlir
similarity index 99%
rename from mlir/test/Transforms/test-fuse-casts-into-consumers.mlir
rename to mlir/test/Transforms/test-bubble-down-memory-space-casts.mlir
index 7534332b3663a..e4fce89cffb45 100644
--- a/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir
+++ b/mlir/test/Transforms/test-bubble-down-memory-space-casts.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --fuse-memory-space-casts-into-consumers | FileCheck %s
+// RUN: mlir-opt %s --bubble-down-memory-space-casts | FileCheck %s
 
 #map = affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>
 

>From 7a881b6cf29e2f7d1c7a8fbcf3d0c9edb63cfd8f Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Fri, 19 Sep 2025 09:04:19 -0400
Subject: [PATCH 04/10] Update mlir/include/mlir/Interfaces/MemOpInterfaces.td

Co-authored-by: Copilot <175728472+Copilot at users.noreply.github.com>
---
 mlir/include/mlir/Interfaces/MemOpInterfaces.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.td b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
index d097b00c8e80c..575fd0af7e020 100644
--- a/mlir/include/mlir/Interfaces/MemOpInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
@@ -90,7 +90,7 @@ def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
         Returns whether the memory-space cast is lossless. A lossless
         memory-space cast must not lose any information encoded in the memory
         space. An example of such cast, is any conversion to the generic memory
-        space. 
+        space.
       }],
       "bool", "isLosslessCast"
     >

>From 7f4a7f95e4f9040eec8cc47f2270e7aeb039fea4 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Fri, 19 Sep 2025 13:30:58 +0000
Subject: [PATCH 05/10] fix benefit

---
 mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp b/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp
index 96e0e8d584ea7..b9f00d4d4e23e 100644
--- a/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp
+++ b/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp
@@ -56,7 +56,7 @@ struct BubbleDownMemorySpaceCasts
 
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
-    populateBubbleDownMemorySpaceCastPatterns(patterns, PatternBenefit());
+    populateBubbleDownMemorySpaceCastPatterns(patterns, PatternBenefit(1));
     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
       signalPassFailure();
   }

>From 150b79198701b51d4d7a17fc41ad8eb9f530c256 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Mon, 22 Sep 2025 06:37:44 -0400
Subject: [PATCH 06/10] Update mlir/include/mlir/Interfaces/MemOpInterfaces.td

Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
---
 mlir/include/mlir/Interfaces/MemOpInterfaces.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.td b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
index 575fd0af7e020..bdecac2b3512f 100644
--- a/mlir/include/mlir/Interfaces/MemOpInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
@@ -84,7 +84,7 @@ def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
       }],
       "::mlir::MemorySpaceCastOpInterface", "cloneMemorySpaceCastOp",
       (ins "::mlir::OpBuilder &":$builder, "::mlir::PtrLikeTypeInterface":$tgt,
-           "::mlir::Value":$src)
+           "::mlir::TypedValue<::mlir::PtrLikeTypeInterface>>":$src)
     >,
     InterfaceMethod<[{
         Returns whether the memory-space cast is lossless. A lossless

>From 231ae132f79208e5bf6ba6888a73d7ba89a15183 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Mon, 22 Sep 2025 10:44:47 +0000
Subject: [PATCH 07/10] fix build

Signed-off-by: Fabian Mora <fabian.mora-cordero at amd.com>
---
 mlir/include/mlir/Interfaces/MemOpInterfaces.td |  2 +-
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp        | 15 +++++++--------
 2 files changed, 8 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.td b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
index bdecac2b3512f..3a5affb55ebbc 100644
--- a/mlir/include/mlir/Interfaces/MemOpInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
@@ -84,7 +84,7 @@ def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
       }],
       "::mlir::MemorySpaceCastOpInterface", "cloneMemorySpaceCastOp",
       (ins "::mlir::OpBuilder &":$builder, "::mlir::PtrLikeTypeInterface":$tgt,
-           "::mlir::TypedValue<::mlir::PtrLikeTypeInterface>>":$src)
+           "::mlir::TypedValue<::mlir::PtrLikeTypeInterface>":$src)
     >,
     InterfaceMethod<[{
         Returns whether the memory-space cast is lossless. A lossless
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 6f276efb84c1c..b600d0d32293c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -163,8 +163,9 @@ bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder,
       llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
 
   // Insert a memory-space cast to the original memory space of the op.
-  MemorySpaceCastOpInterface result =
-      castOp.cloneMemorySpaceCastOp(builder, tgtTy, newOp.getResult());
+  MemorySpaceCastOpInterface result = castOp.cloneMemorySpaceCastOp(
+      builder, tgtTy,
+      cast<TypedValue<PtrLikeTypeInterface>>(newOp.getResult()));
   return std::make_pair(SmallVector<Value>({result.getTargetPtr()}), false);
 }
 
@@ -1732,12 +1733,10 @@ bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
          tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
 }
 
-MemorySpaceCastOpInterface
-MemorySpaceCastOp::cloneMemorySpaceCastOp(OpBuilder &b,
-                                          PtrLikeTypeInterface tgt, Value src) {
-  assert(
-      isValidMemorySpaceCast(tgt, cast<PtrLikeTypeInterface>(src.getType())) &&
-      "invalid arguments");
+MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(
+    OpBuilder &b, PtrLikeTypeInterface tgt,
+    TypedValue<PtrLikeTypeInterface> src) {
+  assert(isValidMemorySpaceCast(tgt, src.getType()) && "invalid arguments");
   return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
 }
 

>From 13214545f1457e64485cb1adcfe38ae4c11710ba Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Mon, 22 Sep 2025 11:24:17 +0000
Subject: [PATCH 08/10] rename isLosslesscast method

---
 mlir/include/mlir/Interfaces/MemOpInterfaces.td | 15 +++++++--------
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp        |  6 +++---
 mlir/lib/Interfaces/MemOpInterfaces.cpp         |  2 +-
 3 files changed, 11 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.td b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
index 3a5affb55ebbc..02e01d81912b2 100644
--- a/mlir/include/mlir/Interfaces/MemOpInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
@@ -87,12 +87,11 @@ def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
            "::mlir::TypedValue<::mlir::PtrLikeTypeInterface>":$src)
     >,
     InterfaceMethod<[{
-        Returns whether the memory-space cast is lossless. A lossless
-        memory-space cast must not lose any information encoded in the memory
-        space. An example of such cast, is any conversion to the generic memory
-        space.
+        Returns whether the source pointer of the memory-space cast can be used
+        by the `MemorySpaceCastConsumerOpInterface::bubbleDownCasts` method to
+        promote the source pointer and bubble down the cast.
       }],
-      "bool", "isLosslessCast"
+      "bool", "isSourcePromotable"
     >
   ];
   let verify = [{
@@ -102,12 +101,12 @@ def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
   let extraClassDeclaration = [{
     /// Returns the underlying `MemorySpaceCastOpInterface` op if `value`
     /// is produced by a `MemorySpaceCastOpInterface` op, and
-    /// `isLosslessCast` returns true, otherwise it returns null.
+    /// `isSourcePromotable` returns true, otherwise it returns null.
     static ::mlir::MemorySpaceCastOpInterface
-    getIfLosslessCast(::mlir::Value value) {
+    getIfPromotableCast(::mlir::Value value) {
       auto op = ::llvm::dyn_cast_or_null<::mlir::MemorySpaceCastOpInterface>(
         value.getDefiningOp());
-      if (!op || !op.isLosslessCast())
+      if (!op || !op.isSourcePromotable())
         return nullptr;
       return op;
     }
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b600d0d32293c..cc82602239d48 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -116,7 +116,7 @@ static void constifyIndexValues(SmallVectorImpl<OpFoldResult> &values,
 static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
 getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src) {
   MemorySpaceCastOpInterface castOp =
-      MemorySpaceCastOpInterface::getIfLosslessCast(src);
+      MemorySpaceCastOpInterface::getIfPromotableCast(src);
 
   // Bail if the cast is not lossless.
   if (!castOp)
@@ -1740,8 +1740,8 @@ MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(
   return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
 }
 
-bool MemorySpaceCastOp::isLosslessCast() {
-  // The only cast we recognize as lossless is to the generic space.
+bool MemorySpaceCastOp::isSourcePromotable() {
+  // The only cast we recognize as promotable is to the generic space.
   return getDest().getType().getMemorySpace() == nullptr;
 }
 
diff --git a/mlir/lib/Interfaces/MemOpInterfaces.cpp b/mlir/lib/Interfaces/MemOpInterfaces.cpp
index 10303185ad833..c29c7a9244651 100644
--- a/mlir/lib/Interfaces/MemOpInterfaces.cpp
+++ b/mlir/lib/Interfaces/MemOpInterfaces.cpp
@@ -59,7 +59,7 @@ FailureOr<std::pair<SmallVector<Value>, bool>>
 mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand,
                                                    ValueRange results) {
   MemorySpaceCastOpInterface castOp =
-      MemorySpaceCastOpInterface::getIfLosslessCast(operand.get());
+      MemorySpaceCastOpInterface::getIfPromotableCast(operand.get());
 
   // Bail if the src is not valid.
   if (!castOp)

>From d753be0129f123c711ccd341a4176cc799414afa Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Tue, 23 Sep 2025 11:15:11 +0000
Subject: [PATCH 09/10] use std::optional in bubbleDownCasts

Signed-off-by: Fabian Mora <fabian.mora-cordero at amd.com>
---
 .../include/mlir/Interfaces/MemOpInterfaces.h |  4 +--
 .../mlir/Interfaces/MemOpInterfaces.td        |  7 +++--
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 31 ++++++++++---------
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 20 ++++++------
 mlir/lib/Interfaces/MemOpInterfaces.cpp       |  4 +--
 .../Transforms/BubbleDownMemorySpaceCasts.cpp |  6 ++--
 6 files changed, 37 insertions(+), 35 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.h b/mlir/include/mlir/Interfaces/MemOpInterfaces.h
index d4ed71e38f4ff..cdc423f5da1a5 100644
--- a/mlir/include/mlir/Interfaces/MemOpInterfaces.h
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.h
@@ -22,10 +22,10 @@ namespace detail {
 LogicalResult verifyMemorySpaceCastOpInterface(Operation *op);
 
 /// Tries to bubble-down inplace a `MemorySpaceCastOpInterface` operation
-/// referenced by `operand`. On success, it returns `results` and true. It
+/// referenced by `operand`. On success, it returns `std::nullopt`. It
 /// returns failure if `operand` doesn't reference a
 /// `MemorySpaceCastOpInterface` op.
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results);
 } // namespace detail
 } // namespace mlir
diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.td b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
index 02e01d81912b2..0c7aff8cd7ff3 100644
--- a/mlir/include/mlir/Interfaces/MemOpInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
@@ -26,8 +26,9 @@ def MemorySpaceCastConsumerOpInterface :
   let methods = [
     InterfaceMethod<[{
         Attempt to bubble-down the incoming cast-like operands. On success
-        returns any new results, and whether the operation was modified in
-        place, otherwise it returns failure.
+        returns a `std::optional<SmallVector<Value>>`, otherwise it returns
+        failure. If the optional is `std::nullopt` then the cast was performed
+        in place, otherwise the method returns a list of replacement values.
         If new results are produced, these must be compatible with the original
         operation results.
 
@@ -44,7 +45,7 @@ def MemorySpaceCastConsumerOpInterface :
         Finally, any implementations of this method have to guarantee that the
         IR remains valid at all times.
       }],
-      "::llvm::FailureOr<std::pair<::llvm::SmallVector<::mlir::Value>, bool>>",
+      "::llvm::FailureOr<std::optional<::llvm::SmallVector<::mlir::Value>>>",
       "bubbleDownCasts",
       (ins "::mlir::OpBuilder &":$builder)
     >,
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index cc82602239d48..349b4deb29023 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -144,7 +144,7 @@ getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src) {
 /// Implementation of `bubbleDownCasts` method for memref operations that
 /// return a single memref result.
 template <typename ConcreteOpTy>
-static FailureOr<std::pair<SmallVector<Value>, bool>>
+static FailureOr<std::optional<SmallVector<Value>>>
 bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder,
                                  OpOperand &src) {
   auto [castOp, tgtTy, resTy] = getMemorySpaceCastInfo(op.getType(), src.get());
@@ -166,7 +166,8 @@ bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder,
   MemorySpaceCastOpInterface result = castOp.cloneMemorySpaceCastOp(
       builder, tgtTy,
       cast<TypedValue<PtrLikeTypeInterface>>(newOp.getResult()));
-  return std::make_pair(SmallVector<Value>({result.getTargetPtr()}), false);
+  return std::optional<SmallVector<Value>>(
+      SmallVector<Value>({result.getTargetPtr()}));
 }
 
 //===----------------------------------------------------------------------===//
@@ -600,7 +601,7 @@ OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
   return getMemref();
 }
 
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
   return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
 }
@@ -773,7 +774,7 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
 }
 
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 CastOp::bubbleDownCasts(OpBuilder &builder) {
   return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
 }
@@ -1669,7 +1670,7 @@ OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
   return OpFoldResult();
 }
 
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 LoadOp::bubbleDownCasts(OpBuilder &builder) {
   return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
                                                             getResult());
@@ -1740,8 +1741,8 @@ MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(
   return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
 }
 
+/// The only cast we recognize as promotable is to the generic space.
 bool MemorySpaceCastOp::isSourcePromotable() {
-  // The only cast we recognize as promotable is to the generic space.
   return getDest().getType().getMemorySpace() == nullptr;
 }
 
@@ -2141,7 +2142,7 @@ void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
 }
 
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 ReinterpretCastOp::bubbleDownCasts(OpBuilder &builder) {
   return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
 }
@@ -2453,7 +2454,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
       ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
 }
 
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 ExpandShapeOp::bubbleDownCasts(OpBuilder &builder) {
   return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
 }
@@ -2679,7 +2680,7 @@ OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
                                                        adaptor.getOperands());
 }
 
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 CollapseShapeOp::bubbleDownCasts(OpBuilder &builder) {
   return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
 }
@@ -2724,7 +2725,7 @@ LogicalResult ReshapeOp::verify() {
   return success();
 }
 
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 ReshapeOp::bubbleDownCasts(OpBuilder &builder) {
   return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
 }
@@ -2746,7 +2747,7 @@ LogicalResult StoreOp::fold(FoldAdaptor adaptor,
   return foldMemRefCast(*this, getValueToStore());
 }
 
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 StoreOp::bubbleDownCasts(OpBuilder &builder) {
   return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
                                                             ValueRange());
@@ -3408,7 +3409,7 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 SubViewOp::bubbleDownCasts(OpBuilder &builder) {
   return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
 }
@@ -3513,7 +3514,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor) {
   return {};
 }
 
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 TransposeOp::bubbleDownCasts(OpBuilder &builder) {
   return bubbleDownCastsPassthroughOpImpl(*this, builder, getInMutable());
 }
@@ -3661,7 +3662,7 @@ void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
 }
 
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 ViewOp::bubbleDownCasts(OpBuilder &builder) {
   return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
 }
@@ -3711,7 +3712,7 @@ OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
   return OpFoldResult();
 }
 
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 AtomicRMWOp::bubbleDownCasts(OpBuilder &builder) {
   return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
                                                             getResult());
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 77dcb1fc6220e..b2e5a5b1e36cc 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5087,7 +5087,7 @@ void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<TransferReadAfterWriteToBroadcast>(context);
 }
 
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
   if (!hasPureBufferSemantics())
     return failure();
@@ -5582,7 +5582,7 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
 }
 
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
   if (!hasPureBufferSemantics())
     return failure();
@@ -5644,7 +5644,7 @@ std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
   return llvm::to_vector<4>(getVectorType().getShape());
 }
 
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 LoadOp::bubbleDownCasts(OpBuilder &builder) {
   return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
                                                             getResult());
@@ -5689,7 +5689,7 @@ std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
   return llvm::to_vector<4>(getVectorType().getShape());
 }
 
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 StoreOp::bubbleDownCasts(OpBuilder &builder) {
   return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
                                                             ValueRange());
@@ -5749,7 +5749,7 @@ OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
   return OpFoldResult();
 }
 
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
   return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
                                                             getResult());
@@ -5805,7 +5805,7 @@ LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
   return memref::foldMemRefCast(*this);
 }
 
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
   return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
                                                             ValueRange());
@@ -5914,7 +5914,7 @@ void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<GatherFolder, FoldContiguousGather>(context);
 }
 
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 GatherOp::bubbleDownCasts(OpBuilder &builder) {
   return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
                                                             getResult());
@@ -5982,7 +5982,7 @@ void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ScatterFolder, FoldContiguousScatter>(context);
 }
 
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 ScatterOp::bubbleDownCasts(OpBuilder &builder) {
   return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
                                                             ValueRange());
@@ -6036,7 +6036,7 @@ void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ExpandLoadFolder>(context);
 }
 
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
   return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
                                                             getResult());
@@ -6088,7 +6088,7 @@ void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<CompressStoreFolder>(context);
 }
 
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
   return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
                                                             ValueRange());
diff --git a/mlir/lib/Interfaces/MemOpInterfaces.cpp b/mlir/lib/Interfaces/MemOpInterfaces.cpp
index c29c7a9244651..fe5c717f67bc4 100644
--- a/mlir/lib/Interfaces/MemOpInterfaces.cpp
+++ b/mlir/lib/Interfaces/MemOpInterfaces.cpp
@@ -55,7 +55,7 @@ LogicalResult mlir::detail::verifyMemorySpaceCastOpInterface(Operation *op) {
   return success();
 }
 
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
 mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand,
                                                    ValueRange results) {
   MemorySpaceCastOpInterface castOp =
@@ -67,7 +67,7 @@ mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand,
 
   // Modify the op.
   operand.set(castOp.getSourcePtr());
-  return std::make_pair(llvm::to_vector_of<Value>(results), true);
+  return std::optional<SmallVector<Value>>();
 }
 
 #include "mlir/Interfaces/MemOpInterfaces.cpp.inc"
diff --git a/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp b/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp
index b9f00d4d4e23e..00dac19e37171 100644
--- a/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp
+++ b/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp
@@ -32,15 +32,15 @@ struct BubbleDownCastsPattern
 
   LogicalResult matchAndRewrite(MemorySpaceCastConsumerOpInterface op,
                                 PatternRewriter &rewriter) const override {
-    FailureOr<std::pair<SmallVector<Value>, bool>> results =
+    FailureOr<std::optional<SmallVector<Value>>> results =
         op.bubbleDownCasts(rewriter);
     if (failed(results))
       return failure();
-    if (results->second) {
+    if (!results->has_value()) {
       rewriter.modifyOpInPlace(op, []() {});
       return success();
     }
-    rewriter.replaceOp(op, results->first);
+    rewriter.replaceOp(op, **results);
     return success();
   }
 };

>From e70b0f2b460c6b19152e5a13333c1fd5a97a082e Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Tue, 23 Sep 2025 17:48:06 +0000
Subject: [PATCH 10/10] improve docs

---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       | 23 ++++++++++++++++++-
 .../mlir/Interfaces/MemOpInterfaces.td        | 12 ++++++++--
 mlir/include/mlir/Transforms/Passes.td        |  8 +++++--
 3 files changed, 38 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index c708d7f3d884a..bddf766d8eb21 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1288,7 +1288,7 @@ def LoadOp : MemRef_Op<"load",
 def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
       DeclareOpInterfaceMethods<CastOpInterface>,
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-      DeclareOpInterfaceMethods<MemorySpaceCastOpInterface>,
+      MemorySpaceCastOpInterface,
       MemRefsNormalizable,
       Pure,
       SameOperandsAndResultElementType,
@@ -1326,6 +1326,27 @@ def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
 
   let extraClassDeclaration = [{
     Value getViewSource() { return getSource(); }
+
+    //===------------------------------------------------------------------===//
+    // MemorySpaceCastConsumerOpInterface
+    //===------------------------------------------------------------------===//
+    /// Returns the `source` memref.
+    TypedValue<PtrLikeTypeInterface> getSourcePtr();
+    /// Returns the `dest` memref.
+    TypedValue<PtrLikeTypeInterface> getTargetPtr();
+    /// Returns whether the memory-space cast is valid. Only casts between
+    /// memrefs are considered valid. Further, the `tgt` and `src` should only
+    /// differ on the memory-space parameter of the memref type.
+    bool isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
+                                PtrLikeTypeInterface src);
+    /// Clones the operation using a new target type and source value.
+    MemorySpaceCastOpInterface cloneMemorySpaceCastOp(
+        OpBuilder &b, PtrLikeTypeInterface tgt,
+        TypedValue<PtrLikeTypeInterface> src);
+    /// Returns whether the `source` value can be promoted by the
+    /// `MemorySpaceCastConsumerOpInterface::bubbleDownCasts` method. The only
+    /// casts the op recognizes as promotable are to the generic memory-space.
+    bool isSourcePromotable();
   }];
 
   let hasFolder = 1;
diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.td b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
index 0c7aff8cd7ff3..1a64e97c3412d 100644
--- a/mlir/include/mlir/Interfaces/MemOpInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
@@ -14,13 +14,15 @@
 #define MLIR_INTERFACES_MEMOPINTERFACES_TD
 
 include "mlir/IR/OpBase.td"
-include "mlir/Interfaces/SideEffectInterfaces.td"
 
 def MemorySpaceCastConsumerOpInterface :
     OpInterface<"MemorySpaceCastConsumerOpInterface"> {
   let description = [{
     An interface for operations that can consume memory-space cast-like
     operations.
+
+    This interface can be used to bubble-down memory-space cast operations,
+    see the `bubble-down-memory-space-casts` pass for an example.
   }];
   let cppNamespace = "::mlir";
   let methods = [
@@ -59,6 +61,10 @@ def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
 
     These operations expect to have a well-defined ptr-like operand, and
     a well-defined target ptr-like result.
+
+    This interface also allows to determine whether a cast can be bubbled-down
+    by the `MemorySpaceCastConsumerOpInterface`, allowing control over which
+    casts can be bubbled-down or not.
   }];
   let cppNamespace = "::mlir";
   let methods = [
@@ -91,6 +97,9 @@ def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
         Returns whether the source pointer of the memory-space cast can be used
         by the `MemorySpaceCastConsumerOpInterface::bubbleDownCasts` method to
         promote the source pointer and bubble down the cast.
+
+        For example, a cast operation might decide that all casts to the generic
+        memory-space can be promoted. 
       }],
       "bool", "isSourcePromotable"
     >
@@ -98,7 +107,6 @@ def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
   let verify = [{
     return ::mlir::detail::verifyMemorySpaceCastOpInterface($_op);
   }];
-  let dependentTraits = [Pure];
   let extraClassDeclaration = [{
     /// Returns the underlying `MemorySpaceCastOpInterface` op if `value`
     /// is produced by a `MemorySpaceCastOpInterface` op, and
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 8f0b80c5e511b..b2b7f20a497e3 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -590,8 +590,12 @@ def BubbleDownMemorySpaceCasts :
   let summary = "Bubbles down memory-space cast operations.";
   let description = [{
     This pass tries to iteratively bubble down all possible memory-space cast
-    operations. It does this by looking for `MemorySpaceCastConsumerOpInterface`
-    operations, and invoking the interface methods to perform the bubbling down.
+    operations. It is important to note that the determination of which casts
+    are bubbled down is based on the interfaces
+    `MemorySpaceCastConsumerOpInterface`, and `MemorySpaceCastOpInterface`, and
+    not the pass. The pass only looks for operations implementing the
+    `MemorySpaceCastConsumerOpInterface` interface, and invoking the interface
+    methods to perform the bubbling down.
 
     Example:
 



More information about the Mlir-commits mailing list