[Mlir-commits] [mlir] 077a796 - [mlir] Implement a memory-space cast bubbling-down transform (#159454)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 24 06:11:47 PDT 2025


Author: Fabian Mora
Date: 2025-09-24T09:11:43-04:00
New Revision: 077a796c0da6c32d5d43af9945b4c2728f43633f

URL: https://github.com/llvm/llvm-project/commit/077a796c0da6c32d5d43af9945b4c2728f43633f
DIFF: https://github.com/llvm/llvm-project/commit/077a796c0da6c32d5d43af9945b4c2728f43633f.diff

LOG: [mlir] Implement a memory-space cast bubbling-down transform (#159454)

This commit adds functionality to bubble down memory-space casts
operations, allowing consumer operations to use the original
memory-space rather than first casting to a different memory space.

Changes:
- Introduce `MemorySpaceCastOpInterface` to handle memory-space cast
operations
- Create a `MemorySpaceCastConsumerOpInterface` pass that identifies and
bubbles down eligible casts
- Add implementation for memref and vector operations to handle
memory-space cast propagation
- Add `bubbleDownCasts` method to relevant operations to support the
fusion

In particular, in the current implementation only memory-space casts
into the default memory-space can be bubbled-down.

Example:

```mlir
func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
    %memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32>
    %c0 = arith.constant 0 : index
    %c4 = arith.constant 4 : index
    %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32>
    %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32>
    %loaded = memref.load %collapsed[%c0] : memref<16xf32>
    %added = arith.addf %loaded, %arg2 : f32
    memref.store %added, %collapsed[%c0] : memref<16xf32>
    %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32
    return %collapsed : memref<16xf32>
}
// mlir-opt --bubble-down-memory-space-casts
func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
    %c4 = arith.constant 4 : index
    %c0 = arith.constant 0 : index
    %expand_shape = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
    %collapse_shape = memref.collapse_shape %expand_shape [[0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
    %memspacecast = memref.memory_space_cast %collapse_shape : memref<16xf32, 1> to memref<16xf32>
    %0 = memref.load %collapse_shape[%c0] : memref<16xf32, 1>
    %1 = arith.addf %0, %arg2 : f32
    memref.store %1, %collapse_shape[%c0] : memref<16xf32, 1>
    %2 = memref.atomic_rmw addf %arg2, %collapse_shape[%c4] : (f32, memref<16xf32, 1>) -> f32
    return %memspacecast : memref<16xf32>
}
```

---------

Signed-off-by: Fabian Mora <fabian.mora-cordero at amd.com>
Co-authored-by: Mehdi Amini <joker.eph at gmail.com>

Added: 
    mlir/include/mlir/Interfaces/MemOpInterfaces.h
    mlir/include/mlir/Interfaces/MemOpInterfaces.td
    mlir/include/mlir/Transforms/BubbleDownMemorySpaceCasts.h
    mlir/lib/Interfaces/MemOpInterfaces.cpp
    mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp
    mlir/test/Transforms/test-bubble-down-memory-space-casts.mlir

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/include/mlir/Interfaces/CMakeLists.txt
    mlir/include/mlir/Transforms/Passes.h
    mlir/include/mlir/Transforms/Passes.td
    mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Interfaces/CMakeLists.txt
    mlir/lib/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index bdec699eb4ce4..30f33ed2fd1d6 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -18,6 +18,7 @@
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/MemOpInterfaces.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"

diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 671cc05e963b4..cd92ca98b2530 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -15,6 +15,7 @@ include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/MemOpInterfaces.td"
 include "mlir/Interfaces/MemorySlotInterfaces.td"
 include "mlir/Interfaces/ShapedOpInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -145,7 +146,8 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
       Pure,
       ViewLikeOpInterface,
-      SameOperandsAndResultType
+      SameOperandsAndResultType,
+      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
     ]> {
   let summary =
       "assumption that gives alignment information to the input memref";
@@ -456,6 +458,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
 def MemRef_CastOp : MemRef_Op<"cast", [
       DeclareOpInterfaceMethods<CastOpInterface>,
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
       MemRefsNormalizable,
       Pure,
       SameOperandsAndResultShape,
@@ -1194,6 +1197,7 @@ def LoadOp : MemRef_Op<"load",
                      "memref", "result",
                      "::llvm::cast<MemRefType>($_self).getElementType()">,
       MemRefsNormalizable,
+      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
       DeclareOpInterfaceMethods<PromotableMemOpInterface>,
       DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
   let summary = "load operation";
@@ -1284,6 +1288,7 @@ def LoadOp : MemRef_Op<"load",
 def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
       DeclareOpInterfaceMethods<CastOpInterface>,
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+      MemorySpaceCastOpInterface,
       MemRefsNormalizable,
       Pure,
       SameOperandsAndResultElementType,
@@ -1302,6 +1307,10 @@ def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
 
     If the source and target address spaces are the same, this operation is a noop.
 
+    Finally, if the target memory-space is the generic/default memory-space,
+    then it is assumed this cast can be bubbled down safely. See the docs of
+    `MemorySpaceCastOpInterface` interface for more details.
+
     Example:
 
     ```mlir
@@ -1321,6 +1330,27 @@ def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
 
   let extraClassDeclaration = [{
     Value getViewSource() { return getSource(); }
+
+    //===------------------------------------------------------------------===//
+    // MemorySpaceCastConsumerOpInterface
+    //===------------------------------------------------------------------===//
+    /// Returns the `source` memref.
+    TypedValue<PtrLikeTypeInterface> getSourcePtr();
+    /// Returns the `dest` memref.
+    TypedValue<PtrLikeTypeInterface> getTargetPtr();
+    /// Returns whether the memory-space cast is valid. Only casts between
+    /// memrefs are considered valid. Further, the `tgt` and `src` should only
+    /// 
diff er on the memory-space parameter of the memref type.
+    bool isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
+                                PtrLikeTypeInterface src);
+    /// Clones the operation using a new target type and source value.
+    MemorySpaceCastOpInterface cloneMemorySpaceCastOp(
+        OpBuilder &b, PtrLikeTypeInterface tgt,
+        TypedValue<PtrLikeTypeInterface> src);
+    /// Returns whether the `source` value can be promoted by the
+    /// `MemorySpaceCastConsumerOpInterface::bubbleDownCasts` method. The only
+    /// casts the op recognizes as promotable are to the generic memory-space.
+    bool isSourcePromotable();
   }];
 
   let hasFolder = 1;
@@ -1376,6 +1406,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
 def MemRef_ReinterpretCastOp
   : MemRef_OpWithOffsetSizesAndStrides<"reinterpret_cast", [
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
       AttrSizedOperandSegments,
       MemRefsNormalizable,
       Pure,
@@ -1603,6 +1634,7 @@ def MemRef_RankOp : MemRef_Op<"rank", [Pure]> {
 
 def MemRef_ReshapeOp: MemRef_Op<"reshape", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
     Pure,
     ViewLikeOpInterface]>  {
   let summary = "memref reshape operation";
@@ -1701,6 +1733,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
 
 def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
   let summary = "operation to produce a memref with a higher rank.";
   let description = [{
@@ -1822,7 +1855,9 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
 }
 
 def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
-    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
+  ]> {
   let summary = "operation to produce a memref with a smaller rank.";
   let description = [{
     The `memref.collapse_shape` op produces a new view with a smaller rank
@@ -1929,6 +1964,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
                      "memref", "value",
                      "::llvm::cast<MemRefType>($_self).getElementType()">,
       MemRefsNormalizable,
+      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
       DeclareOpInterfaceMethods<PromotableMemOpInterface>,
       DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
   let summary = "store operation";
@@ -2006,6 +2042,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
 
 def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
     DeclareOpInterfaceMethods<ViewLikeOpInterface>,
     AttrSizedOperandSegments,
     OffsetSizeAndStrideOpInterface,
@@ -2281,6 +2318,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
 
 def MemRef_TransposeOp : MemRef_Op<"transpose", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
     Pure]>,
     Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
     Results<(outs AnyStridedMemRef)> {
@@ -2316,6 +2354,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [
 
 def MemRef_ViewOp : MemRef_Op<"view", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
     DeclareOpInterfaceMethods<ViewLikeOpInterface>,
     Pure]> {
   let summary = "memref view operation";
@@ -2392,6 +2431,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
 //===----------------------------------------------------------------------===//
 
 def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
+      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
       AllTypesMatch<["value", "result"]>,
       TypesMatchWith<"value type matches element type of memref",
                      "memref", "value",

diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 63410b8bea747..bbf55f5d507e3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -27,6 +27,7 @@
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/IndexingMapOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/MemOpInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"

diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 26d06624cb976..252c0b72456df 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -24,6 +24,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/IndexingMapOpInterface.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/MemOpInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/VectorInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
@@ -1246,6 +1247,7 @@ def Vector_TransferReadOp :
       DeclareOpInterfaceMethods<MaskableOpInterface>,
       DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
       DeclareOpInterfaceMethods<ConditionallySpeculatable>,
+      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
       AttrSizedOperandSegments,
       DestinationStyleOpInterface
     ]>,
@@ -1493,6 +1495,7 @@ def Vector_TransferWriteOp :
       DeclareOpInterfaceMethods<MaskableOpInterface>,
       DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
       DeclareOpInterfaceMethods<ConditionallySpeculatable>,
+      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
       AttrSizedOperandSegments,
       DestinationStyleOpInterface
   ]>,
@@ -1649,6 +1652,7 @@ def Vector_TransferWriteOp :
 
 def Vector_LoadOp : Vector_Op<"load", [
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+    DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
   ]> {
   let summary = "reads an n-D slice of memory into an n-D vector";
   let description = [{
@@ -1765,6 +1769,7 @@ def Vector_LoadOp : Vector_Op<"load", [
 
 def Vector_StoreOp : Vector_Op<"store", [
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+    DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
   ]> {
   let summary = "writes an n-D vector to an n-D slice of memory";
   let description = [{
@@ -1869,7 +1874,7 @@ def Vector_StoreOp : Vector_Op<"store", [
 }
 
 def Vector_MaskedLoadOp :
-  Vector_Op<"maskedload">,
+  Vector_Op<"maskedload", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -1961,7 +1966,7 @@ def Vector_MaskedLoadOp :
 }
 
 def Vector_MaskedStoreOp :
-  Vector_Op<"maskedstore">,
+  Vector_Op<"maskedstore", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2041,6 +2046,7 @@ def Vector_MaskedStoreOp :
 def Vector_GatherOp :
   Vector_Op<"gather", [
     DeclareOpInterfaceMethods<MaskableOpInterface>,
+    DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
   ]>,
     Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
@@ -2144,7 +2150,7 @@ def Vector_GatherOp :
 }
 
 def Vector_ScatterOp :
-  Vector_Op<"scatter">,
+  Vector_Op<"scatter", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$offsets,
                VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
@@ -2229,7 +2235,7 @@ def Vector_ScatterOp :
 }
 
 def Vector_ExpandLoadOp :
-  Vector_Op<"expandload">,
+  Vector_Op<"expandload", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
                FixedVectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2317,7 +2323,7 @@ def Vector_ExpandLoadOp :
 }
 
 def Vector_CompressStoreOp :
-  Vector_Op<"compressstore">,
+  Vector_Op<"compressstore", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
                FixedVectorOfNonZeroRankOf<[I1]>:$mask,

diff  --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index 2add220fdfb7c..a5feb592045c0 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_interface(IndexingMapOpInterface)
 add_mlir_interface(InferIntRangeInterface)
 add_mlir_interface(InferTypeOpInterface)
 add_mlir_interface(LoopLikeInterface)
+add_mlir_interface(MemOpInterfaces)
 add_mlir_interface(ParallelCombiningOpInterface)
 add_mlir_interface(RuntimeVerifiableOpInterface)
 add_mlir_interface(ShapedOpInterfaces)

diff  --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.h b/mlir/include/mlir/Interfaces/MemOpInterfaces.h
new file mode 100644
index 0000000000000..cdc423f5da1a5
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.h
@@ -0,0 +1,36 @@
+//===- MemOpInterfaces.h - Memory operation interfaces ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains declarations of interfaces for operations that interact
+// with memory.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_MEMOPINTERFACES_H
+#define MLIR_INTERFACES_MEMOPINTERFACES_H
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace detail {
+/// Attempt to verify the given memory space cast operation.
+LogicalResult verifyMemorySpaceCastOpInterface(Operation *op);
+
+/// Tries to bubble-down inplace a `MemorySpaceCastOpInterface` operation
+/// referenced by `operand`. On success, it returns `std::nullopt`. It
+/// returns failure if `operand` doesn't reference a
+/// `MemorySpaceCastOpInterface` op.
+FailureOr<std::optional<SmallVector<Value>>>
+bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results);
+} // namespace detail
+} // namespace mlir
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/MemOpInterfaces.h.inc"
+
+#endif // MLIR_INTERFACES_MEMOPINTERFACES_H

diff  --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.td b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
new file mode 100644
index 0000000000000..1a64e97c3412d
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
@@ -0,0 +1,125 @@
+//===- MemOpInterfaces.td - Memory operation interfaces -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains interfaces for operations that interact with memory.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_MEMOPINTERFACES_TD
+#define MLIR_INTERFACES_MEMOPINTERFACES_TD
+
+include "mlir/IR/OpBase.td"
+
+def MemorySpaceCastConsumerOpInterface :
+    OpInterface<"MemorySpaceCastConsumerOpInterface"> {
+  let description = [{
+    An interface for operations that can consume memory-space cast-like
+    operations.
+
+    This interface can be used to bubble-down memory-space cast operations,
+    see the `bubble-down-memory-space-casts` pass for an example.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<[{
+        Attempt to bubble-down the incoming cast-like operands. On success
+        returns a `std::optional<SmallVector<Value>>`, otherwise it returns
+        failure. If the optional is `std::nullopt` then the cast was performed
+        in place, otherwise the method returns a list of replacement values.
+        If new results are produced, these must be compatible with the original
+        operation results.
+
+        If the operation was not modified in place, then the interface
+        guarantees it is valid to erase the original operation.
+        If the operation was modified in place, then the interface must
+        guarantee no operations were created by the method, and that no further
+        IR modification is necessary.
+
+        Any implementations of this method must not erase/replace the original
+        operation, instead it is the caller responsibility to erase or replace
+        the op with the results provided by the method.
+
+        Finally, any implementations of this method have to guarantee that the
+        IR remains valid at all times.
+      }],
+      "::llvm::FailureOr<std::optional<::llvm::SmallVector<::mlir::Value>>>",
+      "bubbleDownCasts",
+      (ins "::mlir::OpBuilder &":$builder)
+    >,
+  ];
+}
+
+def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
+  let description = [{
+    An interface for operations that perform memory-space casts. This
+    interface assumes that the cast operation is `pure`.
+
+    These operations expect to have a well-defined ptr-like operand, and
+    a well-defined target ptr-like result.
+
+    This interface also allows to determine whether a cast can be bubbled-down
+    by the `MemorySpaceCastConsumerOpInterface`, allowing control over which
+    casts can be bubbled-down or not.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<[{
+        Returns the source ptr-like value.
+      }],
+      "::mlir::TypedValue<::mlir::PtrLikeTypeInterface>",  "getSourcePtr"
+    >,
+    InterfaceMethod<[{
+        Returns the target ptr-like value.
+      }],
+      "::mlir::TypedValue<::mlir::PtrLikeTypeInterface>", "getTargetPtr"
+    >,
+    InterfaceMethod<[{
+        Returns whether the memory space cast specified by `tgt` and `src`
+        is supported.
+      }],
+      "bool", "isValidMemorySpaceCast",
+      (ins "::mlir::PtrLikeTypeInterface":$tgt,
+           "::mlir::PtrLikeTypeInterface":$src)
+    >,
+    InterfaceMethod<[{
+        Clones the memory space cast op with the given source and target type.
+      }],
+      "::mlir::MemorySpaceCastOpInterface", "cloneMemorySpaceCastOp",
+      (ins "::mlir::OpBuilder &":$builder, "::mlir::PtrLikeTypeInterface":$tgt,
+           "::mlir::TypedValue<::mlir::PtrLikeTypeInterface>":$src)
+    >,
+    InterfaceMethod<[{
+        Returns whether the source pointer of the memory-space cast can be used
+        by the `MemorySpaceCastConsumerOpInterface::bubbleDownCasts` method to
+        promote the source pointer and bubble down the cast.
+
+        For example, a cast operation might decide that all casts to the generic
+        memory-space can be promoted. 
+      }],
+      "bool", "isSourcePromotable"
+    >
+  ];
+  let verify = [{
+    return ::mlir::detail::verifyMemorySpaceCastOpInterface($_op);
+  }];
+  let extraClassDeclaration = [{
+    /// Returns the underlying `MemorySpaceCastOpInterface` op if `value`
+    /// is produced by a `MemorySpaceCastOpInterface` op, and
+    /// `isSourcePromotable` returns true, otherwise it returns null.
+    static ::mlir::MemorySpaceCastOpInterface
+    getIfPromotableCast(::mlir::Value value) {
+      auto op = ::llvm::dyn_cast_or_null<::mlir::MemorySpaceCastOpInterface>(
+        value.getDefiningOp());
+      if (!op || !op.isSourcePromotable())
+        return nullptr;
+      return op;
+    }
+  }];
+}
+
+#endif // MLIR_INTERFACES_MEMOPINTERFACES_TD

diff  --git a/mlir/include/mlir/Transforms/BubbleDownMemorySpaceCasts.h b/mlir/include/mlir/Transforms/BubbleDownMemorySpaceCasts.h
new file mode 100644
index 0000000000000..99db092879a90
--- /dev/null
+++ b/mlir/include/mlir/Transforms/BubbleDownMemorySpaceCasts.h
@@ -0,0 +1,20 @@
+//===-- BubbleDownMemorySpaceCasts.h - Bubble down cast patterns ---C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_BUBBLEDOWNMEMORYSPACECASTS_H
+#define MLIR_TRANSFORMS_BUBBLEDOWNMEMORYSPACECASTS_H
+
+namespace mlir {
+class PatternBenefit;
+class RewritePatternSet;
+/// Collect a set of patterns to bubble-down memory-space cast operations.
+void populateBubbleDownMemorySpaceCastPatterns(RewritePatternSet &patterns,
+                                               PatternBenefit benefit);
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_BUBBLEDOWNMEMORYSPACECASTS_H

diff  --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 9cd2ef34e15ea..1c035f2a843ff 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -46,6 +46,7 @@ class GreedyRewriteConfig;
 #define GEN_PASS_DECL_SYMBOLPRIVATIZE
 #define GEN_PASS_DECL_TOPOLOGICALSORT
 #define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
+#define GEN_PASS_DECL_BUBBLEDOWNMEMORYSPACECASTS
 #include "mlir/Transforms/Passes.h.inc"
 
 /// Creates an instance of the Canonicalizer pass, configured with default

diff  --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index beb59784947c5..b2b7f20a497e3 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -585,4 +585,48 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
   ];
 }
 
+def BubbleDownMemorySpaceCasts :
+    Pass<"bubble-down-memory-space-casts"> {
+  let summary = "Bubbles down memory-space cast operations.";
+  let description = [{
+    This pass tries to iteratively bubble down all possible memory-space cast
+    operations. It is important to note that the determination of which casts
+    are bubbled down is based on the interfaces
+    `MemorySpaceCastConsumerOpInterface`, and `MemorySpaceCastOpInterface`, and
+    not the pass. The pass only looks for operations implementing the
+    `MemorySpaceCastConsumerOpInterface` interface, and invoking the interface
+    methods to perform the bubbling down.
+
+    Example:
+
+    ```mlir
+    func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
+      %memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32>
+      %c0 = arith.constant 0 : index
+      %c4 = arith.constant 4 : index
+      %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32>
+      %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32>
+      %loaded = memref.load %collapsed[%c0] : memref<16xf32>
+      %added = arith.addf %loaded, %arg2 : f32
+      memref.store %added, %collapsed[%c0] : memref<16xf32>
+      %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32
+      return %collapsed : memref<16xf32>
+    }
+    // mlir-opt --bubble-down-memory-space-casts
+    func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
+      %c4 = arith.constant 4 : index
+      %c0 = arith.constant 0 : index
+      %expand_shape = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
+      %collapse_shape = memref.collapse_shape %expand_shape [[0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
+      %memspacecast = memref.memory_space_cast %collapse_shape : memref<16xf32, 1> to memref<16xf32>
+      %0 = memref.load %collapse_shape[%c0] : memref<16xf32, 1>
+      %1 = arith.addf %0, %arg2 : f32
+      memref.store %1, %collapse_shape[%c0] : memref<16xf32, 1>
+      %2 = memref.atomic_rmw addf %arg2, %collapse_shape[%c4] : (f32, memref<16xf32, 1>) -> f32
+      return %memspacecast : memref<16xf32>
+    }
+    ```
+  }];
+}
+
 #endif // MLIR_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index 734294bd014c6..e25a0121a3359 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
   MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRIR
+  MLIRMemOpInterfaces
   MLIRMemorySlotInterfaces
   MLIRShapedOpInterfaces
   MLIRSideEffectInterfaces

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 5d15d5f6e3de4..349b4deb29023 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -111,6 +111,65 @@ static void constifyIndexValues(SmallVectorImpl<OpFoldResult> &values,
   }
 }
 
+/// Helper function to retrieve a lossless memory-space cast, and the
+/// corresponding new result memref type.
+static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
+getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src) {
+  MemorySpaceCastOpInterface castOp =
+      MemorySpaceCastOpInterface::getIfPromotableCast(src);
+
+  // Bail if the cast is not lossless.
+  if (!castOp)
+    return {};
+
+  // Transform the source and target type of `castOp` to have the same metadata
+  // as `resultTy`. Bail if not possible.
+  FailureOr<PtrLikeTypeInterface> srcTy = resultTy.clonePtrWith(
+      castOp.getSourcePtr().getType().getMemorySpace(), std::nullopt);
+  if (failed(srcTy))
+    return {};
+
+  FailureOr<PtrLikeTypeInterface> tgtTy = resultTy.clonePtrWith(
+      castOp.getTargetPtr().getType().getMemorySpace(), std::nullopt);
+  if (failed(tgtTy))
+    return {};
+
+  // Check if this is a valid memory-space cast.
+  if (!castOp.isValidMemorySpaceCast(*tgtTy, *srcTy))
+    return {};
+
+  return std::make_tuple(castOp, *tgtTy, *srcTy);
+}
+
+/// Implementation of `bubbleDownCasts` method for memref operations that
+/// return a single memref result.
+template <typename ConcreteOpTy>
+static FailureOr<std::optional<SmallVector<Value>>>
+bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder,
+                                 OpOperand &src) {
+  auto [castOp, tgtTy, resTy] = getMemorySpaceCastInfo(op.getType(), src.get());
+  // Bail if we cannot cast.
+  if (!castOp)
+    return failure();
+
+  // Create the new operands.
+  SmallVector<Value> operands;
+  llvm::append_range(operands, op->getOperands());
+  operands[src.getOperandNumber()] = castOp.getSourcePtr();
+
+  // Create the new op and results.
+  auto newOp = ConcreteOpTy::create(
+      builder, op.getLoc(), TypeRange(resTy), operands, op.getProperties(),
+      llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
+
+  // Insert a memory-space cast to the original memory space of the op.
+  MemorySpaceCastOpInterface result = castOp.cloneMemorySpaceCastOp(
+      builder, tgtTy,
+      cast<TypedValue<PtrLikeTypeInterface>>(newOp.getResult()));
+  return std::optional<SmallVector<Value>>(
+      SmallVector<Value>({result.getTargetPtr()}));
+}
+
 //===----------------------------------------------------------------------===//
 // AllocOp / AllocaOp
 //===----------------------------------------------------------------------===//
@@ -542,6 +601,11 @@ OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
   return getMemref();
 }
 
+FailureOr<std::optional<SmallVector<Value>>>
+AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
+  return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // CastOp
 //===----------------------------------------------------------------------===//
@@ -710,6 +774,11 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
 }
 
+FailureOr<std::optional<SmallVector<Value>>>
+CastOp::bubbleDownCasts(OpBuilder &builder) {
+  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // CopyOp
 //===----------------------------------------------------------------------===//
@@ -1601,6 +1670,12 @@ OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
   return OpFoldResult();
 }
 
+FailureOr<std::optional<SmallVector<Value>>>
+LoadOp::bubbleDownCasts(OpBuilder &builder) {
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
+                                                            getResult());
+}
+
 //===----------------------------------------------------------------------===//
 // MemorySpaceCastOp
 //===----------------------------------------------------------------------===//
@@ -1645,6 +1720,32 @@ OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
   return Value{};
 }
 
+TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
+  return cast<TypedValue<PtrLikeTypeInterface>>(getSource());
+}
+
+TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
+  return cast<TypedValue<PtrLikeTypeInterface>>(getDest());
+}
+
+bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
+                                               PtrLikeTypeInterface src) {
+  return isa<BaseMemRefType>(tgt) &&
+         tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
+}
+
+MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(
+    OpBuilder &b, PtrLikeTypeInterface tgt,
+    TypedValue<PtrLikeTypeInterface> src) {
+  assert(isValidMemorySpaceCast(tgt, src.getType()) && "invalid arguments");
+  return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
+}
+
+/// The only cast we recognize as promotable is to the generic space.
+bool MemorySpaceCastOp::isSourcePromotable() {
+  return getDest().getType().getMemorySpace() == nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // PrefetchOp
 //===----------------------------------------------------------------------===//
@@ -2041,6 +2142,11 @@ void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
 }
 
+FailureOr<std::optional<SmallVector<Value>>>
+ReinterpretCastOp::bubbleDownCasts(OpBuilder &builder) {
+  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // Reassociative reshape ops
 //===----------------------------------------------------------------------===//
@@ -2348,6 +2454,11 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
       ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
 }
 
+FailureOr<std::optional<SmallVector<Value>>>
+ExpandShapeOp::bubbleDownCasts(OpBuilder &builder) {
+  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
+}
+
 /// Compute the layout map after collapsing a given source MemRef type with the
 /// specified reassociation indices.
 ///
@@ -2569,6 +2680,11 @@ OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
                                                        adaptor.getOperands());
 }
 
+FailureOr<std::optional<SmallVector<Value>>>
+CollapseShapeOp::bubbleDownCasts(OpBuilder &builder) {
+  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // ReshapeOp
 //===----------------------------------------------------------------------===//
@@ -2609,6 +2725,11 @@ LogicalResult ReshapeOp::verify() {
   return success();
 }
 
+FailureOr<std::optional<SmallVector<Value>>>
+ReshapeOp::bubbleDownCasts(OpBuilder &builder) {
+  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // StoreOp
 //===----------------------------------------------------------------------===//
@@ -2626,6 +2747,12 @@ LogicalResult StoreOp::fold(FoldAdaptor adaptor,
   return foldMemRefCast(*this, getValueToStore());
 }
 
+FailureOr<std::optional<SmallVector<Value>>>
+StoreOp::bubbleDownCasts(OpBuilder &builder) {
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
+                                                            ValueRange());
+}
+
 //===----------------------------------------------------------------------===//
 // SubViewOp
 //===----------------------------------------------------------------------===//
@@ -3282,6 +3409,11 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+FailureOr<std::optional<SmallVector<Value>>>
+SubViewOp::bubbleDownCasts(OpBuilder &builder) {
+  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // TransposeOp
 //===----------------------------------------------------------------------===//
@@ -3382,6 +3514,11 @@ OpFoldResult TransposeOp::fold(FoldAdaptor) {
   return {};
 }
 
+FailureOr<std::optional<SmallVector<Value>>>
+TransposeOp::bubbleDownCasts(OpBuilder &builder) {
+  return bubbleDownCastsPassthroughOpImpl(*this, builder, getInMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // ViewOp
 //===----------------------------------------------------------------------===//
@@ -3525,6 +3662,11 @@ void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
 }
 
+FailureOr<std::optional<SmallVector<Value>>>
+ViewOp::bubbleDownCasts(OpBuilder &builder) {
+  return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // AtomicRMWOp
 //===----------------------------------------------------------------------===//
@@ -3570,6 +3712,12 @@ OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
   return OpFoldResult();
 }
 
+FailureOr<std::optional<SmallVector<Value>>>
+AtomicRMWOp::bubbleDownCasts(OpBuilder &builder) {
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
+                                                            getResult());
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 347141e2773b8..adcfd4af9a991 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5099,6 +5099,14 @@ void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<TransferReadAfterWriteToBroadcast>(context);
 }
 
+FailureOr<std::optional<SmallVector<Value>>>
+TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
+  if (!hasPureBufferSemantics())
+    return failure();
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+                                                            getResult());
+}
+
 //===----------------------------------------------------------------------===//
 // TransferWriteOp
 //===----------------------------------------------------------------------===//
@@ -5586,6 +5594,14 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
 }
 
+FailureOr<std::optional<SmallVector<Value>>>
+TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
+  if (!hasPureBufferSemantics())
+    return failure();
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+                                                            ValueRange());
+}
+
 //===----------------------------------------------------------------------===//
 // LoadOp
 //===----------------------------------------------------------------------===//
@@ -5640,6 +5656,12 @@ std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
   return llvm::to_vector<4>(getVectorType().getShape());
 }
 
+FailureOr<std::optional<SmallVector<Value>>>
+LoadOp::bubbleDownCasts(OpBuilder &builder) {
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+                                                            getResult());
+}
+
 //===----------------------------------------------------------------------===//
 // StoreOp
 //===----------------------------------------------------------------------===//
@@ -5679,6 +5701,12 @@ std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
   return llvm::to_vector<4>(getVectorType().getShape());
 }
 
+FailureOr<std::optional<SmallVector<Value>>>
+StoreOp::bubbleDownCasts(OpBuilder &builder) {
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+                                                            ValueRange());
+}
+
 //===----------------------------------------------------------------------===//
 // MaskedLoadOp
 //===----------------------------------------------------------------------===//
@@ -5733,6 +5761,12 @@ OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
   return OpFoldResult();
 }
 
+FailureOr<std::optional<SmallVector<Value>>>
+MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+                                                            getResult());
+}
+
 //===----------------------------------------------------------------------===//
 // MaskedStoreOp
 //===----------------------------------------------------------------------===//
@@ -5783,6 +5817,12 @@ LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
   return memref::foldMemRefCast(*this);
 }
 
+FailureOr<std::optional<SmallVector<Value>>>
+MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+                                                            ValueRange());
+}
+
 //===----------------------------------------------------------------------===//
 // GatherOp
 //===----------------------------------------------------------------------===//
@@ -5886,6 +5926,12 @@ void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<GatherFolder, FoldContiguousGather>(context);
 }
 
+FailureOr<std::optional<SmallVector<Value>>>
+GatherOp::bubbleDownCasts(OpBuilder &builder) {
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+                                                            getResult());
+}
+
 //===----------------------------------------------------------------------===//
 // ScatterOp
 //===----------------------------------------------------------------------===//
@@ -5948,6 +5994,12 @@ void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ScatterFolder, FoldContiguousScatter>(context);
 }
 
+FailureOr<std::optional<SmallVector<Value>>>
+ScatterOp::bubbleDownCasts(OpBuilder &builder) {
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+                                                            ValueRange());
+}
+
 //===----------------------------------------------------------------------===//
 // ExpandLoadOp
 //===----------------------------------------------------------------------===//
@@ -5996,6 +6048,12 @@ void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ExpandLoadFolder>(context);
 }
 
+FailureOr<std::optional<SmallVector<Value>>>
+ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+                                                            getResult());
+}
+
 //===----------------------------------------------------------------------===//
 // CompressStoreOp
 //===----------------------------------------------------------------------===//
@@ -6042,6 +6100,12 @@ void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<CompressStoreFolder>(context);
 }
 
+FailureOr<std::optional<SmallVector<Value>>>
+CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
+  return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+                                                            ValueRange());
+}
+
 //===----------------------------------------------------------------------===//
 // ShapeCastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index fdc19844702bc..388de1c3e5abf 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -11,6 +11,7 @@ set(LLVM_OPTIONAL_SOURCES
   InferIntRangeInterface.cpp
   InferTypeOpInterface.cpp
   LoopLikeInterface.cpp
+  MemOpInterfaces.cpp
   MemorySlotInterfaces.cpp
   ParallelCombiningOpInterface.cpp
   RuntimeVerifiableOpInterface.cpp
@@ -79,6 +80,7 @@ add_mlir_library(MLIRLoopLikeInterface
   MLIRFunctionInterfaces
 )
 
+add_mlir_interface_library(MemOpInterfaces)
 add_mlir_interface_library(MemorySlotInterfaces)
 add_mlir_interface_library(ParallelCombiningOpInterface)
 add_mlir_interface_library(RuntimeVerifiableOpInterface)

diff  --git a/mlir/lib/Interfaces/MemOpInterfaces.cpp b/mlir/lib/Interfaces/MemOpInterfaces.cpp
new file mode 100644
index 0000000000000..fe5c717f67bc4
--- /dev/null
+++ b/mlir/lib/Interfaces/MemOpInterfaces.cpp
@@ -0,0 +1,73 @@
+//===- MemOpInterfaces.cpp - Memory operation interfaces ---------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/MemOpInterfaces.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
+
+using namespace mlir;
+
+LogicalResult mlir::detail::verifyMemorySpaceCastOpInterface(Operation *op) {
+  auto memCastOp = cast<MemorySpaceCastOpInterface>(op);
+
+  // Verify that the source and target pointers are valid
+  Value sourcePtr = memCastOp.getSourcePtr();
+  Value targetPtr = memCastOp.getTargetPtr();
+
+  if (!sourcePtr || !targetPtr) {
+    return op->emitError()
+           << "memory space cast op must have valid source and target pointers";
+  }
+
+  if (sourcePtr.getType().getTypeID() != targetPtr.getType().getTypeID()) {
+    return op->emitError()
+           << "expected source and target types of the same kind";
+  }
+
+  // Verify the Types are of `PtrLikeTypeInterface` type.
+  auto sourceType = dyn_cast<PtrLikeTypeInterface>(sourcePtr.getType());
+  if (!sourceType) {
+    return op->emitError()
+           << "source type must implement `PtrLikeTypeInterface`, but got: "
+           << sourcePtr.getType();
+  }
+
+  auto targetType = dyn_cast<PtrLikeTypeInterface>(targetPtr.getType());
+  if (!targetType) {
+    return op->emitError()
+           << "target type must implement `PtrLikeTypeInterface`, but got: "
+           << targetPtr.getType();
+  }
+
+  // Verify that the operation has exactly one result
+  if (op->getNumResults() != 1) {
+    return op->emitError()
+           << "memory space cast op must have exactly one result";
+  }
+
+  return success();
+}
+
+FailureOr<std::optional<SmallVector<Value>>>
+mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand,
+                                                   ValueRange results) {
+  MemorySpaceCastOpInterface castOp =
+      MemorySpaceCastOpInterface::getIfPromotableCast(operand.get());
+
+  // Bail if the src is not valid.
+  if (!castOp)
+    return failure();
+
+  // Modify the op.
+  operand.set(castOp.getSourcePtr());
+  return std::optional<SmallVector<Value>>();
+}
+
+#include "mlir/Interfaces/MemOpInterfaces.cpp.inc"

diff  --git a/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp b/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp
new file mode 100644
index 0000000000000..00dac19e37171
--- /dev/null
+++ b/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp
@@ -0,0 +1,69 @@
+//===- BubbleDownMemorySpaceCasts.cpp - Bubble down casts transform -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/BubbleDownMemorySpaceCasts.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/MemOpInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+
+namespace mlir {
+#define GEN_PASS_DEF_BUBBLEDOWNMEMORYSPACECASTS
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+namespace {
+//===----------------------------------------------------------------------===//
+// BubbleDownCastsPattern pattern
+//===----------------------------------------------------------------------===//
+/// Pattern to bubble down casts into consumer operations.
+struct BubbleDownCastsPattern
+    : public OpInterfaceRewritePattern<MemorySpaceCastConsumerOpInterface> {
+  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+
+  LogicalResult matchAndRewrite(MemorySpaceCastConsumerOpInterface op,
+                                PatternRewriter &rewriter) const override {
+    FailureOr<std::optional<SmallVector<Value>>> results =
+        op.bubbleDownCasts(rewriter);
+    if (failed(results))
+      return failure();
+    if (!results->has_value()) {
+      rewriter.modifyOpInPlace(op, []() {});
+      return success();
+    }
+    rewriter.replaceOp(op, **results);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// BubbleDownMemorySpaceCasts pass
+//===----------------------------------------------------------------------===//
+
+struct BubbleDownMemorySpaceCasts
+    : public impl::BubbleDownMemorySpaceCastsBase<BubbleDownMemorySpaceCasts> {
+  using impl::BubbleDownMemorySpaceCastsBase<
+      BubbleDownMemorySpaceCasts>::BubbleDownMemorySpaceCastsBase;
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateBubbleDownMemorySpaceCastPatterns(patterns, PatternBenefit(1));
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+void mlir::populateBubbleDownMemorySpaceCastPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<BubbleDownCastsPattern>(patterns.getContext(), benefit);
+}

diff  --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 058039e47313e..54b67f5c7a91e 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_library(MLIRTransforms
   ControlFlowSink.cpp
   CSE.cpp
   GenerateRuntimeVerification.cpp
+  BubbleDownMemorySpaceCasts.cpp
   InlinerPass.cpp
   LocationSnapshot.cpp
   LoopInvariantCodeMotion.cpp
@@ -31,6 +32,7 @@ add_mlir_library(MLIRTransforms
   MLIRAnalysis
   MLIRFunctionInterfaces
   MLIRLoopLikeInterface
+  MLIRMemOpInterfaces
   MLIRMemorySlotInterfaces
   MLIRPass
   MLIRRuntimeVerifiableOpInterface

diff  --git a/mlir/test/Transforms/test-bubble-down-memory-space-casts.mlir b/mlir/test/Transforms/test-bubble-down-memory-space-casts.mlir
new file mode 100644
index 0000000000000..e4fce89cffb45
--- /dev/null
+++ b/mlir/test/Transforms/test-bubble-down-memory-space-casts.mlir
@@ -0,0 +1,298 @@
+// RUN: mlir-opt %s --bubble-down-memory-space-casts | FileCheck %s
+
+#map = affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>
+// CHECK-LABEL:   func.func @load_store(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) {
+// CHECK:           %[[VAL_0:.*]] = memref.load %[[ARG0]]{{\[}}%[[ARG1]]] : memref<?xf32, 1>
+// CHECK:           memref.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[ARG1]]] : memref<?xf32, 1>
+// CHECK:           return
+// CHECK:         }
+func.func @load_store(%arg0: memref<?xf32, 1>, %arg1: index) {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %0 = memref.load %memspacecast[%arg1] : memref<?xf32>
+  memref.store %0, %memspacecast[%arg1] : memref<?xf32>
+  return
+}
+
+// CHECK-LABEL:   func.func @load_store_unfoldable(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) {
+// CHECK:           %[[VAL_0:.*]] = memref.memory_space_cast %[[ARG0]] : memref<?xf32, 1> to memref<?xf32, 2>
+// CHECK:           %[[VAL_1:.*]] = memref.load %[[VAL_0]]{{\[}}%[[ARG1]]] : memref<?xf32, 2>
+// CHECK:           memref.store %[[VAL_1]], %[[VAL_0]]{{\[}}%[[ARG1]]] : memref<?xf32, 2>
+// CHECK:           return
+// CHECK:         }
+func.func @load_store_unfoldable(%arg0: memref<?xf32, 1>, %arg1: index) {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32, 2>
+  %0 = memref.load %memspacecast[%arg1] : memref<?xf32, 2>
+  memref.store %0, %memspacecast[%arg1] : memref<?xf32, 2>
+  return
+}
+
+// CHECK-LABEL:   func.func @cast(
+// CHECK-SAME:                    %[[ARG0:.*]]: memref<2xf32, 1>,
+// CHECK-SAME:                    %[[ARG1:.*]]: memref<*xf32, 1>) -> (memref<*xf32>, memref<3x2xf32>) {
+// CHECK:           %[[VAL_0:.*]] = memref.cast %[[ARG0]] : memref<2xf32, 1> to memref<*xf32, 1>
+// CHECK:           %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<*xf32, 1> to memref<*xf32>
+// CHECK:           %[[VAL_2:.*]] = memref.cast %[[ARG1]] : memref<*xf32, 1> to memref<3x2xf32, 1>
+// CHECK:           %[[VAL_3:.*]] = memref.memory_space_cast %[[VAL_2]] : memref<3x2xf32, 1> to memref<3x2xf32>
+// CHECK:           return %[[VAL_1]], %[[VAL_3]] : memref<*xf32>, memref<3x2xf32>
+// CHECK:         }
+func.func @cast(%arg0: memref<2xf32, 1>, %arg1: memref<*xf32, 1>) -> (memref<*xf32>, memref<3x2xf32>) {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<2xf32, 1> to memref<2xf32>
+  %1 = memref.cast %memspacecast : memref<2xf32> to memref<*xf32>
+  %memspacecast_1 = memref.memory_space_cast %arg1 : memref<*xf32, 1> to memref<*xf32>
+  %2 = memref.cast %memspacecast_1 : memref<*xf32> to memref<3x2xf32>
+  return %1, %2 : memref<*xf32>, memref<3x2xf32>
+}
+
+// CHECK-LABEL:   func.func @view(
+// CHECK-SAME:                    %[[ARG0:.*]]: memref<?xi8, 1>,
+// CHECK-SAME:                    %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> memref<?x?xi8> {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 100 : index
+// CHECK:           %[[VAL_1:.*]] = memref.view %[[ARG0]]{{\[}}%[[ARG1]]]{{\[}}%[[ARG2]], %[[VAL_0]]] : memref<?xi8, 1> to memref<?x?xi8, 1>
+// CHECK:           %[[VAL_2:.*]] = memref.memory_space_cast %[[VAL_1]] : memref<?x?xi8, 1> to memref<?x?xi8>
+// CHECK:           return %[[VAL_2]] : memref<?x?xi8>
+// CHECK:         }
+func.func @view(%arg0: memref<?xi8, 1>, %arg1: index, %arg2: index) -> memref<?x?xi8> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xi8, 1> to memref<?xi8>
+  %c100 = arith.constant 100 : index
+  %view = memref.view %memspacecast[%arg1][%arg2, %c100] : memref<?xi8> to memref<?x?xi8>
+  return %view : memref<?x?xi8>
+}
+
+// CHECK-LABEL:   func.func @subview(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?x?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) -> memref<8x2xf32, strided<[?, 2], offset: ?>> {
+// CHECK:           %[[VAL_0:.*]] = memref.subview %[[ARG0]][4, 2] [8, 2] [3, 2] : memref<?x?xf32, 1> to memref<8x2xf32, strided<[?, 2], offset: ?>, 1>
+// CHECK:           %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<8x2xf32, strided<[?, 2], offset: ?>, 1> to memref<8x2xf32, strided<[?, 2], offset: ?>>
+// CHECK:           return %[[VAL_1]] : memref<8x2xf32, strided<[?, 2], offset: ?>>
+// CHECK:         }
+func.func @subview(%arg0: memref<?x?xf32, 1>, %arg1: index) -> memref<8x2xf32, strided<[?, 2], offset: ?>> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?x?xf32, 1> to memref<?x?xf32>
+  %subview = memref.subview %memspacecast[4, 2] [8, 2] [3, 2] : memref<?x?xf32> to memref<8x2xf32, strided<[?, 2], offset: ?>>
+  return %subview : memref<8x2xf32, strided<[?, 2], offset: ?>>
+}
+
+// CHECK-LABEL:   func.func @reinterpret_cast(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
+// CHECK-DAG:       %[[VAL_0:.*]] = arith.constant 10 : index
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_2:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[VAL_1]]], sizes: [10, %[[VAL_0]]], strides: {{\[}}%[[VAL_0]], 1] : memref<?xf32, 1> to memref<10x?xf32, strided<[?, 1], offset: ?>, 1>
+// CHECK:           %[[VAL_3:.*]] = memref.memory_space_cast %[[VAL_2]] : memref<10x?xf32, strided<[?, 1], offset: ?>, 1> to memref<10x?xf32, strided<[?, 1], offset: ?>>
+// CHECK:           return %[[VAL_3]] : memref<10x?xf32, strided<[?, 1], offset: ?>>
+// CHECK:         }
+func.func @reinterpret_cast(%arg0: memref<?xf32, 1>, %arg1: index) -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %c0 = arith.constant 0 : index
+  %c10 = arith.constant 10 : index
+  %reinterpret_cast = memref.reinterpret_cast %memspacecast to offset: [%c0], sizes: [10, %c10], strides: [%c10, 1] : memref<?xf32> to memref<10x?xf32, strided<[?, 1], offset: ?>>
+  return %reinterpret_cast : memref<10x?xf32, strided<[?, 1], offset: ?>>
+}
+
+// CHECK-LABEL:   func.func @reshape(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?x?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: memref<1xindex>) -> memref<?xf32> {
+// CHECK:           %[[VAL_0:.*]] = memref.reshape %[[ARG0]](%[[ARG1]]) : (memref<?x?xf32, 1>, memref<1xindex>) -> memref<?xf32, 1>
+// CHECK:           %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<?xf32, 1> to memref<?xf32>
+// CHECK:           return %[[VAL_1]] : memref<?xf32>
+// CHECK:         }
+func.func @reshape(%arg0: memref<?x?xf32, 1>, %arg1: memref<1xindex>) -> memref<?xf32> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?x?xf32, 1> to memref<?x?xf32>
+  %reshape = memref.reshape %memspacecast(%arg1) : (memref<?x?xf32>, memref<1xindex>) -> memref<?xf32>
+  return %reshape : memref<?xf32>
+}
+
+// CHECK-LABEL:   func.func @expand_shape(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<12xf32, 1>) -> memref<3x4xf32> {
+// CHECK:           %[[VAL_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1]] output_shape [3, 4] : memref<12xf32, 1> into memref<3x4xf32, 1>
+// CHECK:           %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<3x4xf32, 1> to memref<3x4xf32>
+// CHECK:           return %[[VAL_1]] : memref<3x4xf32>
+// CHECK:         }
+func.func @expand_shape(%arg0: memref<12xf32, 1>) -> memref<3x4xf32> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<12xf32, 1> to memref<12xf32>
+  %expand_shape = memref.expand_shape %memspacecast [[0, 1]] output_shape [3, 4] : memref<12xf32> into memref<3x4xf32>
+  return %expand_shape : memref<3x4xf32>
+}
+
+// CHECK-LABEL:   func.func @collapse_shape(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<3x4xf32, 1>) -> memref<12xf32> {
+// CHECK:           %[[VAL_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[\[}}0, 1]] : memref<3x4xf32, 1> into memref<12xf32, 1>
+// CHECK:           %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<12xf32, 1> to memref<12xf32>
+// CHECK:           return %[[VAL_1]] : memref<12xf32>
+// CHECK:         }
+func.func @collapse_shape(%arg0: memref<3x4xf32, 1>) -> memref<12xf32> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<3x4xf32, 1> to memref<3x4xf32>
+  %collapse_shape = memref.collapse_shape %memspacecast [[0, 1]] : memref<3x4xf32> into memref<12xf32>
+  return %collapse_shape : memref<12xf32>
+}
+
+// CHECK-LABEL:   func.func @transpose(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?x?xf32, 1>) -> memref<?x?xf32, #[[$ATTR_0]]> {
+// CHECK:           %[[VAL_0:.*]] = memref.transpose %[[ARG0]] (d0, d1) -> (d1, d0) : memref<?x?xf32, 1> to memref<?x?xf32, #[[$ATTR_0]], 1>
+// CHECK:           %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<?x?xf32, #[[$ATTR_0]], 1> to memref<?x?xf32, #[[$ATTR_0]]>
+// CHECK:           return %[[VAL_1]] : memref<?x?xf32, #[[$ATTR_0]]>
+// CHECK:         }
+func.func @transpose(%arg0: memref<?x?xf32, 1>) -> memref<?x?xf32, #map> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?x?xf32, 1> to memref<?x?xf32>
+  %transpose = memref.transpose %memspacecast (d0, d1) -> (d1, d0) : memref<?x?xf32> to memref<?x?xf32, #map>
+  return %transpose : memref<?x?xf32, #map>
+}
+
+// CHECK-LABEL:   func.func @atomic_rmw(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index,
+// CHECK-SAME:      %[[ARG2:.*]]: f32) -> f32 {
+// CHECK:           %[[VAL_0:.*]] = memref.atomic_rmw addf %[[ARG2]], %[[ARG0]]{{\[}}%[[ARG1]]] : (f32, memref<?xf32, 1>) -> f32
+// CHECK:           return %[[VAL_0]] : f32
+// CHECK:         }
+func.func @atomic_rmw(%arg0: memref<?xf32, 1>, %arg1: index, %arg2: f32) -> f32 {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %0 = memref.atomic_rmw addf %arg2, %memspacecast[%arg1] : (f32, memref<?xf32>) -> f32
+  return %0 : f32
+}
+
+// CHECK-LABEL:   func.func @assume_alignment(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>) -> memref<?xf32> {
+// CHECK:           %[[VAL_0:.*]] = memref.assume_alignment %[[ARG0]], 16 : memref<?xf32, 1>
+// CHECK:           %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<?xf32, 1> to memref<?xf32>
+// CHECK:           return %[[VAL_1]] : memref<?xf32>
+// CHECK:         }
+func.func @assume_alignment(%arg0: memref<?xf32, 1>) -> memref<?xf32> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %1 = memref.assume_alignment %memspacecast, 16 : memref<?xf32>
+  return %1 : memref<?xf32>
+}
+
+// CHECK-LABEL:   func.func @op_with_cast_sequence(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<4x4xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index,
+// CHECK-SAME:      %[[ARG2:.*]]: f32) -> memref<16xf32> {
+// CHECK-DAG:       %[[VAL_0:.*]] = arith.constant 4 : index
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_2:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
+// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
+// CHECK:           %[[VAL_4:.*]] = memref.memory_space_cast %[[VAL_3]] : memref<16xf32, 1> to memref<16xf32>
+// CHECK:           %[[VAL_5:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_1]]] : memref<16xf32, 1>
+// CHECK:           %[[VAL_6:.*]] = arith.addf %[[VAL_5]], %[[ARG2]] : f32
+// CHECK:           memref.store %[[VAL_6]], %[[VAL_3]]{{\[}}%[[VAL_1]]] : memref<16xf32, 1>
+// CHECK:           %[[VAL_7:.*]] = memref.atomic_rmw addf %[[ARG2]], %[[VAL_3]]{{\[}}%[[VAL_0]]] : (f32, memref<16xf32, 1>) -> f32
+// CHECK:           return %[[VAL_4]] : memref<16xf32>
+// CHECK:         }
+func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32>
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32>
+  %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32>
+  %loaded = memref.load %collapsed[%c0] : memref<16xf32>
+  %added = arith.addf %loaded, %arg2 : f32
+  memref.store %added, %collapsed[%c0] : memref<16xf32>
+  %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32
+  return %collapsed : memref<16xf32>
+}
+
+// CHECK-LABEL:   func.func @transfer_read_write(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_1:.*]] = vector.transfer_read %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xf32>
+// CHECK:           vector.transfer_write %[[VAL_1]], %[[ARG0]]{{\[}}%[[ARG1]]] : vector<4xf32>, memref<?xf32, 1>
+// CHECK:           return
+// CHECK:         }
+func.func @transfer_read_write(%arg0: memref<?xf32, 1>, %arg1: index) {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %c0 = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %memspacecast[%arg1], %c0 : memref<?xf32>, vector<4xf32>
+  vector.transfer_write %0, %memspacecast[%arg1] : vector<4xf32>, memref<?xf32>
+  return
+}
+
+// NOTE: The operations disappear because they can get folded.
+// CHECK-LABEL:   func.func @transfer_read_write_tensor(
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<?xf32>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) -> tensor<?xf32> {
+// CHECK:           return %[[ARG0]] : tensor<?xf32>
+// CHECK:         }
+func.func @transfer_read_write_tensor(%arg0: tensor<?xf32>, %arg1: index) -> tensor<?xf32> {
+  %c0 = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %arg0[%arg1], %c0 : tensor<?xf32>, vector<4xf32>
+  %1 = vector.transfer_write %0, %arg0[%arg1] : vector<4xf32>, tensor<?xf32>
+  return %1 : tensor<?xf32>
+}
+
+// CHECK-LABEL:   func.func @vector_load_store(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) {
+// CHECK:           %[[VAL_0:.*]] = vector.load %[[ARG0]]{{\[}}%[[ARG1]]] : memref<?xf32, 1>, vector<4xf32>
+// CHECK:           vector.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[ARG1]]] : memref<?xf32, 1>, vector<4xf32>
+// CHECK:           return
+// CHECK:         }
+func.func @vector_load_store(%arg0: memref<?xf32, 1>, %arg1: index) {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %0 = vector.load %memspacecast[%arg1] : memref<?xf32>, vector<4xf32>
+  vector.store %0, %memspacecast[%arg1] : memref<?xf32>, vector<4xf32>
+  return
+}
+
+// CHECK-LABEL:   func.func @masked_load_store(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) {
+// CHECK-DAG:       %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+// CHECK:           %[[VAL_2:.*]] = vector.maskedload %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+// CHECK:           vector.maskedstore %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_2]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32>
+// CHECK:           return
+// CHECK:         }
+func.func @masked_load_store(%arg0: memref<?xf32, 1>, %arg1: index) {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %mask = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+  %passthrough = arith.constant dense<0.0> : vector<4xf32>
+  %0 = vector.maskedload %memspacecast[%arg1], %mask, %passthrough : memref<?xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  vector.maskedstore %memspacecast[%arg1], %mask, %0 : memref<?xf32>, vector<4xi1>, vector<4xf32>
+  return
+}
+
+// CHECK-LABEL:   func.func @gather_scatter(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) {
+// CHECK-DAG:       %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant dense<true> : vector<4xi1>
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = vector.gather %[[ARG0]]{{\[}}%[[VAL_3]]] {{\[}}%[[VAL_2]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+// CHECK:           vector.scatter %[[ARG0]]{{\[}}%[[VAL_3]]] {{\[}}%[[VAL_2]]], %[[VAL_1]], %[[VAL_4]] : memref<?xf32, 1>, vector<4xindex>, vector<4xi1>, vector<4xf32>
+// CHECK:           return
+// CHECK:         }
+func.func @gather_scatter(%arg0: memref<?xf32, 1>, %arg1: index) {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %c0 = arith.constant 0 : index
+  %indices = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+  %mask = arith.constant dense<true> : vector<4xi1>
+  %passthrough = arith.constant dense<0.0> : vector<4xf32>
+  %0 = vector.gather %memspacecast[%c0] [%indices], %mask, %passthrough : memref<?xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  vector.scatter %memspacecast[%c0] [%indices], %mask, %0 : memref<?xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32>
+  return
+}
+
+// CHECK-LABEL:   func.func @expandload_compressstore(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) {
+// CHECK-DAG:       %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+// CHECK:           %[[VAL_2:.*]] = vector.expandload %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+// CHECK:           vector.compressstore %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_2]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32>
+// CHECK:           return
+// CHECK:         }
+func.func @expandload_compressstore(%arg0: memref<?xf32, 1>, %arg1: index) {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %mask = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+  %passthrough = arith.constant dense<0.0> : vector<4xf32>
+  %0 = vector.expandload %memspacecast[%arg1], %mask, %passthrough : memref<?xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  vector.compressstore %memspacecast[%arg1], %mask, %0 : memref<?xf32>, vector<4xi1>, vector<4xf32>
+  return
+}


        


More information about the Mlir-commits mailing list