[Mlir-commits] [mlir] [mlir] Implement memory-space cast operand fusion into consumers (PR #159454)

Fabian Mora llvmlistbot at llvm.org
Thu Sep 18 04:04:32 PDT 2025


https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/159454

>From a15b8ca9bd287d4ad6af320cae7ae01afb4234bc Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Wed, 17 Sep 2025 21:07:38 +0000
Subject: [PATCH 1/2] [mlir] Implement memory-space cast operand fusion into
 consumers

This commit adds functionality to fuse memory-space casts into consumer operations,
allowing operations to be performed directly on the original memory-space rather
than first casting to a different memory space.

Key changes:
- Introduce `MemorySpaceCastOpInterface` to handle memory-space cast operations
- Create a `FuseMemorySpaceCastsIntoConsumers` pass that identifies and fuses eligible casts
- Add implementation for memref and vector operations to handle memory-space cast fusion
- Add fuseCastOperands method to relevant operations to support the fusion

In particular, in the current implementation only memory-space casts into the default
memory-space can be fused.

Example:

```mlir
func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
    %memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32>
    %c0 = arith.constant 0 : index
    %c4 = arith.constant 4 : index
    %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32>
    %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32>
    %loaded = memref.load %collapsed[%c0] : memref<16xf32>
    %added = arith.addf %loaded, %arg2 : f32
    memref.store %added, %collapsed[%c0] : memref<16xf32>
    %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32
    return %collapsed : memref<16xf32>
}
// mlir-opt --fuse-memory-space-casts-into-consumers
func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
    %c4 = arith.constant 4 : index
    %c0 = arith.constant 0 : index
    %expand_shape = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
    %collapse_shape = memref.collapse_shape %expand_shape [[0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
    %memspacecast = memref.memory_space_cast %collapse_shape : memref<16xf32, 1> to memref<16xf32>
    %0 = memref.load %collapse_shape[%c0] : memref<16xf32, 1>
    %1 = arith.addf %0, %arg2 : f32
    memref.store %1, %collapse_shape[%c0] : memref<16xf32, 1>
    %2 = memref.atomic_rmw addf %arg2, %collapse_shape[%c4] : (f32, memref<16xf32, 1>) -> f32
    return %memspacecast : memref<16xf32>
}
```

Signed-off-by: Fabian Mora <fabian.mora-cordero at amd.com>
---
 mlir/include/mlir/Dialect/MemRef/IR/MemRef.h  |   1 +
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |  19 +-
 .../mlir/Dialect/Vector/IR/VectorOps.h        |   1 +
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  16 +-
 mlir/include/mlir/Interfaces/CMakeLists.txt   |   1 +
 .../include/mlir/Interfaces/MemOpInterfaces.h |  37 +++
 .../mlir/Interfaces/MemOpInterfaces.td        | 114 +++++++
 .../FuseMemorySpaceCastsIntoConsumers.h       |  20 ++
 mlir/include/mlir/Transforms/Passes.h         |   1 +
 mlir/include/mlir/Transforms/Passes.td        |  40 +++
 mlir/lib/Dialect/MemRef/IR/CMakeLists.txt     |   1 +
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 158 ++++++++++
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  64 ++++
 mlir/lib/Interfaces/CMakeLists.txt            |   2 +
 mlir/lib/Interfaces/MemOpInterfaces.cpp       |  73 +++++
 mlir/lib/Transforms/CMakeLists.txt            |   2 +
 .../FuseMemorySpaceCastsIntoConsumers.cpp     |  73 +++++
 .../test-fuse-casts-into-consumers.mlir       | 281 ++++++++++++++++++
 18 files changed, 897 insertions(+), 7 deletions(-)
 create mode 100644 mlir/include/mlir/Interfaces/MemOpInterfaces.h
 create mode 100644 mlir/include/mlir/Interfaces/MemOpInterfaces.td
 create mode 100644 mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h
 create mode 100644 mlir/lib/Interfaces/MemOpInterfaces.cpp
 create mode 100644 mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp
 create mode 100644 mlir/test/Transforms/test-fuse-casts-into-consumers.mlir

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index bdec699eb4ce4..30f33ed2fd1d6 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -18,6 +18,7 @@
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/MemOpInterfaces.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 671cc05e963b4..238a767ac8b73 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -15,6 +15,7 @@ include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/MemOpInterfaces.td"
 include "mlir/Interfaces/MemorySlotInterfaces.td"
 include "mlir/Interfaces/ShapedOpInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -145,7 +146,8 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
       Pure,
       ViewLikeOpInterface,
-      SameOperandsAndResultType
+      SameOperandsAndResultType,
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
     ]> {
   let summary =
       "assumption that gives alignment information to the input memref";
@@ -456,6 +458,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
 def MemRef_CastOp : MemRef_Op<"cast", [
       DeclareOpInterfaceMethods<CastOpInterface>,
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
       MemRefsNormalizable,
       Pure,
       SameOperandsAndResultShape,
@@ -1194,6 +1197,7 @@ def LoadOp : MemRef_Op<"load",
                      "memref", "result",
                      "::llvm::cast<MemRefType>($_self).getElementType()">,
       MemRefsNormalizable,
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
       DeclareOpInterfaceMethods<PromotableMemOpInterface>,
       DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
   let summary = "load operation";
@@ -1284,6 +1288,7 @@ def LoadOp : MemRef_Op<"load",
 def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
       DeclareOpInterfaceMethods<CastOpInterface>,
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+      DeclareOpInterfaceMethods<MemorySpaceCastOpInterface>,
       MemRefsNormalizable,
       Pure,
       SameOperandsAndResultElementType,
@@ -1376,6 +1381,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
 def MemRef_ReinterpretCastOp
   : MemRef_OpWithOffsetSizesAndStrides<"reinterpret_cast", [
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
       AttrSizedOperandSegments,
       MemRefsNormalizable,
       Pure,
@@ -1603,6 +1609,7 @@ def MemRef_RankOp : MemRef_Op<"rank", [Pure]> {
 
 def MemRef_ReshapeOp: MemRef_Op<"reshape", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
     Pure,
     ViewLikeOpInterface]>  {
   let summary = "memref reshape operation";
@@ -1701,6 +1708,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
 
 def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
   let summary = "operation to produce a memref with a higher rank.";
   let description = [{
@@ -1822,7 +1830,9 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
 }
 
 def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
-    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
+  ]> {
   let summary = "operation to produce a memref with a smaller rank.";
   let description = [{
     The `memref.collapse_shape` op produces a new view with a smaller rank
@@ -1929,6 +1939,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
                      "memref", "value",
                      "::llvm::cast<MemRefType>($_self).getElementType()">,
       MemRefsNormalizable,
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
       DeclareOpInterfaceMethods<PromotableMemOpInterface>,
       DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
   let summary = "store operation";
@@ -2006,6 +2017,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
 
 def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
     DeclareOpInterfaceMethods<ViewLikeOpInterface>,
     AttrSizedOperandSegments,
     OffsetSizeAndStrideOpInterface,
@@ -2281,6 +2293,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
 
 def MemRef_TransposeOp : MemRef_Op<"transpose", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
     Pure]>,
     Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
     Results<(outs AnyStridedMemRef)> {
@@ -2316,6 +2329,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [
 
 def MemRef_ViewOp : MemRef_Op<"view", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
     DeclareOpInterfaceMethods<ViewLikeOpInterface>,
     Pure]> {
   let summary = "memref view operation";
@@ -2392,6 +2406,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
 //===----------------------------------------------------------------------===//
 
 def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
       AllTypesMatch<["value", "result"]>,
       TypesMatchWith<"value type matches element type of memref",
                      "memref", "value",
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 63410b8bea747..bbf55f5d507e3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -27,6 +27,7 @@
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/IndexingMapOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/MemOpInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 26d06624cb976..93e9bfc78ea75 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -24,6 +24,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/IndexingMapOpInterface.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/MemOpInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/VectorInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
@@ -1246,6 +1247,7 @@ def Vector_TransferReadOp :
       DeclareOpInterfaceMethods<MaskableOpInterface>,
       DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
       DeclareOpInterfaceMethods<ConditionallySpeculatable>,
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
       AttrSizedOperandSegments,
       DestinationStyleOpInterface
     ]>,
@@ -1493,6 +1495,7 @@ def Vector_TransferWriteOp :
       DeclareOpInterfaceMethods<MaskableOpInterface>,
       DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
       DeclareOpInterfaceMethods<ConditionallySpeculatable>,
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
       AttrSizedOperandSegments,
       DestinationStyleOpInterface
   ]>,
@@ -1649,6 +1652,7 @@ def Vector_TransferWriteOp :
 
 def Vector_LoadOp : Vector_Op<"load", [
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
   ]> {
   let summary = "reads an n-D slice of memory into an n-D vector";
   let description = [{
@@ -1765,6 +1769,7 @@ def Vector_LoadOp : Vector_Op<"load", [
 
 def Vector_StoreOp : Vector_Op<"store", [
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
   ]> {
   let summary = "writes an n-D vector to an n-D slice of memory";
   let description = [{
@@ -1869,7 +1874,7 @@ def Vector_StoreOp : Vector_Op<"store", [
 }
 
 def Vector_MaskedLoadOp :
-  Vector_Op<"maskedload">,
+  Vector_Op<"maskedload", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -1961,7 +1966,7 @@ def Vector_MaskedLoadOp :
 }
 
 def Vector_MaskedStoreOp :
-  Vector_Op<"maskedstore">,
+  Vector_Op<"maskedstore", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2041,6 +2046,7 @@ def Vector_MaskedStoreOp :
 def Vector_GatherOp :
   Vector_Op<"gather", [
     DeclareOpInterfaceMethods<MaskableOpInterface>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
   ]>,
     Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
@@ -2144,7 +2150,7 @@ def Vector_GatherOp :
 }
 
 def Vector_ScatterOp :
-  Vector_Op<"scatter">,
+  Vector_Op<"scatter", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$offsets,
                VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
@@ -2229,7 +2235,7 @@ def Vector_ScatterOp :
 }
 
 def Vector_ExpandLoadOp :
-  Vector_Op<"expandload">,
+  Vector_Op<"expandload", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
                FixedVectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2317,7 +2323,7 @@ def Vector_ExpandLoadOp :
 }
 
 def Vector_CompressStoreOp :
-  Vector_Op<"compressstore">,
+  Vector_Op<"compressstore", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
                FixedVectorOfNonZeroRankOf<[I1]>:$mask,
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index 2add220fdfb7c..a5feb592045c0 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_interface(IndexingMapOpInterface)
 add_mlir_interface(InferIntRangeInterface)
 add_mlir_interface(InferTypeOpInterface)
 add_mlir_interface(LoopLikeInterface)
+add_mlir_interface(MemOpInterfaces)
 add_mlir_interface(ParallelCombiningOpInterface)
 add_mlir_interface(RuntimeVerifiableOpInterface)
 add_mlir_interface(ShapedOpInterfaces)
diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.h b/mlir/include/mlir/Interfaces/MemOpInterfaces.h
new file mode 100644
index 0000000000000..cc9f4c6b3882e
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.h
@@ -0,0 +1,37 @@
+//===- MemOpInterfaces.h - Memory operation interfaces ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains declarations of interfaces for operations that interact
+// with memory.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_MEMOPINTERFACES_H
+#define MLIR_INTERFACES_MEMOPINTERFACES_H
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace detail {
+/// Attempt to verify the given memory space cast operation.
+LogicalResult verifyMemorySpaceCastOpInterface(Operation *op);
+
+/// Tries to fuse inplace a `MemorySpaceCastOpInterface` operation referenced by
+/// `operand`. On success, it returns `results`, and sets `modifiedInPlace` to
+/// true. It returns failure if `operand` doesn't reference a
+/// `MemorySpaceCastOpInterface` op.
+FailureOr<SmallVector<Value>>
+fuseInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results,
+                               bool &modifiedInPlace);
+} // namespace detail
+} // namespace mlir
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/MemOpInterfaces.h.inc"
+
+#endif // MLIR_INTERFACES_MEMOPINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.td b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
new file mode 100644
index 0000000000000..0b8ba19171fb7
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
@@ -0,0 +1,114 @@
+//===- MemOpInterfaces.td - Memory operation interfaces -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains interfaces for operations that interact with memory.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_MEMOPINTERFACES_TD
+#define MLIR_INTERFACES_MEMOPINTERFACES_TD
+
+include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+def FuseMemorySpaceCastConsumerOpInterface :
+    OpInterface<"FuseMemorySpaceCastConsumerOpInterface"> {
+  let description = [{
+    An interface to fuse memory-space cast operands into a consumer operation.
+    It is the responsibility of the interface to determine which casts can be
+    fused into the operation.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<[{
+        Attempt to fuse the incoming cast-like operands. Returns `success`
+        and any new results on fusion success, otherwise it returns failure.
+        If new results are produced, these must be compatible with the original
+        operation results.
+
+        The `modifiedInPlace` parameter indicates whether the operation was
+        modified in place. If `false` and the fusion succeeded, then the
+        interface guarantees it is valid to erase the original operation.
+        If `true`, then the interface must guarantee no operations were created
+        by the method, and that no further IR modification is necessary. It is
+        considered an error if `modifiedInPlace` is true and the fusion failed.
+
+        Any implementations of this method must not erase/replace the original
+        operation, instead it is the caller responsibility to erase or replace
+        the op with the results provided by the method.
+
+        Finally, any implementations of this method have to guarantee that the
+        IR remains valid at all times.
+      }],
+      "::llvm::FailureOr<::llvm::SmallVector<::mlir::Value>>", "fuseCastOperands",
+      (ins "::mlir::OpBuilder &":$builder, "bool &":$modifiedInPlace)
+    >,
+  ];
+}
+
+def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
+  let description = [{
+    An interface for operations that perform memory-space casts. This
+    interface assumes that the cast operation is `pure`.
+
+    These operations expect to have a well-defined ptr-like operand, and
+    a well-defined target ptr-like result.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<[{
+        Returns the source ptr-like value.
+      }],
+      "::mlir::TypedValue<::mlir::PtrLikeTypeInterface>",  "getSourcePtr"
+    >,
+    InterfaceMethod<[{
+        Returns the target ptr-like value.
+      }],
+      "::mlir::TypedValue<::mlir::PtrLikeTypeInterface>", "getTargetPtr"
+    >,
+    InterfaceMethod<[{
+        Returns whether the memory space cast specified by `tgt` and `src`
+        is supported.
+      }],
+      "bool", "isValidMemorySpaceCast",
+      (ins "::mlir::PtrLikeTypeInterface":$tgt,
+           "::mlir::PtrLikeTypeInterface":$src)
+    >,
+    InterfaceMethod<[{
+        Clones the memory space cast op with the given source and target type.
+      }],
+      "::mlir::MemorySpaceCastOpInterface", "cloneMemorySpaceCastOp",
+      (ins "::mlir::OpBuilder &":$builder, "::mlir::Type":$tgt,
+           "::mlir::Value":$src)
+    >,
+    InterfaceMethod<[{
+        Returns whether the cast allows to be fused.
+      }],
+      "bool", "isFusableMemorySpaceCast"
+    >
+  ];
+  let verify = [{
+    return ::mlir::detail::verifyMemorySpaceCastOpInterface($_op);
+  }];
+  let dependentTraits = [Pure];
+  let extraClassDeclaration = [{
+    /// Returns the underlying `MemorySpaceCastOpInterface` op if `value`
+    /// is produced by a `MemorySpaceCastOpInterface` op, and
+    /// `isFusableMemorySpaceCast` returns true, otherwise it returns null.
+    static ::mlir::MemorySpaceCastOpInterface
+    getIfFusableCast(::mlir::Value value) {
+      auto op = ::llvm::dyn_cast_or_null<::mlir::MemorySpaceCastOpInterface>(
+        value.getDefiningOp());
+      if (!op || !op.isFusableMemorySpaceCast())
+        return nullptr;
+      return op;
+    }
+  }];
+}
+
+#endif // MLIR_INTERFACES_MEMOPINTERFACES_TD
diff --git a/mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h b/mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h
new file mode 100644
index 0000000000000..9333f92a10289
--- /dev/null
+++ b/mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h
@@ -0,0 +1,20 @@
+//===-- FuseMemorySpaceCastsIntoConsumers.h - Cast fusion patterns -C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_FUSEMEMORYSPACECASTSINTOCONSUMERS_H
+#define MLIR_TRANSFORMS_FUSEMEMORYSPACECASTSINTOCONSUMERS_H
+
+namespace mlir {
+class RewritePatternSet;
+/// Collect a set of patterns to fuse memory-space cast operations into
+/// consumers.
+void populateFuseMemorySpaceCastIntoConsumersPatterns(
+    RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_FUSEMEMORYSPACECASTSINTOCONSUMERS_H
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 9cd2ef34e15ea..610a9671fede8 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -46,6 +46,7 @@ class GreedyRewriteConfig;
 #define GEN_PASS_DECL_SYMBOLPRIVATIZE
 #define GEN_PASS_DECL_TOPOLOGICALSORT
 #define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
+#define GEN_PASS_DECL_FUSEMEMORYSPACECASTSINTOCONSUMERS
 #include "mlir/Transforms/Passes.h.inc"
 
 /// Creates an instance of the Canonicalizer pass, configured with default
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index beb59784947c5..69280e3d443ea 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -585,4 +585,44 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
   ];
 }
 
+def FuseMemorySpaceCastsIntoConsumers :
+    Pass<"fuse-memory-space-casts-into-consumers"> {
+  let summary = "Fuses memory-space cast operations into consumers.";
+  let description = [{
+    This pass tries to fuse all possible memory-space cast operations into their consumers.
+    It does this by looking for `FuseMemorySpaceCastConsumerOpInterface`
+    operations, and invoking the interface methods to perform the fusion.
+
+    Example:
+
+    ```mlir
+    func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
+      %memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32>
+      %c0 = arith.constant 0 : index
+      %c4 = arith.constant 4 : index
+      %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32>
+      %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32>
+      %loaded = memref.load %collapsed[%c0] : memref<16xf32>
+      %added = arith.addf %loaded, %arg2 : f32
+      memref.store %added, %collapsed[%c0] : memref<16xf32>
+      %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32
+      return %collapsed : memref<16xf32>
+    }
+    // mlir-opt --fuse-casts-into-consumers
+    func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
+      %c4 = arith.constant 4 : index
+      %c0 = arith.constant 0 : index
+      %expand_shape = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
+      %collapse_shape = memref.collapse_shape %expand_shape [[0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
+      %memspacecast = memref.memory_space_cast %collapse_shape : memref<16xf32, 1> to memref<16xf32>
+      %0 = memref.load %collapse_shape[%c0] : memref<16xf32, 1>
+      %1 = arith.addf %0, %arg2 : f32
+      memref.store %1, %collapse_shape[%c0] : memref<16xf32, 1>
+      %2 = memref.atomic_rmw addf %arg2, %collapse_shape[%c4] : (f32, memref<16xf32, 1>) -> f32
+      return %memspacecast : memref<16xf32>
+    }
+    ```
+  }];
+}
+
 #endif // MLIR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index 734294bd014c6..e25a0121a3359 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
   MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRIR
+  MLIRMemOpInterfaces
   MLIRMemorySlotInterfaces
   MLIRShapedOpInterfaces
   MLIRSideEffectInterfaces
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 5d15d5f6e3de4..0ddb2b0ca1645 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -111,6 +111,65 @@ static void constifyIndexValues(SmallVectorImpl<OpFoldResult> &values,
   }
 }
 
+/// Helper function to retrieve a fusable memory-space cast, and the
+/// corresponding new result memref type.
+static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
+getFuseCastInfo(BaseMemRefType resultTy, Value src) {
+  MemorySpaceCastOpInterface castOp =
+      MemorySpaceCastOpInterface::getIfFusableCast(src);
+
+  // Bail if the cast is not fusable.
+  if (!castOp)
+    return {};
+
+  // Transform the source and target type of `castOp` to have the same metadata
+  // as `resultTy`. Bail if not possible.
+  FailureOr<PtrLikeTypeInterface> srcTy = resultTy.clonePtrWith(
+      castOp.getSourcePtr().getType().getMemorySpace(), std::nullopt);
+  if (failed(srcTy))
+    return {};
+
+  FailureOr<PtrLikeTypeInterface> tgtTy = resultTy.clonePtrWith(
+      castOp.getTargetPtr().getType().getMemorySpace(), std::nullopt);
+  if (failed(tgtTy))
+    return {};
+
+  // Check if this is a valid memory-space cast.
+  if (!castOp.isValidMemorySpaceCast(*tgtTy, *srcTy))
+    return {};
+
+  return std::make_tuple(castOp, *tgtTy, *srcTy);
+}
+
+/// Implementation of `fuseCastOperands` method for memref operations that
+/// return a single memref result.
+template <typename ConcreteOpTy>
+static FailureOr<SmallVector<Value>>
+fuseCastOperandsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder,
+                                  bool &modifiedInPlace, OpOperand &src) {
+  auto [castOp, tgtTy, resTy] = getFuseCastInfo(op.getType(), src.get());
+  // Bail if we cannot cast.
+  if (!castOp)
+    return failure();
+
+  modifiedInPlace = false;
+
+  // Create the new operands.
+  SmallVector<Value> operands;
+  llvm::append_range(operands, op->getOperands());
+  operands[src.getOperandNumber()] = castOp.getSourcePtr();
+
+  // Create the fused op and results.
+  auto newOp = ConcreteOpTy::create(
+      builder, op.getLoc(), TypeRange(resTy), operands, op.getProperties(),
+      llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
+
+  // Insert a memory-space cast to the original memory space of the op.
+  MemorySpaceCastOpInterface result =
+      castOp.cloneMemorySpaceCastOp(builder, tgtTy, newOp.getResult());
+  return SmallVector<Value>({result.getTargetPtr()});
+}
+
 //===----------------------------------------------------------------------===//
 // AllocOp / AllocaOp
 //===----------------------------------------------------------------------===//
@@ -542,6 +601,12 @@ OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
   return getMemref();
 }
 
+FailureOr<SmallVector<Value>>
+AssumeAlignmentOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+                                           getMemrefMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // CastOp
 //===----------------------------------------------------------------------===//
@@ -710,6 +775,12 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
 }
 
+FailureOr<SmallVector<Value>> CastOp::fuseCastOperands(OpBuilder &builder,
+                                                       bool &modifiedInPlace) {
+  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+                                           getSourceMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // CopyOp
 //===----------------------------------------------------------------------===//
@@ -1601,6 +1672,12 @@ OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
   return OpFoldResult();
 }
 
+FailureOr<SmallVector<Value>> LoadOp::fuseCastOperands(OpBuilder &builder,
+                                                       bool &modifiedInPlace) {
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getMemrefMutable(), getResult(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // MemorySpaceCastOp
 //===----------------------------------------------------------------------===//
@@ -1645,6 +1722,33 @@ OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
   return Value{};
 }
 
+TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
+  return cast<TypedValue<PtrLikeTypeInterface>>(getSource());
+}
+
+TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
+  return cast<TypedValue<PtrLikeTypeInterface>>(getDest());
+}
+
+bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
+                                               PtrLikeTypeInterface src) {
+  return isa<MemRefType>(tgt) &&
+         tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
+}
+
+MemorySpaceCastOpInterface
+MemorySpaceCastOp::cloneMemorySpaceCastOp(OpBuilder &b, Type tgt, Value src) {
+  assert(isValidMemorySpaceCast(cast<PtrLikeTypeInterface>(tgt),
+                                cast<PtrLikeTypeInterface>(src.getType())) &&
+         "invalid arguments");
+  return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
+}
+
+bool MemorySpaceCastOp::isFusableMemorySpaceCast() {
+  // Only allow fusion when this is discarding information.
+  return getDest().getType().getMemorySpace() == nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // PrefetchOp
 //===----------------------------------------------------------------------===//
@@ -2041,6 +2145,12 @@ void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
 }
 
+FailureOr<SmallVector<Value>>
+ReinterpretCastOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+                                           getSourceMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // Reassociative reshape ops
 //===----------------------------------------------------------------------===//
@@ -2348,6 +2458,12 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
       ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
 }
 
+FailureOr<SmallVector<Value>>
+ExpandShapeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+                                           getSrcMutable());
+}
+
 /// Compute the layout map after collapsing a given source MemRef type with the
 /// specified reassociation indices.
 ///
@@ -2569,6 +2685,12 @@ OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
                                                        adaptor.getOperands());
 }
 
+FailureOr<SmallVector<Value>>
+CollapseShapeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+                                           getSrcMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // ReshapeOp
 //===----------------------------------------------------------------------===//
@@ -2609,6 +2731,12 @@ LogicalResult ReshapeOp::verify() {
   return success();
 }
 
+FailureOr<SmallVector<Value>>
+ReshapeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+                                           getSourceMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // StoreOp
 //===----------------------------------------------------------------------===//
@@ -2626,6 +2754,12 @@ LogicalResult StoreOp::fold(FoldAdaptor adaptor,
   return foldMemRefCast(*this, getValueToStore());
 }
 
+FailureOr<SmallVector<Value>> StoreOp::fuseCastOperands(OpBuilder &builder,
+                                                        bool &modifiedInPlace) {
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getMemrefMutable(), ValueRange(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // SubViewOp
 //===----------------------------------------------------------------------===//
@@ -3282,6 +3416,12 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+FailureOr<SmallVector<Value>>
+SubViewOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+                                           getSourceMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // TransposeOp
 //===----------------------------------------------------------------------===//
@@ -3382,6 +3522,12 @@ OpFoldResult TransposeOp::fold(FoldAdaptor) {
   return {};
 }
 
+FailureOr<SmallVector<Value>>
+TransposeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+                                           getInMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // ViewOp
 //===----------------------------------------------------------------------===//
@@ -3525,6 +3671,12 @@ void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
 }
 
+FailureOr<SmallVector<Value>> ViewOp::fuseCastOperands(OpBuilder &builder,
+                                                       bool &modifiedInPlace) {
+  return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+                                           getSourceMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // AtomicRMWOp
 //===----------------------------------------------------------------------===//
@@ -3570,6 +3722,12 @@ OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
   return OpFoldResult();
 }
 
+FailureOr<SmallVector<Value>>
+AtomicRMWOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getMemrefMutable(), getResult(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8d6e263934fb4..806e6c1c070aa 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5087,6 +5087,14 @@ void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<TransferReadAfterWriteToBroadcast>(context);
 }
 
+FailureOr<SmallVector<Value>>
+TransferReadOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  if (!hasPureBufferSemantics())
+    return failure();
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getBaseMutable(), getResult(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // TransferWriteOp
 //===----------------------------------------------------------------------===//
@@ -5574,6 +5582,14 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
 }
 
+FailureOr<SmallVector<Value>>
+TransferWriteOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  if (!hasPureBufferSemantics())
+    return failure();
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getBaseMutable(), ValueRange(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // LoadOp
 //===----------------------------------------------------------------------===//
@@ -5628,6 +5644,12 @@ std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
   return llvm::to_vector<4>(getVectorType().getShape());
 }
 
+FailureOr<SmallVector<Value>> LoadOp::fuseCastOperands(OpBuilder &builder,
+                                                       bool &modifiedInPlace) {
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getBaseMutable(), getResult(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // StoreOp
 //===----------------------------------------------------------------------===//
@@ -5667,6 +5689,12 @@ std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
   return llvm::to_vector<4>(getVectorType().getShape());
 }
 
+FailureOr<SmallVector<Value>> StoreOp::fuseCastOperands(OpBuilder &builder,
+                                                        bool &modifiedInPlace) {
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getBaseMutable(), ValueRange(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // MaskedLoadOp
 //===----------------------------------------------------------------------===//
@@ -5721,6 +5749,12 @@ OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
   return OpFoldResult();
 }
 
+FailureOr<SmallVector<Value>>
+MaskedLoadOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getBaseMutable(), getResult(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // MaskedStoreOp
 //===----------------------------------------------------------------------===//
@@ -5771,6 +5805,12 @@ LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
   return memref::foldMemRefCast(*this);
 }
 
+FailureOr<SmallVector<Value>>
+MaskedStoreOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getBaseMutable(), ValueRange(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // GatherOp
 //===----------------------------------------------------------------------===//
@@ -5874,6 +5914,12 @@ void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<GatherFolder, FoldContiguousGather>(context);
 }
 
+FailureOr<SmallVector<Value>>
+GatherOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getBaseMutable(), getResult(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // ScatterOp
 //===----------------------------------------------------------------------===//
@@ -5936,6 +5982,12 @@ void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ScatterFolder, FoldContiguousScatter>(context);
 }
 
+FailureOr<SmallVector<Value>>
+ScatterOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getBaseMutable(), ValueRange(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // ExpandLoadOp
 //===----------------------------------------------------------------------===//
@@ -5984,6 +6036,12 @@ void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ExpandLoadFolder>(context);
 }
 
+FailureOr<SmallVector<Value>>
+ExpandLoadOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getBaseMutable(), getResult(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // CompressStoreOp
 //===----------------------------------------------------------------------===//
@@ -6030,6 +6088,12 @@ void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<CompressStoreFolder>(context);
 }
 
+FailureOr<SmallVector<Value>>
+CompressStoreOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+  return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+      getBaseMutable(), ValueRange(), modifiedInPlace);
+}
+
 //===----------------------------------------------------------------------===//
 // ShapeCastOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index fdc19844702bc..388de1c3e5abf 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -11,6 +11,7 @@ set(LLVM_OPTIONAL_SOURCES
   InferIntRangeInterface.cpp
   InferTypeOpInterface.cpp
   LoopLikeInterface.cpp
+  MemOpInterfaces.cpp
   MemorySlotInterfaces.cpp
   ParallelCombiningOpInterface.cpp
   RuntimeVerifiableOpInterface.cpp
@@ -79,6 +80,7 @@ add_mlir_library(MLIRLoopLikeInterface
   MLIRFunctionInterfaces
 )
 
+add_mlir_interface_library(MemOpInterfaces)
 add_mlir_interface_library(MemorySlotInterfaces)
 add_mlir_interface_library(ParallelCombiningOpInterface)
 add_mlir_interface_library(RuntimeVerifiableOpInterface)
diff --git a/mlir/lib/Interfaces/MemOpInterfaces.cpp b/mlir/lib/Interfaces/MemOpInterfaces.cpp
new file mode 100644
index 0000000000000..013d828da1d66
--- /dev/null
+++ b/mlir/lib/Interfaces/MemOpInterfaces.cpp
@@ -0,0 +1,73 @@
+//===- MemOpInterfaces.cpp - Memory operation interfaces ---------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/MemOpInterfaces.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
+
+using namespace mlir;
+
+LogicalResult mlir::detail::verifyMemorySpaceCastOpInterface(Operation *op) {
+  auto memCastOp = cast<MemorySpaceCastOpInterface>(op);
+
+  // Verify that the source and target pointers are valid
+  Value sourcePtr = memCastOp.getSourcePtr();
+  Value targetPtr = memCastOp.getTargetPtr();
+
+  if (!sourcePtr || !targetPtr) {
+    return op->emitError()
+           << "memory space cast op must have valid source and target pointers";
+  }
+
+  if (sourcePtr.getType().getTypeID() != targetPtr.getType().getTypeID()) {
+    return op->emitError()
+           << "expected source and target types of the same kind";
+  }
+
+  // Verify the Types are of `PtrLikeTypeInterface` type.
+  auto sourceType = dyn_cast<PtrLikeTypeInterface>(sourcePtr.getType());
+  if (!sourceType) {
+    return op->emitError()
+           << "source type must implement `PtrLikeTypeInterface`, but got: "
+           << sourcePtr.getType();
+  }
+
+  auto targetType = dyn_cast<PtrLikeTypeInterface>(targetPtr.getType());
+  if (!targetType) {
+    return op->emitError()
+           << "target type must implement `PtrLikeTypeInterface`, but got: "
+           << targetPtr.getType();
+  }
+
+  // Verify that the operation has exactly one result
+  if (op->getNumResults() != 1) {
+    return op->emitError()
+           << "memory space cast op must have exactly one result";
+  }
+
+  return success();
+}
+
+FailureOr<SmallVector<Value>> mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+    OpOperand &operand, ValueRange results, bool &modifiedInPlace) {
+  MemorySpaceCastOpInterface castOp =
+      MemorySpaceCastOpInterface::getIfFusableCast(operand.get());
+
+  // Bail if the src is not produced by a `MemorySpaceCastOpInterface`.
+  if (!castOp)
+    return failure();
+
+  // Modify the op.
+  modifiedInPlace = true;
+  operand.set(castOp.getSourcePtr());
+  return llvm::to_vector_of<Value>(results);
+}
+
+#include "mlir/Interfaces/MemOpInterfaces.cpp.inc"
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 058039e47313e..e9a7d3e4abe99 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_library(MLIRTransforms
   ControlFlowSink.cpp
   CSE.cpp
   GenerateRuntimeVerification.cpp
+  FuseMemorySpaceCastsIntoConsumers.cpp
   InlinerPass.cpp
   LocationSnapshot.cpp
   LoopInvariantCodeMotion.cpp
@@ -31,6 +32,7 @@ add_mlir_library(MLIRTransforms
   MLIRAnalysis
   MLIRFunctionInterfaces
   MLIRLoopLikeInterface
+  MLIRMemOpInterfaces
   MLIRMemorySlotInterfaces
   MLIRPass
   MLIRRuntimeVerifiableOpInterface
diff --git a/mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp b/mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp
new file mode 100644
index 0000000000000..010b88ac12de2
--- /dev/null
+++ b/mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp
@@ -0,0 +1,73 @@
+//===- FuseMemorySpaceCastsIntoConsumers.cpp - Fuse casts transform -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/MemOpInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+
+namespace mlir {
+#define GEN_PASS_DEF_FUSEMEMORYSPACECASTSINTOCONSUMERS
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+namespace {
+//===----------------------------------------------------------------------===//
+// FuseCastsPattern pattern
+//===----------------------------------------------------------------------===//
+/// Pattern to fuse casts into consumer operations.
+struct FuseCastsPattern
+    : public OpInterfaceRewritePattern<FuseMemorySpaceCastConsumerOpInterface> {
+  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+
+  LogicalResult matchAndRewrite(FuseMemorySpaceCastConsumerOpInterface op,
+                                PatternRewriter &rewriter) const override {
+    bool modifiedInPlace = false;
+    FailureOr<SmallVector<Value>> results =
+        op.fuseCastOperands(rewriter, modifiedInPlace);
+    assert((!failed(results) || !modifiedInPlace) &&
+           "expected `modifiedInPlace` to be false on fusion failure");
+    if (failed(results))
+      return failure();
+    if (modifiedInPlace) {
+      rewriter.modifyOpInPlace(op, []() {});
+      return success();
+    }
+    rewriter.replaceOp(op, *results);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// FuseMemorySpaceCastsIntoConsumers pass
+//===----------------------------------------------------------------------===//
+
+struct FuseMemorySpaceCastsIntoConsumers
+    : public impl::FuseMemorySpaceCastsIntoConsumersBase<
+          FuseMemorySpaceCastsIntoConsumers> {
+  using impl::FuseMemorySpaceCastsIntoConsumersBase<
+      FuseMemorySpaceCastsIntoConsumers>::FuseMemorySpaceCastsIntoConsumersBase;
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateFuseMemorySpaceCastIntoConsumersPatterns(patterns);
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+void mlir::populateFuseMemorySpaceCastIntoConsumersPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<FuseCastsPattern>(patterns.getContext());
+}
diff --git a/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir b/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir
new file mode 100644
index 0000000000000..69a15f429cec2
--- /dev/null
+++ b/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir
@@ -0,0 +1,281 @@
+// RUN: mlir-opt %s --fuse-memory-space-casts-into-consumers | FileCheck %s
+
+#map = affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>
+// CHECK-LABEL:   func.func @load_store(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) {
+// CHECK:           %[[VAL_0:.*]] = memref.load %[[ARG0]]{{\[}}%[[ARG1]]] : memref<?xf32, 1>
+// CHECK:           memref.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[ARG1]]] : memref<?xf32, 1>
+// CHECK:           return
+// CHECK:         }
+func.func @load_store(%arg0: memref<?xf32, 1>, %arg1: index) {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %0 = memref.load %memspacecast[%arg1] : memref<?xf32>
+  memref.store %0, %memspacecast[%arg1] : memref<?xf32>
+  return
+}
+
+// CHECK-LABEL:   func.func @load_store_unfoldable(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) {
+// CHECK:           %[[VAL_0:.*]] = memref.memory_space_cast %[[ARG0]] : memref<?xf32, 1> to memref<?xf32, 2>
+// CHECK:           %[[VAL_1:.*]] = memref.load %[[VAL_0]]{{\[}}%[[ARG1]]] : memref<?xf32, 2>
+// CHECK:           memref.store %[[VAL_1]], %[[VAL_0]]{{\[}}%[[ARG1]]] : memref<?xf32, 2>
+// CHECK:           return
+// CHECK:         }
+func.func @load_store_unfoldable(%arg0: memref<?xf32, 1>, %arg1: index) {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32, 2>
+  %0 = memref.load %memspacecast[%arg1] : memref<?xf32, 2>
+  memref.store %0, %memspacecast[%arg1] : memref<?xf32, 2>
+  return
+}
+
+// CHECK-LABEL:   func.func @view(
+// CHECK-SAME:                    %[[ARG0:.*]]: memref<?xi8, 1>,
+// CHECK-SAME:                    %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> memref<?x?xi8> {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 100 : index
+// CHECK:           %[[VAL_1:.*]] = memref.view %[[ARG0]]{{\[}}%[[ARG1]]]{{\[}}%[[ARG2]], %[[VAL_0]]] : memref<?xi8, 1> to memref<?x?xi8, 1>
+// CHECK:           %[[VAL_2:.*]] = memref.memory_space_cast %[[VAL_1]] : memref<?x?xi8, 1> to memref<?x?xi8>
+// CHECK:           return %[[VAL_2]] : memref<?x?xi8>
+// CHECK:         }
+func.func @view(%arg0: memref<?xi8, 1>, %arg1: index, %arg2: index) -> memref<?x?xi8> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xi8, 1> to memref<?xi8>
+  %c100 = arith.constant 100 : index
+  %view = memref.view %memspacecast[%arg1][%arg2, %c100] : memref<?xi8> to memref<?x?xi8>
+  return %view : memref<?x?xi8>
+}
+
+// CHECK-LABEL:   func.func @subview(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?x?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) -> memref<8x2xf32, strided<[?, 2], offset: ?>> {
+// CHECK:           %[[VAL_0:.*]] = memref.subview %[[ARG0]][4, 2] [8, 2] [3, 2] : memref<?x?xf32, 1> to memref<8x2xf32, strided<[?, 2], offset: ?>, 1>
+// CHECK:           %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<8x2xf32, strided<[?, 2], offset: ?>, 1> to memref<8x2xf32, strided<[?, 2], offset: ?>>
+// CHECK:           return %[[VAL_1]] : memref<8x2xf32, strided<[?, 2], offset: ?>>
+// CHECK:         }
+func.func @subview(%arg0: memref<?x?xf32, 1>, %arg1: index) -> memref<8x2xf32, strided<[?, 2], offset: ?>> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?x?xf32, 1> to memref<?x?xf32>
+  %subview = memref.subview %memspacecast[4, 2] [8, 2] [3, 2] : memref<?x?xf32> to memref<8x2xf32, strided<[?, 2], offset: ?>>
+  return %subview : memref<8x2xf32, strided<[?, 2], offset: ?>>
+}
+
+// CHECK-LABEL:   func.func @reinterpret_cast(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 10 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_2:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[VAL_1]]], sizes: [10, %[[VAL_0]]], strides: {{\[}}%[[VAL_0]], 1] : memref<?xf32, 1> to memref<10x?xf32, strided<[?, 1], offset: ?>, 1>
+// CHECK:           %[[VAL_3:.*]] = memref.memory_space_cast %[[VAL_2]] : memref<10x?xf32, strided<[?, 1], offset: ?>, 1> to memref<10x?xf32, strided<[?, 1], offset: ?>>
+// CHECK:           return %[[VAL_3]] : memref<10x?xf32, strided<[?, 1], offset: ?>>
+// CHECK:         }
+func.func @reinterpret_cast(%arg0: memref<?xf32, 1>, %arg1: index) -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %c0 = arith.constant 0 : index
+  %c10 = arith.constant 10 : index
+  %reinterpret_cast = memref.reinterpret_cast %memspacecast to offset: [%c0], sizes: [10, %c10], strides: [%c10, 1] : memref<?xf32> to memref<10x?xf32, strided<[?, 1], offset: ?>>
+  return %reinterpret_cast : memref<10x?xf32, strided<[?, 1], offset: ?>>
+}
+
+// CHECK-LABEL:   func.func @reshape(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?x?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: memref<1xindex>) -> memref<?xf32> {
+// CHECK:           %[[VAL_0:.*]] = memref.reshape %[[ARG0]](%[[ARG1]]) : (memref<?x?xf32, 1>, memref<1xindex>) -> memref<?xf32, 1>
+// CHECK:           %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<?xf32, 1> to memref<?xf32>
+// CHECK:           return %[[VAL_1]] : memref<?xf32>
+// CHECK:         }
+func.func @reshape(%arg0: memref<?x?xf32, 1>, %arg1: memref<1xindex>) -> memref<?xf32> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?x?xf32, 1> to memref<?x?xf32>
+  %reshape = memref.reshape %memspacecast(%arg1) : (memref<?x?xf32>, memref<1xindex>) -> memref<?xf32>
+  return %reshape : memref<?xf32>
+}
+
+// CHECK-LABEL:   func.func @expand_shape(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<12xf32, 1>) -> memref<3x4xf32> {
+// CHECK:           %[[VAL_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1]] output_shape [3, 4] : memref<12xf32, 1> into memref<3x4xf32, 1>
+// CHECK:           %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<3x4xf32, 1> to memref<3x4xf32>
+// CHECK:           return %[[VAL_1]] : memref<3x4xf32>
+// CHECK:         }
+func.func @expand_shape(%arg0: memref<12xf32, 1>) -> memref<3x4xf32> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<12xf32, 1> to memref<12xf32>
+  %expand_shape = memref.expand_shape %memspacecast [[0, 1]] output_shape [3, 4] : memref<12xf32> into memref<3x4xf32>
+  return %expand_shape : memref<3x4xf32>
+}
+
+// CHECK-LABEL:   func.func @collapse_shape(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<3x4xf32, 1>) -> memref<12xf32> {
+// CHECK:           %[[VAL_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[\[}}0, 1]] : memref<3x4xf32, 1> into memref<12xf32, 1>
+// CHECK:           %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<12xf32, 1> to memref<12xf32>
+// CHECK:           return %[[VAL_1]] : memref<12xf32>
+// CHECK:         }
+func.func @collapse_shape(%arg0: memref<3x4xf32, 1>) -> memref<12xf32> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<3x4xf32, 1> to memref<3x4xf32>
+  %collapse_shape = memref.collapse_shape %memspacecast [[0, 1]] : memref<3x4xf32> into memref<12xf32>
+  return %collapse_shape : memref<12xf32>
+}
+
+// CHECK-LABEL:   func.func @transpose(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?x?xf32, 1>) -> memref<?x?xf32, #[[$ATTR_0]]> {
+// CHECK:           %[[VAL_0:.*]] = memref.transpose %[[ARG0]] (d0, d1) -> (d1, d0) : memref<?x?xf32, 1> to memref<?x?xf32, #[[$ATTR_0]], 1>
+// CHECK:           %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<?x?xf32, #[[$ATTR_0]], 1> to memref<?x?xf32, #[[$ATTR_0]]>
+// CHECK:           return %[[VAL_1]] : memref<?x?xf32, #[[$ATTR_0]]>
+// CHECK:         }
+func.func @transpose(%arg0: memref<?x?xf32, 1>) -> memref<?x?xf32, #map> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?x?xf32, 1> to memref<?x?xf32>
+  %transpose = memref.transpose %memspacecast (d0, d1) -> (d1, d0) : memref<?x?xf32> to memref<?x?xf32, #map>
+  return %transpose : memref<?x?xf32, #map>
+}
+
+// CHECK-LABEL:   func.func @atomic_rmw(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index,
+// CHECK-SAME:      %[[ARG2:.*]]: f32) -> f32 {
+// CHECK:           %[[VAL_0:.*]] = memref.atomic_rmw addf %[[ARG2]], %[[ARG0]]{{\[}}%[[ARG1]]] : (f32, memref<?xf32, 1>) -> f32
+// CHECK:           return %[[VAL_0]] : f32
+// CHECK:         }
+func.func @atomic_rmw(%arg0: memref<?xf32, 1>, %arg1: index, %arg2: f32) -> f32 {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %0 = memref.atomic_rmw addf %arg2, %memspacecast[%arg1] : (f32, memref<?xf32>) -> f32
+  return %0 : f32
+}
+
+// CHECK-LABEL:   func.func @assume_alignment(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>) -> memref<?xf32> {
+// CHECK:           %[[VAL_0:.*]] = memref.assume_alignment %[[ARG0]], 16 : memref<?xf32, 1>
+// CHECK:           %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<?xf32, 1> to memref<?xf32>
+// CHECK:           return %[[VAL_1]] : memref<?xf32>
+// CHECK:         }
+func.func @assume_alignment(%arg0: memref<?xf32, 1>) -> memref<?xf32> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %1 = memref.assume_alignment %memspacecast, 16 : memref<?xf32>
+  return %1 : memref<?xf32>
+}
+
+// CHECK-LABEL:   func.func @op_with_cast_sequence(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<4x4xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index,
+// CHECK-SAME:      %[[ARG2:.*]]: f32) -> memref<16xf32> {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 4 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_2:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
+// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
+// CHECK:           %[[VAL_4:.*]] = memref.memory_space_cast %[[VAL_3]] : memref<16xf32, 1> to memref<16xf32>
+// CHECK:           %[[VAL_5:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_1]]] : memref<16xf32, 1>
+// CHECK:           %[[VAL_6:.*]] = arith.addf %[[VAL_5]], %[[ARG2]] : f32
+// CHECK:           memref.store %[[VAL_6]], %[[VAL_3]]{{\[}}%[[VAL_1]]] : memref<16xf32, 1>
+// CHECK:           %[[VAL_7:.*]] = memref.atomic_rmw addf %[[ARG2]], %[[VAL_3]]{{\[}}%[[VAL_0]]] : (f32, memref<16xf32, 1>) -> f32
+// CHECK:           return %[[VAL_4]] : memref<16xf32>
+// CHECK:         }
+func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32>
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32>
+  %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32>
+  %loaded = memref.load %collapsed[%c0] : memref<16xf32>
+  %added = arith.addf %loaded, %arg2 : f32
+  memref.store %added, %collapsed[%c0] : memref<16xf32>
+  %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32
+  return %collapsed : memref<16xf32>
+}
+
+// CHECK-LABEL:   func.func @transfer_read_write(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_1:.*]] = vector.transfer_read %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xf32>
+// CHECK:           vector.transfer_write %[[VAL_1]], %[[ARG0]]{{\[}}%[[ARG1]]] : vector<4xf32>, memref<?xf32, 1>
+// CHECK:           return
+// CHECK:         }
+func.func @transfer_read_write(%arg0: memref<?xf32, 1>, %arg1: index) {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %c0 = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %memspacecast[%arg1], %c0 : memref<?xf32>, vector<4xf32>
+  vector.transfer_write %0, %memspacecast[%arg1] : vector<4xf32>, memref<?xf32>
+  return
+}
+
+// NOTE: The operations disappear because they can get folded.
+// CHECK-LABEL:   func.func @transfer_read_write_tensor(
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<?xf32>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) -> tensor<?xf32> {
+// CHECK:           return %[[ARG0]] : tensor<?xf32>
+// CHECK:         }
+func.func @transfer_read_write_tensor(%arg0: tensor<?xf32>, %arg1: index) -> tensor<?xf32> {
+  %c0 = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %arg0[%arg1], %c0 : tensor<?xf32>, vector<4xf32>
+  %1 = vector.transfer_write %0, %arg0[%arg1] : vector<4xf32>, tensor<?xf32>
+  return %1 : tensor<?xf32>
+}
+
+// CHECK-LABEL:   func.func @vector_load_store(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) {
+// CHECK:           %[[VAL_0:.*]] = vector.load %[[ARG0]]{{\[}}%[[ARG1]]] : memref<?xf32, 1>, vector<4xf32>
+// CHECK:           vector.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[ARG1]]] : memref<?xf32, 1>, vector<4xf32>
+// CHECK:           return
+// CHECK:         }
+func.func @vector_load_store(%arg0: memref<?xf32, 1>, %arg1: index) {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %0 = vector.load %memspacecast[%arg1] : memref<?xf32>, vector<4xf32>
+  vector.store %0, %memspacecast[%arg1] : memref<?xf32>, vector<4xf32>
+  return
+}
+
+// CHECK-LABEL:   func.func @masked_load_store(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK:           %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+// CHECK:           %[[VAL_2:.*]] = vector.maskedload %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+// CHECK:           vector.maskedstore %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_2]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32>
+// CHECK:           return
+// CHECK:         }
+func.func @masked_load_store(%arg0: memref<?xf32, 1>, %arg1: index) {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %mask = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+  %passthrough = arith.constant dense<0.0> : vector<4xf32>
+  %0 = vector.maskedload %memspacecast[%arg1], %mask, %passthrough : memref<?xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  vector.maskedstore %memspacecast[%arg1], %mask, %0 : memref<?xf32>, vector<4xi1>, vector<4xf32>
+  return
+}
+
+// CHECK-LABEL:   func.func @gather_scatter(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK:           %[[VAL_1:.*]] = arith.constant dense<true> : vector<4xi1>
+// CHECK:           %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = vector.gather %[[ARG0]]{{\[}}%[[VAL_3]]] {{\[}}%[[VAL_2]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+// CHECK:           vector.scatter %[[ARG0]]{{\[}}%[[VAL_3]]] {{\[}}%[[VAL_2]]], %[[VAL_1]], %[[VAL_4]] : memref<?xf32, 1>, vector<4xindex>, vector<4xi1>, vector<4xf32>
+// CHECK:           return
+// CHECK:         }
+func.func @gather_scatter(%arg0: memref<?xf32, 1>, %arg1: index) {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %c0 = arith.constant 0 : index
+  %indices = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+  %mask = arith.constant dense<true> : vector<4xi1>
+  %passthrough = arith.constant dense<0.0> : vector<4xf32>
+  %0 = vector.gather %memspacecast[%c0] [%indices], %mask, %passthrough : memref<?xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  vector.scatter %memspacecast[%c0] [%indices], %mask, %0 : memref<?xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32>
+  return
+}
+
+// CHECK-LABEL:   func.func @expandload_compressstore(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK:           %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+// CHECK:           %[[VAL_2:.*]] = vector.expandload %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+// CHECK:           vector.compressstore %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_2]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32>
+// CHECK:           return
+// CHECK:         }
+func.func @expandload_compressstore(%arg0: memref<?xf32, 1>, %arg1: index) {
+  %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+  %mask = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+  %passthrough = arith.constant dense<0.0> : vector<4xf32>
+  %0 = vector.expandload %memspacecast[%arg1], %mask, %passthrough : memref<?xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  vector.compressstore %memspacecast[%arg1], %mask, %0 : memref<?xf32>, vector<4xi1>, vector<4xf32>
+  return
+}

>From 1e49e3dff6ca2f4c3b551d164310e2318a0a46cb Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Thu, 18 Sep 2025 07:04:24 -0400
Subject: [PATCH 2/2] Update mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Co-authored-by: Copilot <175728472+Copilot at users.noreply.github.com>
---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 0ddb2b0ca1645..2201b237cfdda 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -118,7 +118,7 @@ getFuseCastInfo(BaseMemRefType resultTy, Value src) {
   MemorySpaceCastOpInterface castOp =
       MemorySpaceCastOpInterface::getIfFusableCast(src);
 
-  // Bail if the cast is not fusable.
+  // Bail if the cast is not fusible.
   if (!castOp)
     return {};
 



More information about the Mlir-commits mailing list