[Mlir-commits] [mlir] [mlir] Implement memory-space cast operand fusion into consumers (PR #159454)
Fabian Mora
llvmlistbot at llvm.org
Thu Sep 18 04:04:32 PDT 2025
https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/159454
>From a15b8ca9bd287d4ad6af320cae7ae01afb4234bc Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Wed, 17 Sep 2025 21:07:38 +0000
Subject: [PATCH 1/2] [mlir] Implement memory-space cast operand fusion into
consumers
This commit adds functionality to fuse memory-space casts into consumer operations,
allowing operations to be performed directly on the original memory-space rather
than first casting to a different memory space.
Key changes:
- Introduce `MemorySpaceCastOpInterface` to handle memory-space cast operations
- Create a `FuseMemorySpaceCastsIntoConsumers` pass that identifies and fuses eligible casts
- Add implementation for memref and vector operations to handle memory-space cast fusion
- Add fuseCastOperands method to relevant operations to support the fusion
In particular, in the current implementation only memory-space casts into the default
memory-space can be fused.
Example:
```mlir
func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
%memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32>
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32>
%collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32>
%loaded = memref.load %collapsed[%c0] : memref<16xf32>
%added = arith.addf %loaded, %arg2 : f32
memref.store %added, %collapsed[%c0] : memref<16xf32>
%atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32
return %collapsed : memref<16xf32>
}
// mlir-opt --fuse-memory-space-casts-into-consumers
func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
%c4 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%expand_shape = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
%collapse_shape = memref.collapse_shape %expand_shape [[0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
%memspacecast = memref.memory_space_cast %collapse_shape : memref<16xf32, 1> to memref<16xf32>
%0 = memref.load %collapse_shape[%c0] : memref<16xf32, 1>
%1 = arith.addf %0, %arg2 : f32
memref.store %1, %collapse_shape[%c0] : memref<16xf32, 1>
%2 = memref.atomic_rmw addf %arg2, %collapse_shape[%c4] : (f32, memref<16xf32, 1>) -> f32
return %memspacecast : memref<16xf32>
}
```
Signed-off-by: Fabian Mora <fabian.mora-cordero at amd.com>
---
mlir/include/mlir/Dialect/MemRef/IR/MemRef.h | 1 +
.../mlir/Dialect/MemRef/IR/MemRefOps.td | 19 +-
.../mlir/Dialect/Vector/IR/VectorOps.h | 1 +
.../mlir/Dialect/Vector/IR/VectorOps.td | 16 +-
mlir/include/mlir/Interfaces/CMakeLists.txt | 1 +
.../include/mlir/Interfaces/MemOpInterfaces.h | 37 +++
.../mlir/Interfaces/MemOpInterfaces.td | 114 +++++++
.../FuseMemorySpaceCastsIntoConsumers.h | 20 ++
mlir/include/mlir/Transforms/Passes.h | 1 +
mlir/include/mlir/Transforms/Passes.td | 40 +++
mlir/lib/Dialect/MemRef/IR/CMakeLists.txt | 1 +
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 158 ++++++++++
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 64 ++++
mlir/lib/Interfaces/CMakeLists.txt | 2 +
mlir/lib/Interfaces/MemOpInterfaces.cpp | 73 +++++
mlir/lib/Transforms/CMakeLists.txt | 2 +
.../FuseMemorySpaceCastsIntoConsumers.cpp | 73 +++++
.../test-fuse-casts-into-consumers.mlir | 281 ++++++++++++++++++
18 files changed, 897 insertions(+), 7 deletions(-)
create mode 100644 mlir/include/mlir/Interfaces/MemOpInterfaces.h
create mode 100644 mlir/include/mlir/Interfaces/MemOpInterfaces.td
create mode 100644 mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h
create mode 100644 mlir/lib/Interfaces/MemOpInterfaces.cpp
create mode 100644 mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp
create mode 100644 mlir/test/Transforms/test-fuse-casts-into-consumers.mlir
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index bdec699eb4ce4..30f33ed2fd1d6 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -18,6 +18,7 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/MemOpInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 671cc05e963b4..238a767ac8b73 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -15,6 +15,7 @@ include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/MemOpInterfaces.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/ShapedOpInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -145,7 +146,8 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
Pure,
ViewLikeOpInterface,
- SameOperandsAndResultType
+ SameOperandsAndResultType,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
]> {
let summary =
"assumption that gives alignment information to the input memref";
@@ -456,6 +458,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
def MemRef_CastOp : MemRef_Op<"cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
MemRefsNormalizable,
Pure,
SameOperandsAndResultShape,
@@ -1194,6 +1197,7 @@ def LoadOp : MemRef_Op<"load",
"memref", "result",
"::llvm::cast<MemRefType>($_self).getElementType()">,
MemRefsNormalizable,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
let summary = "load operation";
@@ -1284,6 +1288,7 @@ def LoadOp : MemRef_Op<"load",
def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<MemorySpaceCastOpInterface>,
MemRefsNormalizable,
Pure,
SameOperandsAndResultElementType,
@@ -1376,6 +1381,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
def MemRef_ReinterpretCastOp
: MemRef_OpWithOffsetSizesAndStrides<"reinterpret_cast", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
AttrSizedOperandSegments,
MemRefsNormalizable,
Pure,
@@ -1603,6 +1609,7 @@ def MemRef_RankOp : MemRef_Op<"rank", [Pure]> {
def MemRef_ReshapeOp: MemRef_Op<"reshape", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
Pure,
ViewLikeOpInterface]> {
let summary = "memref reshape operation";
@@ -1701,6 +1708,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
let summary = "operation to produce a memref with a higher rank.";
let description = [{
@@ -1822,7 +1830,9 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
}
def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
+ ]> {
let summary = "operation to produce a memref with a smaller rank.";
let description = [{
The `memref.collapse_shape` op produces a new view with a smaller rank
@@ -1929,6 +1939,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
"memref", "value",
"::llvm::cast<MemRefType>($_self).getElementType()">,
MemRefsNormalizable,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
let summary = "store operation";
@@ -2006,6 +2017,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface,
@@ -2281,6 +2293,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
def MemRef_TransposeOp : MemRef_Op<"transpose", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
Pure]>,
Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
Results<(outs AnyStridedMemRef)> {
@@ -2316,6 +2329,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [
def MemRef_ViewOp : MemRef_Op<"view", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
Pure]> {
let summary = "memref view operation";
@@ -2392,6 +2406,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
//===----------------------------------------------------------------------===//
def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
AllTypesMatch<["value", "result"]>,
TypesMatchWith<"value type matches element type of memref",
"memref", "value",
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 63410b8bea747..bbf55f5d507e3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -27,6 +27,7 @@
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/IndexingMapOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/MemOpInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 26d06624cb976..93e9bfc78ea75 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -24,6 +24,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/IndexingMapOpInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/MemOpInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
@@ -1246,6 +1247,7 @@ def Vector_TransferReadOp :
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
AttrSizedOperandSegments,
DestinationStyleOpInterface
]>,
@@ -1493,6 +1495,7 @@ def Vector_TransferWriteOp :
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
AttrSizedOperandSegments,
DestinationStyleOpInterface
]>,
@@ -1649,6 +1652,7 @@ def Vector_TransferWriteOp :
def Vector_LoadOp : Vector_Op<"load", [
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
]> {
let summary = "reads an n-D slice of memory into an n-D vector";
let description = [{
@@ -1765,6 +1769,7 @@ def Vector_LoadOp : Vector_Op<"load", [
def Vector_StoreOp : Vector_Op<"store", [
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
]> {
let summary = "writes an n-D vector to an n-D slice of memory";
let description = [{
@@ -1869,7 +1874,7 @@ def Vector_StoreOp : Vector_Op<"store", [
}
def Vector_MaskedLoadOp :
- Vector_Op<"maskedload">,
+ Vector_Op<"maskedload", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -1961,7 +1966,7 @@ def Vector_MaskedLoadOp :
}
def Vector_MaskedStoreOp :
- Vector_Op<"maskedstore">,
+ Vector_Op<"maskedstore", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2041,6 +2046,7 @@ def Vector_MaskedStoreOp :
def Vector_GatherOp :
Vector_Op<"gather", [
DeclareOpInterfaceMethods<MaskableOpInterface>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
@@ -2144,7 +2150,7 @@ def Vector_GatherOp :
}
def Vector_ScatterOp :
- Vector_Op<"scatter">,
+ Vector_Op<"scatter", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$offsets,
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
@@ -2229,7 +2235,7 @@ def Vector_ScatterOp :
}
def Vector_ExpandLoadOp :
- Vector_Op<"expandload">,
+ Vector_Op<"expandload", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2317,7 +2323,7 @@ def Vector_ExpandLoadOp :
}
def Vector_CompressStoreOp :
- Vector_Op<"compressstore">,
+ Vector_Op<"compressstore", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index 2add220fdfb7c..a5feb592045c0 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_interface(IndexingMapOpInterface)
add_mlir_interface(InferIntRangeInterface)
add_mlir_interface(InferTypeOpInterface)
add_mlir_interface(LoopLikeInterface)
+add_mlir_interface(MemOpInterfaces)
add_mlir_interface(ParallelCombiningOpInterface)
add_mlir_interface(RuntimeVerifiableOpInterface)
add_mlir_interface(ShapedOpInterfaces)
diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.h b/mlir/include/mlir/Interfaces/MemOpInterfaces.h
new file mode 100644
index 0000000000000..cc9f4c6b3882e
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.h
@@ -0,0 +1,37 @@
+//===- MemOpInterfaces.h - Memory operation interfaces ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains declarations of interfaces for operations that interact
+// with memory.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_MEMOPINTERFACES_H
+#define MLIR_INTERFACES_MEMOPINTERFACES_H
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace detail {
+/// Attempt to verify the given memory space cast operation.
+LogicalResult verifyMemorySpaceCastOpInterface(Operation *op);
+
+/// Tries to fuse inplace a `MemorySpaceCastOpInterface` operation referenced by
+/// `operand`. On success, it returns `results`, and sets `modifiedInPlace` to
+/// true. It returns failure if `operand` doesn't reference a
+/// `MemorySpaceCastOpInterface` op.
+FailureOr<SmallVector<Value>>
+fuseInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results,
+ bool &modifiedInPlace);
+} // namespace detail
+} // namespace mlir
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/MemOpInterfaces.h.inc"
+
+#endif // MLIR_INTERFACES_MEMOPINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.td b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
new file mode 100644
index 0000000000000..0b8ba19171fb7
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
@@ -0,0 +1,114 @@
+//===- MemOpInterfaces.td - Memory operation interfaces -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains interfaces for operations that interact with memory.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_MEMOPINTERFACES_TD
+#define MLIR_INTERFACES_MEMOPINTERFACES_TD
+
+include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+def FuseMemorySpaceCastConsumerOpInterface :
+ OpInterface<"FuseMemorySpaceCastConsumerOpInterface"> {
+ let description = [{
+ An interface to fuse memory-space cast operands into a consumer operation.
+ It is the responsibility of the interface to determine which casts can be
+ fused into the operation.
+ }];
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<[{
+ Attempt to fuse the incoming cast-like operands. Returns `success`
+ and any new results on fusion success, otherwise it returns failure.
+ If new results are produced, these must be compatible with the original
+ operation results.
+
+ The `modifiedInPlace` parameter indicates whether the operation was
+ modified in place. If `false` and the fusion succeeded, then the
+ interface guarantees it is valid to erase the original operation.
+ If `true`, then the interface must guarantee no operations were created
+ by the method, and that no further IR modification is necessary. It is
+ considered an error if `modifiedInPlace` is true and the fusion failed.
+
+ Any implementations of this method must not erase/replace the original
+ operation, instead it is the caller responsibility to erase or replace
+ the op with the results provided by the method.
+
+ Finally, any implementations of this method have to guarantee that the
+ IR remains valid at all times.
+ }],
+ "::llvm::FailureOr<::llvm::SmallVector<::mlir::Value>>", "fuseCastOperands",
+ (ins "::mlir::OpBuilder &":$builder, "bool &":$modifiedInPlace)
+ >,
+ ];
+}
+
+def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
+ let description = [{
+ An interface for operations that perform memory-space casts. This
+ interface assumes that the cast operation is `pure`.
+
+ These operations expect to have a well-defined ptr-like operand, and
+ a well-defined target ptr-like result.
+ }];
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<[{
+ Returns the source ptr-like value.
+ }],
+ "::mlir::TypedValue<::mlir::PtrLikeTypeInterface>", "getSourcePtr"
+ >,
+ InterfaceMethod<[{
+ Returns the target ptr-like value.
+ }],
+ "::mlir::TypedValue<::mlir::PtrLikeTypeInterface>", "getTargetPtr"
+ >,
+ InterfaceMethod<[{
+ Returns whether the memory space cast specified by `tgt` and `src`
+ is supported.
+ }],
+ "bool", "isValidMemorySpaceCast",
+ (ins "::mlir::PtrLikeTypeInterface":$tgt,
+ "::mlir::PtrLikeTypeInterface":$src)
+ >,
+ InterfaceMethod<[{
+ Clones the memory space cast op with the given source and target type.
+ }],
+ "::mlir::MemorySpaceCastOpInterface", "cloneMemorySpaceCastOp",
+ (ins "::mlir::OpBuilder &":$builder, "::mlir::Type":$tgt,
+ "::mlir::Value":$src)
+ >,
+ InterfaceMethod<[{
+ Returns whether the cast allows to be fused.
+ }],
+ "bool", "isFusableMemorySpaceCast"
+ >
+ ];
+ let verify = [{
+ return ::mlir::detail::verifyMemorySpaceCastOpInterface($_op);
+ }];
+ let dependentTraits = [Pure];
+ let extraClassDeclaration = [{
+ /// Returns the underlying `MemorySpaceCastOpInterface` op if `value`
+ /// is produced by a `MemorySpaceCastOpInterface` op, and
+ /// `isFusableMemorySpaceCast` returns true, otherwise it returns null.
+ static ::mlir::MemorySpaceCastOpInterface
+ getIfFusableCast(::mlir::Value value) {
+ auto op = ::llvm::dyn_cast_or_null<::mlir::MemorySpaceCastOpInterface>(
+ value.getDefiningOp());
+ if (!op || !op.isFusableMemorySpaceCast())
+ return nullptr;
+ return op;
+ }
+ }];
+}
+
+#endif // MLIR_INTERFACES_MEMOPINTERFACES_TD
diff --git a/mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h b/mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h
new file mode 100644
index 0000000000000..9333f92a10289
--- /dev/null
+++ b/mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h
@@ -0,0 +1,20 @@
+//===-- FuseMemorySpaceCastsIntoConsumers.h - Cast fusion patterns -C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_FUSEMEMORYSPACECASTSINTOCONSUMERS_H
+#define MLIR_TRANSFORMS_FUSEMEMORYSPACECASTSINTOCONSUMERS_H
+
+namespace mlir {
+class RewritePatternSet;
+/// Collect a set of patterns to fuse memory-space cast operations into
+/// consumers.
+void populateFuseMemorySpaceCastIntoConsumersPatterns(
+ RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_FUSEMEMORYSPACECASTSINTOCONSUMERS_H
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 9cd2ef34e15ea..610a9671fede8 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -46,6 +46,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_SYMBOLPRIVATIZE
#define GEN_PASS_DECL_TOPOLOGICALSORT
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
+#define GEN_PASS_DECL_FUSEMEMORYSPACECASTSINTOCONSUMERS
#include "mlir/Transforms/Passes.h.inc"
/// Creates an instance of the Canonicalizer pass, configured with default
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index beb59784947c5..69280e3d443ea 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -585,4 +585,44 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
];
}
+def FuseMemorySpaceCastsIntoConsumers :
+ Pass<"fuse-memory-space-casts-into-consumers"> {
+ let summary = "Fuses memory-space cast operations into consumers.";
+ let description = [{
+ This pass tries to fuse all possible memory-space cast operations into their consumers.
+ It does this by looking for `FuseMemorySpaceCastConsumerOpInterface`
+ operations, and invoking the interface methods to perform the fusion.
+
+ Example:
+
+ ```mlir
+ func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32>
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32>
+ %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32>
+ %loaded = memref.load %collapsed[%c0] : memref<16xf32>
+ %added = arith.addf %loaded, %arg2 : f32
+ memref.store %added, %collapsed[%c0] : memref<16xf32>
+ %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32
+ return %collapsed : memref<16xf32>
+ }
+ // mlir-opt --fuse-casts-into-consumers
+ func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
+ %c4 = arith.constant 4 : index
+ %c0 = arith.constant 0 : index
+ %expand_shape = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
+ %collapse_shape = memref.collapse_shape %expand_shape [[0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
+ %memspacecast = memref.memory_space_cast %collapse_shape : memref<16xf32, 1> to memref<16xf32>
+ %0 = memref.load %collapse_shape[%c0] : memref<16xf32, 1>
+ %1 = arith.addf %0, %arg2 : f32
+ memref.store %1, %collapse_shape[%c0] : memref<16xf32, 1>
+ %2 = memref.atomic_rmw addf %arg2, %collapse_shape[%c4] : (f32, memref<16xf32, 1>) -> f32
+ return %memspacecast : memref<16xf32>
+ }
+ ```
+ }];
+}
+
#endif // MLIR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index 734294bd014c6..e25a0121a3359 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRIR
+ MLIRMemOpInterfaces
MLIRMemorySlotInterfaces
MLIRShapedOpInterfaces
MLIRSideEffectInterfaces
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 5d15d5f6e3de4..0ddb2b0ca1645 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -111,6 +111,65 @@ static void constifyIndexValues(SmallVectorImpl<OpFoldResult> &values,
}
}
+/// Helper function to retrieve a fusable memory-space cast, and the
+/// corresponding new result memref type.
+static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
+getFuseCastInfo(BaseMemRefType resultTy, Value src) {
+ MemorySpaceCastOpInterface castOp =
+ MemorySpaceCastOpInterface::getIfFusableCast(src);
+
+ // Bail if the cast is not fusable.
+ if (!castOp)
+ return {};
+
+ // Transform the source and target type of `castOp` to have the same metadata
+ // as `resultTy`. Bail if not possible.
+ FailureOr<PtrLikeTypeInterface> srcTy = resultTy.clonePtrWith(
+ castOp.getSourcePtr().getType().getMemorySpace(), std::nullopt);
+ if (failed(srcTy))
+ return {};
+
+ FailureOr<PtrLikeTypeInterface> tgtTy = resultTy.clonePtrWith(
+ castOp.getTargetPtr().getType().getMemorySpace(), std::nullopt);
+ if (failed(tgtTy))
+ return {};
+
+ // Check if this is a valid memory-space cast.
+ if (!castOp.isValidMemorySpaceCast(*tgtTy, *srcTy))
+ return {};
+
+ return std::make_tuple(castOp, *tgtTy, *srcTy);
+}
+
+/// Implementation of `fuseCastOperands` method for memref operations that
+/// return a single memref result.
+template <typename ConcreteOpTy>
+static FailureOr<SmallVector<Value>>
+fuseCastOperandsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder,
+ bool &modifiedInPlace, OpOperand &src) {
+ auto [castOp, tgtTy, resTy] = getFuseCastInfo(op.getType(), src.get());
+ // Bail if we cannot cast.
+ if (!castOp)
+ return failure();
+
+ modifiedInPlace = false;
+
+ // Create the new operands.
+ SmallVector<Value> operands;
+ llvm::append_range(operands, op->getOperands());
+ operands[src.getOperandNumber()] = castOp.getSourcePtr();
+
+ // Create the fused op and results.
+ auto newOp = ConcreteOpTy::create(
+ builder, op.getLoc(), TypeRange(resTy), operands, op.getProperties(),
+ llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
+
+ // Insert a memory-space cast to the original memory space of the op.
+ MemorySpaceCastOpInterface result =
+ castOp.cloneMemorySpaceCastOp(builder, tgtTy, newOp.getResult());
+ return SmallVector<Value>({result.getTargetPtr()});
+}
+
//===----------------------------------------------------------------------===//
// AllocOp / AllocaOp
//===----------------------------------------------------------------------===//
@@ -542,6 +601,12 @@ OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
return getMemref();
}
+FailureOr<SmallVector<Value>>
+AssumeAlignmentOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+ getMemrefMutable());
+}
+
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
@@ -710,6 +775,12 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
}
+FailureOr<SmallVector<Value>> CastOp::fuseCastOperands(OpBuilder &builder,
+ bool &modifiedInPlace) {
+ return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+ getSourceMutable());
+}
+
//===----------------------------------------------------------------------===//
// CopyOp
//===----------------------------------------------------------------------===//
@@ -1601,6 +1672,12 @@ OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
return OpFoldResult();
}
+FailureOr<SmallVector<Value>> LoadOp::fuseCastOperands(OpBuilder &builder,
+ bool &modifiedInPlace) {
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getMemrefMutable(), getResult(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// MemorySpaceCastOp
//===----------------------------------------------------------------------===//
@@ -1645,6 +1722,33 @@ OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
return Value{};
}
+TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
+ return cast<TypedValue<PtrLikeTypeInterface>>(getSource());
+}
+
+TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
+ return cast<TypedValue<PtrLikeTypeInterface>>(getDest());
+}
+
+bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
+ PtrLikeTypeInterface src) {
+ return isa<MemRefType>(tgt) &&
+ tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
+}
+
+MemorySpaceCastOpInterface
+MemorySpaceCastOp::cloneMemorySpaceCastOp(OpBuilder &b, Type tgt, Value src) {
+ assert(isValidMemorySpaceCast(cast<PtrLikeTypeInterface>(tgt),
+ cast<PtrLikeTypeInterface>(src.getType())) &&
+ "invalid arguments");
+ return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
+}
+
+bool MemorySpaceCastOp::isFusableMemorySpaceCast() {
+ // Only allow fusion when this is discarding information.
+ return getDest().getType().getMemorySpace() == nullptr;
+}
+
//===----------------------------------------------------------------------===//
// PrefetchOp
//===----------------------------------------------------------------------===//
@@ -2041,6 +2145,12 @@ void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
}
+FailureOr<SmallVector<Value>>
+ReinterpretCastOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+ getSourceMutable());
+}
+
//===----------------------------------------------------------------------===//
// Reassociative reshape ops
//===----------------------------------------------------------------------===//
@@ -2348,6 +2458,12 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
}
+FailureOr<SmallVector<Value>>
+ExpandShapeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+ getSrcMutable());
+}
+
/// Compute the layout map after collapsing a given source MemRef type with the
/// specified reassociation indices.
///
@@ -2569,6 +2685,12 @@ OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
adaptor.getOperands());
}
+FailureOr<SmallVector<Value>>
+CollapseShapeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+ getSrcMutable());
+}
+
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
@@ -2609,6 +2731,12 @@ LogicalResult ReshapeOp::verify() {
return success();
}
+FailureOr<SmallVector<Value>>
+ReshapeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+ getSourceMutable());
+}
+
//===----------------------------------------------------------------------===//
// StoreOp
//===----------------------------------------------------------------------===//
@@ -2626,6 +2754,12 @@ LogicalResult StoreOp::fold(FoldAdaptor adaptor,
return foldMemRefCast(*this, getValueToStore());
}
+FailureOr<SmallVector<Value>> StoreOp::fuseCastOperands(OpBuilder &builder,
+ bool &modifiedInPlace) {
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getMemrefMutable(), ValueRange(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// SubViewOp
//===----------------------------------------------------------------------===//
@@ -3282,6 +3416,12 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
return {};
}
+FailureOr<SmallVector<Value>>
+SubViewOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+ getSourceMutable());
+}
+
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
@@ -3382,6 +3522,12 @@ OpFoldResult TransposeOp::fold(FoldAdaptor) {
return {};
}
+FailureOr<SmallVector<Value>>
+TransposeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+ getInMutable());
+}
+
//===----------------------------------------------------------------------===//
// ViewOp
//===----------------------------------------------------------------------===//
@@ -3525,6 +3671,12 @@ void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
}
+FailureOr<SmallVector<Value>> ViewOp::fuseCastOperands(OpBuilder &builder,
+ bool &modifiedInPlace) {
+ return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+ getSourceMutable());
+}
+
//===----------------------------------------------------------------------===//
// AtomicRMWOp
//===----------------------------------------------------------------------===//
@@ -3570,6 +3722,12 @@ OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
return OpFoldResult();
}
+FailureOr<SmallVector<Value>>
+AtomicRMWOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getMemrefMutable(), getResult(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8d6e263934fb4..806e6c1c070aa 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5087,6 +5087,14 @@ void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<TransferReadAfterWriteToBroadcast>(context);
}
+FailureOr<SmallVector<Value>>
+TransferReadOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ if (!hasPureBufferSemantics())
+ return failure();
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getBaseMutable(), getResult(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// TransferWriteOp
//===----------------------------------------------------------------------===//
@@ -5574,6 +5582,14 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
}
+FailureOr<SmallVector<Value>>
+TransferWriteOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ if (!hasPureBufferSemantics())
+ return failure();
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getBaseMutable(), ValueRange(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// LoadOp
//===----------------------------------------------------------------------===//
@@ -5628,6 +5644,12 @@ std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}
+FailureOr<SmallVector<Value>> LoadOp::fuseCastOperands(OpBuilder &builder,
+ bool &modifiedInPlace) {
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getBaseMutable(), getResult(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// StoreOp
//===----------------------------------------------------------------------===//
@@ -5667,6 +5689,12 @@ std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}
+FailureOr<SmallVector<Value>> StoreOp::fuseCastOperands(OpBuilder &builder,
+ bool &modifiedInPlace) {
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getBaseMutable(), ValueRange(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// MaskedLoadOp
//===----------------------------------------------------------------------===//
@@ -5721,6 +5749,12 @@ OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
return OpFoldResult();
}
+FailureOr<SmallVector<Value>>
+MaskedLoadOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getBaseMutable(), getResult(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// MaskedStoreOp
//===----------------------------------------------------------------------===//
@@ -5771,6 +5805,12 @@ LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
return memref::foldMemRefCast(*this);
}
+FailureOr<SmallVector<Value>>
+MaskedStoreOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getBaseMutable(), ValueRange(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// GatherOp
//===----------------------------------------------------------------------===//
@@ -5874,6 +5914,12 @@ void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<GatherFolder, FoldContiguousGather>(context);
}
+FailureOr<SmallVector<Value>>
+GatherOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getBaseMutable(), getResult(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
@@ -5936,6 +5982,12 @@ void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ScatterFolder, FoldContiguousScatter>(context);
}
+FailureOr<SmallVector<Value>>
+ScatterOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getBaseMutable(), ValueRange(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// ExpandLoadOp
//===----------------------------------------------------------------------===//
@@ -5984,6 +6036,12 @@ void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ExpandLoadFolder>(context);
}
+FailureOr<SmallVector<Value>>
+ExpandLoadOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getBaseMutable(), getResult(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// CompressStoreOp
//===----------------------------------------------------------------------===//
@@ -6030,6 +6088,12 @@ void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<CompressStoreFolder>(context);
}
+FailureOr<SmallVector<Value>>
+CompressStoreOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getBaseMutable(), ValueRange(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// ShapeCastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index fdc19844702bc..388de1c3e5abf 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -11,6 +11,7 @@ set(LLVM_OPTIONAL_SOURCES
InferIntRangeInterface.cpp
InferTypeOpInterface.cpp
LoopLikeInterface.cpp
+ MemOpInterfaces.cpp
MemorySlotInterfaces.cpp
ParallelCombiningOpInterface.cpp
RuntimeVerifiableOpInterface.cpp
@@ -79,6 +80,7 @@ add_mlir_library(MLIRLoopLikeInterface
MLIRFunctionInterfaces
)
+add_mlir_interface_library(MemOpInterfaces)
add_mlir_interface_library(MemorySlotInterfaces)
add_mlir_interface_library(ParallelCombiningOpInterface)
add_mlir_interface_library(RuntimeVerifiableOpInterface)
diff --git a/mlir/lib/Interfaces/MemOpInterfaces.cpp b/mlir/lib/Interfaces/MemOpInterfaces.cpp
new file mode 100644
index 0000000000000..013d828da1d66
--- /dev/null
+++ b/mlir/lib/Interfaces/MemOpInterfaces.cpp
@@ -0,0 +1,73 @@
+//===- MemOpInterfaces.cpp - Memory operation interfaces ---------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/MemOpInterfaces.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
+
+using namespace mlir;
+
+LogicalResult mlir::detail::verifyMemorySpaceCastOpInterface(Operation *op) {
+ auto memCastOp = cast<MemorySpaceCastOpInterface>(op);
+
+ // Verify that the source and target pointers are valid
+ Value sourcePtr = memCastOp.getSourcePtr();
+ Value targetPtr = memCastOp.getTargetPtr();
+
+ if (!sourcePtr || !targetPtr) {
+ return op->emitError()
+ << "memory space cast op must have valid source and target pointers";
+ }
+
+ if (sourcePtr.getType().getTypeID() != targetPtr.getType().getTypeID()) {
+ return op->emitError()
+ << "expected source and target types of the same kind";
+ }
+
+ // Verify the Types are of `PtrLikeTypeInterface` type.
+ auto sourceType = dyn_cast<PtrLikeTypeInterface>(sourcePtr.getType());
+ if (!sourceType) {
+ return op->emitError()
+ << "source type must implement `PtrLikeTypeInterface`, but got: "
+ << sourcePtr.getType();
+ }
+
+ auto targetType = dyn_cast<PtrLikeTypeInterface>(targetPtr.getType());
+ if (!targetType) {
+ return op->emitError()
+ << "target type must implement `PtrLikeTypeInterface`, but got: "
+ << targetPtr.getType();
+ }
+
+ // Verify that the operation has exactly one result
+ if (op->getNumResults() != 1) {
+ return op->emitError()
+ << "memory space cast op must have exactly one result";
+ }
+
+ return success();
+}
+
+FailureOr<SmallVector<Value>> mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ OpOperand &operand, ValueRange results, bool &modifiedInPlace) {
+ MemorySpaceCastOpInterface castOp =
+ MemorySpaceCastOpInterface::getIfFusableCast(operand.get());
+
+ // Bail if the src is not produced by a `MemorySpaceCastOpInterface`.
+ if (!castOp)
+ return failure();
+
+ // Modify the op.
+ modifiedInPlace = true;
+ operand.set(castOp.getSourcePtr());
+ return llvm::to_vector_of<Value>(results);
+}
+
+#include "mlir/Interfaces/MemOpInterfaces.cpp.inc"
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 058039e47313e..e9a7d3e4abe99 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_library(MLIRTransforms
ControlFlowSink.cpp
CSE.cpp
GenerateRuntimeVerification.cpp
+ FuseMemorySpaceCastsIntoConsumers.cpp
InlinerPass.cpp
LocationSnapshot.cpp
LoopInvariantCodeMotion.cpp
@@ -31,6 +32,7 @@ add_mlir_library(MLIRTransforms
MLIRAnalysis
MLIRFunctionInterfaces
MLIRLoopLikeInterface
+ MLIRMemOpInterfaces
MLIRMemorySlotInterfaces
MLIRPass
MLIRRuntimeVerifiableOpInterface
diff --git a/mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp b/mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp
new file mode 100644
index 0000000000000..010b88ac12de2
--- /dev/null
+++ b/mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp
@@ -0,0 +1,73 @@
+//===- FuseMemorySpaceCastsIntoConsumers.cpp - Fuse casts transform -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/MemOpInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+
+namespace mlir {
+#define GEN_PASS_DEF_FUSEMEMORYSPACECASTSINTOCONSUMERS
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+namespace {
+//===----------------------------------------------------------------------===//
+// FuseCastsPattern pattern
+//===----------------------------------------------------------------------===//
+/// Pattern to fuse casts into consumer operations.
+struct FuseCastsPattern
+ : public OpInterfaceRewritePattern<FuseMemorySpaceCastConsumerOpInterface> {
+ using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+
+ LogicalResult matchAndRewrite(FuseMemorySpaceCastConsumerOpInterface op,
+ PatternRewriter &rewriter) const override {
+ bool modifiedInPlace = false;
+ FailureOr<SmallVector<Value>> results =
+ op.fuseCastOperands(rewriter, modifiedInPlace);
+ assert((!failed(results) || !modifiedInPlace) &&
+ "expected `modifiedInPlace` to be false on fusion failure");
+ if (failed(results))
+ return failure();
+ if (modifiedInPlace) {
+ rewriter.modifyOpInPlace(op, []() {});
+ return success();
+ }
+ rewriter.replaceOp(op, *results);
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// FuseMemorySpaceCastsIntoConsumers pass
+//===----------------------------------------------------------------------===//
+
+struct FuseMemorySpaceCastsIntoConsumers
+ : public impl::FuseMemorySpaceCastsIntoConsumersBase<
+ FuseMemorySpaceCastsIntoConsumers> {
+ using impl::FuseMemorySpaceCastsIntoConsumersBase<
+ FuseMemorySpaceCastsIntoConsumers>::FuseMemorySpaceCastsIntoConsumersBase;
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateFuseMemorySpaceCastIntoConsumersPatterns(patterns);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+void mlir::populateFuseMemorySpaceCastIntoConsumersPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<FuseCastsPattern>(patterns.getContext());
+}
diff --git a/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir b/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir
new file mode 100644
index 0000000000000..69a15f429cec2
--- /dev/null
+++ b/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir
@@ -0,0 +1,281 @@
+// RUN: mlir-opt %s --fuse-memory-space-casts-into-consumers | FileCheck %s
+
+#map = affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>
+// CHECK-LABEL: func.func @load_store(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index) {
+// CHECK: %[[VAL_0:.*]] = memref.load %[[ARG0]]{{\[}}%[[ARG1]]] : memref<?xf32, 1>
+// CHECK: memref.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[ARG1]]] : memref<?xf32, 1>
+// CHECK: return
+// CHECK: }
+func.func @load_store(%arg0: memref<?xf32, 1>, %arg1: index) {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+ %0 = memref.load %memspacecast[%arg1] : memref<?xf32>
+ memref.store %0, %memspacecast[%arg1] : memref<?xf32>
+ return
+}
+
+// CHECK-LABEL: func.func @load_store_unfoldable(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index) {
+// CHECK: %[[VAL_0:.*]] = memref.memory_space_cast %[[ARG0]] : memref<?xf32, 1> to memref<?xf32, 2>
+// CHECK: %[[VAL_1:.*]] = memref.load %[[VAL_0]]{{\[}}%[[ARG1]]] : memref<?xf32, 2>
+// CHECK: memref.store %[[VAL_1]], %[[VAL_0]]{{\[}}%[[ARG1]]] : memref<?xf32, 2>
+// CHECK: return
+// CHECK: }
+func.func @load_store_unfoldable(%arg0: memref<?xf32, 1>, %arg1: index) {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32, 2>
+ %0 = memref.load %memspacecast[%arg1] : memref<?xf32, 2>
+ memref.store %0, %memspacecast[%arg1] : memref<?xf32, 2>
+ return
+}
+
+// CHECK-LABEL: func.func @view(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xi8, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> memref<?x?xi8> {
+// CHECK: %[[VAL_0:.*]] = arith.constant 100 : index
+// CHECK: %[[VAL_1:.*]] = memref.view %[[ARG0]]{{\[}}%[[ARG1]]]{{\[}}%[[ARG2]], %[[VAL_0]]] : memref<?xi8, 1> to memref<?x?xi8, 1>
+// CHECK: %[[VAL_2:.*]] = memref.memory_space_cast %[[VAL_1]] : memref<?x?xi8, 1> to memref<?x?xi8>
+// CHECK: return %[[VAL_2]] : memref<?x?xi8>
+// CHECK: }
+func.func @view(%arg0: memref<?xi8, 1>, %arg1: index, %arg2: index) -> memref<?x?xi8> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xi8, 1> to memref<?xi8>
+ %c100 = arith.constant 100 : index
+ %view = memref.view %memspacecast[%arg1][%arg2, %c100] : memref<?xi8> to memref<?x?xi8>
+ return %view : memref<?x?xi8>
+}
+
+// CHECK-LABEL: func.func @subview(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index) -> memref<8x2xf32, strided<[?, 2], offset: ?>> {
+// CHECK: %[[VAL_0:.*]] = memref.subview %[[ARG0]][4, 2] [8, 2] [3, 2] : memref<?x?xf32, 1> to memref<8x2xf32, strided<[?, 2], offset: ?>, 1>
+// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<8x2xf32, strided<[?, 2], offset: ?>, 1> to memref<8x2xf32, strided<[?, 2], offset: ?>>
+// CHECK: return %[[VAL_1]] : memref<8x2xf32, strided<[?, 2], offset: ?>>
+// CHECK: }
+func.func @subview(%arg0: memref<?x?xf32, 1>, %arg1: index) -> memref<8x2xf32, strided<[?, 2], offset: ?>> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?x?xf32, 1> to memref<?x?xf32>
+ %subview = memref.subview %memspacecast[4, 2] [8, 2] [3, 2] : memref<?x?xf32> to memref<8x2xf32, strided<[?, 2], offset: ?>>
+ return %subview : memref<8x2xf32, strided<[?, 2], offset: ?>>
+}
+
+// CHECK-LABEL: func.func @reinterpret_cast(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index) -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
+// CHECK: %[[VAL_0:.*]] = arith.constant 10 : index
+// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_2:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[VAL_1]]], sizes: [10, %[[VAL_0]]], strides: {{\[}}%[[VAL_0]], 1] : memref<?xf32, 1> to memref<10x?xf32, strided<[?, 1], offset: ?>, 1>
+// CHECK: %[[VAL_3:.*]] = memref.memory_space_cast %[[VAL_2]] : memref<10x?xf32, strided<[?, 1], offset: ?>, 1> to memref<10x?xf32, strided<[?, 1], offset: ?>>
+// CHECK: return %[[VAL_3]] : memref<10x?xf32, strided<[?, 1], offset: ?>>
+// CHECK: }
+func.func @reinterpret_cast(%arg0: memref<?xf32, 1>, %arg1: index) -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+ %c0 = arith.constant 0 : index
+ %c10 = arith.constant 10 : index
+ %reinterpret_cast = memref.reinterpret_cast %memspacecast to offset: [%c0], sizes: [10, %c10], strides: [%c10, 1] : memref<?xf32> to memref<10x?xf32, strided<[?, 1], offset: ?>>
+ return %reinterpret_cast : memref<10x?xf32, strided<[?, 1], offset: ?>>
+}
+
+// CHECK-LABEL: func.func @reshape(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: memref<1xindex>) -> memref<?xf32> {
+// CHECK: %[[VAL_0:.*]] = memref.reshape %[[ARG0]](%[[ARG1]]) : (memref<?x?xf32, 1>, memref<1xindex>) -> memref<?xf32, 1>
+// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<?xf32, 1> to memref<?xf32>
+// CHECK: return %[[VAL_1]] : memref<?xf32>
+// CHECK: }
+func.func @reshape(%arg0: memref<?x?xf32, 1>, %arg1: memref<1xindex>) -> memref<?xf32> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?x?xf32, 1> to memref<?x?xf32>
+ %reshape = memref.reshape %memspacecast(%arg1) : (memref<?x?xf32>, memref<1xindex>) -> memref<?xf32>
+ return %reshape : memref<?xf32>
+}
+
+// CHECK-LABEL: func.func @expand_shape(
+// CHECK-SAME: %[[ARG0:.*]]: memref<12xf32, 1>) -> memref<3x4xf32> {
+// CHECK: %[[VAL_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1]] output_shape [3, 4] : memref<12xf32, 1> into memref<3x4xf32, 1>
+// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<3x4xf32, 1> to memref<3x4xf32>
+// CHECK: return %[[VAL_1]] : memref<3x4xf32>
+// CHECK: }
+func.func @expand_shape(%arg0: memref<12xf32, 1>) -> memref<3x4xf32> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<12xf32, 1> to memref<12xf32>
+ %expand_shape = memref.expand_shape %memspacecast [[0, 1]] output_shape [3, 4] : memref<12xf32> into memref<3x4xf32>
+ return %expand_shape : memref<3x4xf32>
+}
+
+// CHECK-LABEL: func.func @collapse_shape(
+// CHECK-SAME: %[[ARG0:.*]]: memref<3x4xf32, 1>) -> memref<12xf32> {
+// CHECK: %[[VAL_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[\[}}0, 1]] : memref<3x4xf32, 1> into memref<12xf32, 1>
+// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<12xf32, 1> to memref<12xf32>
+// CHECK: return %[[VAL_1]] : memref<12xf32>
+// CHECK: }
+func.func @collapse_shape(%arg0: memref<3x4xf32, 1>) -> memref<12xf32> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<3x4xf32, 1> to memref<3x4xf32>
+ %collapse_shape = memref.collapse_shape %memspacecast [[0, 1]] : memref<3x4xf32> into memref<12xf32>
+ return %collapse_shape : memref<12xf32>
+}
+
+// CHECK-LABEL: func.func @transpose(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32, 1>) -> memref<?x?xf32, #[[$ATTR_0]]> {
+// CHECK: %[[VAL_0:.*]] = memref.transpose %[[ARG0]] (d0, d1) -> (d1, d0) : memref<?x?xf32, 1> to memref<?x?xf32, #[[$ATTR_0]], 1>
+// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<?x?xf32, #[[$ATTR_0]], 1> to memref<?x?xf32, #[[$ATTR_0]]>
+// CHECK: return %[[VAL_1]] : memref<?x?xf32, #[[$ATTR_0]]>
+// CHECK: }
+func.func @transpose(%arg0: memref<?x?xf32, 1>) -> memref<?x?xf32, #map> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?x?xf32, 1> to memref<?x?xf32>
+ %transpose = memref.transpose %memspacecast (d0, d1) -> (d1, d0) : memref<?x?xf32> to memref<?x?xf32, #map>
+ return %transpose : memref<?x?xf32, #map>
+}
+
+// CHECK-LABEL: func.func @atomic_rmw(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index,
+// CHECK-SAME: %[[ARG2:.*]]: f32) -> f32 {
+// CHECK: %[[VAL_0:.*]] = memref.atomic_rmw addf %[[ARG2]], %[[ARG0]]{{\[}}%[[ARG1]]] : (f32, memref<?xf32, 1>) -> f32
+// CHECK: return %[[VAL_0]] : f32
+// CHECK: }
+func.func @atomic_rmw(%arg0: memref<?xf32, 1>, %arg1: index, %arg2: f32) -> f32 {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+ %0 = memref.atomic_rmw addf %arg2, %memspacecast[%arg1] : (f32, memref<?xf32>) -> f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func.func @assume_alignment(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>) -> memref<?xf32> {
+// CHECK: %[[VAL_0:.*]] = memref.assume_alignment %[[ARG0]], 16 : memref<?xf32, 1>
+// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<?xf32, 1> to memref<?xf32>
+// CHECK: return %[[VAL_1]] : memref<?xf32>
+// CHECK: }
+func.func @assume_alignment(%arg0: memref<?xf32, 1>) -> memref<?xf32> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+ %1 = memref.assume_alignment %memspacecast, 16 : memref<?xf32>
+ return %1 : memref<?xf32>
+}
+
+// CHECK-LABEL: func.func @op_with_cast_sequence(
+// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index,
+// CHECK-SAME: %[[ARG2:.*]]: f32) -> memref<16xf32> {
+// CHECK: %[[VAL_0:.*]] = arith.constant 4 : index
+// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_2:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
+// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
+// CHECK: %[[VAL_4:.*]] = memref.memory_space_cast %[[VAL_3]] : memref<16xf32, 1> to memref<16xf32>
+// CHECK: %[[VAL_5:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_1]]] : memref<16xf32, 1>
+// CHECK: %[[VAL_6:.*]] = arith.addf %[[VAL_5]], %[[ARG2]] : f32
+// CHECK: memref.store %[[VAL_6]], %[[VAL_3]]{{\[}}%[[VAL_1]]] : memref<16xf32, 1>
+// CHECK: %[[VAL_7:.*]] = memref.atomic_rmw addf %[[ARG2]], %[[VAL_3]]{{\[}}%[[VAL_0]]] : (f32, memref<16xf32, 1>) -> f32
+// CHECK: return %[[VAL_4]] : memref<16xf32>
+// CHECK: }
+func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32>
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32>
+ %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32>
+ %loaded = memref.load %collapsed[%c0] : memref<16xf32>
+ %added = arith.addf %loaded, %arg2 : f32
+ memref.store %added, %collapsed[%c0] : memref<16xf32>
+ %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32
+ return %collapsed : memref<16xf32>
+}
+
+// CHECK-LABEL: func.func @transfer_read_write(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index) {
+// CHECK: %[[VAL_0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_1:.*]] = vector.transfer_read %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xf32>
+// CHECK: vector.transfer_write %[[VAL_1]], %[[ARG0]]{{\[}}%[[ARG1]]] : vector<4xf32>, memref<?xf32, 1>
+// CHECK: return
+// CHECK: }
+func.func @transfer_read_write(%arg0: memref<?xf32, 1>, %arg1: index) {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+ %c0 = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %memspacecast[%arg1], %c0 : memref<?xf32>, vector<4xf32>
+ vector.transfer_write %0, %memspacecast[%arg1] : vector<4xf32>, memref<?xf32>
+ return
+}
+
+// NOTE: The operations disappear because they can get folded.
+// CHECK-LABEL: func.func @transfer_read_write_tensor(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: index) -> tensor<?xf32> {
+// CHECK: return %[[ARG0]] : tensor<?xf32>
+// CHECK: }
+func.func @transfer_read_write_tensor(%arg0: tensor<?xf32>, %arg1: index) -> tensor<?xf32> {
+ %c0 = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %arg0[%arg1], %c0 : tensor<?xf32>, vector<4xf32>
+ %1 = vector.transfer_write %0, %arg0[%arg1] : vector<4xf32>, tensor<?xf32>
+ return %1 : tensor<?xf32>
+}
+
+// CHECK-LABEL: func.func @vector_load_store(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index) {
+// CHECK: %[[VAL_0:.*]] = vector.load %[[ARG0]]{{\[}}%[[ARG1]]] : memref<?xf32, 1>, vector<4xf32>
+// CHECK: vector.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[ARG1]]] : memref<?xf32, 1>, vector<4xf32>
+// CHECK: return
+// CHECK: }
+func.func @vector_load_store(%arg0: memref<?xf32, 1>, %arg1: index) {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+ %0 = vector.load %memspacecast[%arg1] : memref<?xf32>, vector<4xf32>
+ vector.store %0, %memspacecast[%arg1] : memref<?xf32>, vector<4xf32>
+ return
+}
+
+// CHECK-LABEL: func.func @masked_load_store(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index) {
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK: %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+// CHECK: %[[VAL_2:.*]] = vector.maskedload %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+// CHECK: vector.maskedstore %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_2]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32>
+// CHECK: return
+// CHECK: }
+func.func @masked_load_store(%arg0: memref<?xf32, 1>, %arg1: index) {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+ %mask = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+ %passthrough = arith.constant dense<0.0> : vector<4xf32>
+ %0 = vector.maskedload %memspacecast[%arg1], %mask, %passthrough : memref<?xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ vector.maskedstore %memspacecast[%arg1], %mask, %0 : memref<?xf32>, vector<4xi1>, vector<4xf32>
+ return
+}
+
+// CHECK-LABEL: func.func @gather_scatter(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index) {
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK: %[[VAL_1:.*]] = arith.constant dense<true> : vector<4xi1>
+// CHECK: %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = vector.gather %[[ARG0]]{{\[}}%[[VAL_3]]] {{\[}}%[[VAL_2]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+// CHECK: vector.scatter %[[ARG0]]{{\[}}%[[VAL_3]]] {{\[}}%[[VAL_2]]], %[[VAL_1]], %[[VAL_4]] : memref<?xf32, 1>, vector<4xindex>, vector<4xi1>, vector<4xf32>
+// CHECK: return
+// CHECK: }
+func.func @gather_scatter(%arg0: memref<?xf32, 1>, %arg1: index) {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+ %c0 = arith.constant 0 : index
+ %indices = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+ %mask = arith.constant dense<true> : vector<4xi1>
+ %passthrough = arith.constant dense<0.0> : vector<4xf32>
+ %0 = vector.gather %memspacecast[%c0] [%indices], %mask, %passthrough : memref<?xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ vector.scatter %memspacecast[%c0] [%indices], %mask, %0 : memref<?xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32>
+ return
+}
+
+// CHECK-LABEL: func.func @expandload_compressstore(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index) {
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK: %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+// CHECK: %[[VAL_2:.*]] = vector.expandload %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+// CHECK: vector.compressstore %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_2]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32>
+// CHECK: return
+// CHECK: }
+func.func @expandload_compressstore(%arg0: memref<?xf32, 1>, %arg1: index) {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+ %mask = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+ %passthrough = arith.constant dense<0.0> : vector<4xf32>
+ %0 = vector.expandload %memspacecast[%arg1], %mask, %passthrough : memref<?xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ vector.compressstore %memspacecast[%arg1], %mask, %0 : memref<?xf32>, vector<4xi1>, vector<4xf32>
+ return
+}
>From 1e49e3dff6ca2f4c3b551d164310e2318a0a46cb Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Thu, 18 Sep 2025 07:04:24 -0400
Subject: [PATCH 2/2] Update mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Co-authored-by: Copilot <175728472+Copilot at users.noreply.github.com>
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 0ddb2b0ca1645..2201b237cfdda 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -118,7 +118,7 @@ getFuseCastInfo(BaseMemRefType resultTy, Value src) {
MemorySpaceCastOpInterface castOp =
MemorySpaceCastOpInterface::getIfFusableCast(src);
- // Bail if the cast is not fusable.
+ // Bail if the cast is not fusible.
if (!castOp)
return {};
More information about the Mlir-commits
mailing list