[Mlir-commits] [mlir] [mlir] Implement a memory-space cast bubbling-down transform (PR #159454)
Fabian Mora
llvmlistbot at llvm.org
Tue Sep 23 10:49:00 PDT 2025
https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/159454
>From a15b8ca9bd287d4ad6af320cae7ae01afb4234bc Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Wed, 17 Sep 2025 21:07:38 +0000
Subject: [PATCH 01/10] [mlir] Implement memory-space cast operand fusion into
consumers
This commit adds functionality to fuse memory-space casts into consumer operations,
allowing operations to be performed directly on the original memory-space rather
than first casting to a different memory space.
Key changes:
- Introduce `MemorySpaceCastOpInterface` to handle memory-space cast operations
- Create a `FuseMemorySpaceCastsIntoConsumers` pass that identifies and fuses eligible casts
- Add implementation for memref and vector operations to handle memory-space cast fusion
- Add fuseCastOperands method to relevant operations to support the fusion
In particular, in the current implementation only memory-space casts into the default
memory-space can be fused.
Example:
```mlir
func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
%memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32>
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32>
%collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32>
%loaded = memref.load %collapsed[%c0] : memref<16xf32>
%added = arith.addf %loaded, %arg2 : f32
memref.store %added, %collapsed[%c0] : memref<16xf32>
%atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32
return %collapsed : memref<16xf32>
}
// mlir-opt --fuse-memory-space-casts-into-consumers
func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
%c4 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%expand_shape = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
%collapse_shape = memref.collapse_shape %expand_shape [[0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
%memspacecast = memref.memory_space_cast %collapse_shape : memref<16xf32, 1> to memref<16xf32>
%0 = memref.load %collapse_shape[%c0] : memref<16xf32, 1>
%1 = arith.addf %0, %arg2 : f32
memref.store %1, %collapse_shape[%c0] : memref<16xf32, 1>
%2 = memref.atomic_rmw addf %arg2, %collapse_shape[%c4] : (f32, memref<16xf32, 1>) -> f32
return %memspacecast : memref<16xf32>
}
```
Signed-off-by: Fabian Mora <fabian.mora-cordero at amd.com>
---
mlir/include/mlir/Dialect/MemRef/IR/MemRef.h | 1 +
.../mlir/Dialect/MemRef/IR/MemRefOps.td | 19 +-
.../mlir/Dialect/Vector/IR/VectorOps.h | 1 +
.../mlir/Dialect/Vector/IR/VectorOps.td | 16 +-
mlir/include/mlir/Interfaces/CMakeLists.txt | 1 +
.../include/mlir/Interfaces/MemOpInterfaces.h | 37 +++
.../mlir/Interfaces/MemOpInterfaces.td | 114 +++++++
.../FuseMemorySpaceCastsIntoConsumers.h | 20 ++
mlir/include/mlir/Transforms/Passes.h | 1 +
mlir/include/mlir/Transforms/Passes.td | 40 +++
mlir/lib/Dialect/MemRef/IR/CMakeLists.txt | 1 +
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 158 ++++++++++
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 64 ++++
mlir/lib/Interfaces/CMakeLists.txt | 2 +
mlir/lib/Interfaces/MemOpInterfaces.cpp | 73 +++++
mlir/lib/Transforms/CMakeLists.txt | 2 +
.../FuseMemorySpaceCastsIntoConsumers.cpp | 73 +++++
.../test-fuse-casts-into-consumers.mlir | 281 ++++++++++++++++++
18 files changed, 897 insertions(+), 7 deletions(-)
create mode 100644 mlir/include/mlir/Interfaces/MemOpInterfaces.h
create mode 100644 mlir/include/mlir/Interfaces/MemOpInterfaces.td
create mode 100644 mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h
create mode 100644 mlir/lib/Interfaces/MemOpInterfaces.cpp
create mode 100644 mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp
create mode 100644 mlir/test/Transforms/test-fuse-casts-into-consumers.mlir
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index bdec699eb4ce4..30f33ed2fd1d6 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -18,6 +18,7 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/MemOpInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 671cc05e963b4..238a767ac8b73 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -15,6 +15,7 @@ include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/MemOpInterfaces.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/ShapedOpInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -145,7 +146,8 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
Pure,
ViewLikeOpInterface,
- SameOperandsAndResultType
+ SameOperandsAndResultType,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
]> {
let summary =
"assumption that gives alignment information to the input memref";
@@ -456,6 +458,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
def MemRef_CastOp : MemRef_Op<"cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
MemRefsNormalizable,
Pure,
SameOperandsAndResultShape,
@@ -1194,6 +1197,7 @@ def LoadOp : MemRef_Op<"load",
"memref", "result",
"::llvm::cast<MemRefType>($_self).getElementType()">,
MemRefsNormalizable,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
let summary = "load operation";
@@ -1284,6 +1288,7 @@ def LoadOp : MemRef_Op<"load",
def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<MemorySpaceCastOpInterface>,
MemRefsNormalizable,
Pure,
SameOperandsAndResultElementType,
@@ -1376,6 +1381,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
def MemRef_ReinterpretCastOp
: MemRef_OpWithOffsetSizesAndStrides<"reinterpret_cast", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
AttrSizedOperandSegments,
MemRefsNormalizable,
Pure,
@@ -1603,6 +1609,7 @@ def MemRef_RankOp : MemRef_Op<"rank", [Pure]> {
def MemRef_ReshapeOp: MemRef_Op<"reshape", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
Pure,
ViewLikeOpInterface]> {
let summary = "memref reshape operation";
@@ -1701,6 +1708,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
let summary = "operation to produce a memref with a higher rank.";
let description = [{
@@ -1822,7 +1830,9 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
}
def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
+ ]> {
let summary = "operation to produce a memref with a smaller rank.";
let description = [{
The `memref.collapse_shape` op produces a new view with a smaller rank
@@ -1929,6 +1939,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
"memref", "value",
"::llvm::cast<MemRefType>($_self).getElementType()">,
MemRefsNormalizable,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
let summary = "store operation";
@@ -2006,6 +2017,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface,
@@ -2281,6 +2293,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
def MemRef_TransposeOp : MemRef_Op<"transpose", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
Pure]>,
Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
Results<(outs AnyStridedMemRef)> {
@@ -2316,6 +2329,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [
def MemRef_ViewOp : MemRef_Op<"view", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
Pure]> {
let summary = "memref view operation";
@@ -2392,6 +2406,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
//===----------------------------------------------------------------------===//
def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
AllTypesMatch<["value", "result"]>,
TypesMatchWith<"value type matches element type of memref",
"memref", "value",
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 63410b8bea747..bbf55f5d507e3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -27,6 +27,7 @@
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/IndexingMapOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/MemOpInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 26d06624cb976..93e9bfc78ea75 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -24,6 +24,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/IndexingMapOpInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/MemOpInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
@@ -1246,6 +1247,7 @@ def Vector_TransferReadOp :
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
AttrSizedOperandSegments,
DestinationStyleOpInterface
]>,
@@ -1493,6 +1495,7 @@ def Vector_TransferWriteOp :
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
AttrSizedOperandSegments,
DestinationStyleOpInterface
]>,
@@ -1649,6 +1652,7 @@ def Vector_TransferWriteOp :
def Vector_LoadOp : Vector_Op<"load", [
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
]> {
let summary = "reads an n-D slice of memory into an n-D vector";
let description = [{
@@ -1765,6 +1769,7 @@ def Vector_LoadOp : Vector_Op<"load", [
def Vector_StoreOp : Vector_Op<"store", [
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
]> {
let summary = "writes an n-D vector to an n-D slice of memory";
let description = [{
@@ -1869,7 +1874,7 @@ def Vector_StoreOp : Vector_Op<"store", [
}
def Vector_MaskedLoadOp :
- Vector_Op<"maskedload">,
+ Vector_Op<"maskedload", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -1961,7 +1966,7 @@ def Vector_MaskedLoadOp :
}
def Vector_MaskedStoreOp :
- Vector_Op<"maskedstore">,
+ Vector_Op<"maskedstore", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2041,6 +2046,7 @@ def Vector_MaskedStoreOp :
def Vector_GatherOp :
Vector_Op<"gather", [
DeclareOpInterfaceMethods<MaskableOpInterface>,
+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
@@ -2144,7 +2150,7 @@ def Vector_GatherOp :
}
def Vector_ScatterOp :
- Vector_Op<"scatter">,
+ Vector_Op<"scatter", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$offsets,
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
@@ -2229,7 +2235,7 @@ def Vector_ScatterOp :
}
def Vector_ExpandLoadOp :
- Vector_Op<"expandload">,
+ Vector_Op<"expandload", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2317,7 +2323,7 @@ def Vector_ExpandLoadOp :
}
def Vector_CompressStoreOp :
- Vector_Op<"compressstore">,
+ Vector_Op<"compressstore", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index 2add220fdfb7c..a5feb592045c0 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_interface(IndexingMapOpInterface)
add_mlir_interface(InferIntRangeInterface)
add_mlir_interface(InferTypeOpInterface)
add_mlir_interface(LoopLikeInterface)
+add_mlir_interface(MemOpInterfaces)
add_mlir_interface(ParallelCombiningOpInterface)
add_mlir_interface(RuntimeVerifiableOpInterface)
add_mlir_interface(ShapedOpInterfaces)
diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.h b/mlir/include/mlir/Interfaces/MemOpInterfaces.h
new file mode 100644
index 0000000000000..cc9f4c6b3882e
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.h
@@ -0,0 +1,37 @@
+//===- MemOpInterfaces.h - Memory operation interfaces ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains declarations of interfaces for operations that interact
+// with memory.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_MEMOPINTERFACES_H
+#define MLIR_INTERFACES_MEMOPINTERFACES_H
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace detail {
+/// Attempt to verify the given memory space cast operation.
+LogicalResult verifyMemorySpaceCastOpInterface(Operation *op);
+
+/// Tries to fuse inplace a `MemorySpaceCastOpInterface` operation referenced by
+/// `operand`. On success, it returns `results`, and sets `modifiedInPlace` to
+/// true. It returns failure if `operand` doesn't reference a
+/// `MemorySpaceCastOpInterface` op.
+FailureOr<SmallVector<Value>>
+fuseInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results,
+ bool &modifiedInPlace);
+} // namespace detail
+} // namespace mlir
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/MemOpInterfaces.h.inc"
+
+#endif // MLIR_INTERFACES_MEMOPINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.td b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
new file mode 100644
index 0000000000000..0b8ba19171fb7
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
@@ -0,0 +1,114 @@
+//===- MemOpInterfaces.td - Memory operation interfaces -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains interfaces for operations that interact with memory.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_MEMOPINTERFACES_TD
+#define MLIR_INTERFACES_MEMOPINTERFACES_TD
+
+include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+def FuseMemorySpaceCastConsumerOpInterface :
+ OpInterface<"FuseMemorySpaceCastConsumerOpInterface"> {
+ let description = [{
+ An interface to fuse memory-space cast operands into a consumer operation.
+ It is the responsibility of the interface to determine which casts can be
+ fused into the operation.
+ }];
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<[{
+ Attempt to fuse the incoming cast-like operands. Returns `success`
+ and any new results on fusion success, otherwise it returns failure.
+ If new results are produced, these must be compatible with the original
+ operation results.
+
+ The `modifiedInPlace` parameter indicates whether the operation was
+ modified in place. If `false` and the fusion succeeded, then the
+ interface guarantees it is valid to erase the original operation.
+ If `true`, then the interface must guarantee no operations were created
+ by the method, and that no further IR modification is necessary. It is
+ considered an error if `modifiedInPlace` is true and the fusion failed.
+
+ Any implementations of this method must not erase/replace the original
+ operation, instead it is the caller responsibility to erase or replace
+ the op with the results provided by the method.
+
+ Finally, any implementations of this method have to guarantee that the
+ IR remains valid at all times.
+ }],
+ "::llvm::FailureOr<::llvm::SmallVector<::mlir::Value>>", "fuseCastOperands",
+ (ins "::mlir::OpBuilder &":$builder, "bool &":$modifiedInPlace)
+ >,
+ ];
+}
+
+def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
+ let description = [{
+ An interface for operations that perform memory-space casts. This
+ interface assumes that the cast operation is `pure`.
+
+ These operations expect to have a well-defined ptr-like operand, and
+ a well-defined target ptr-like result.
+ }];
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<[{
+ Returns the source ptr-like value.
+ }],
+ "::mlir::TypedValue<::mlir::PtrLikeTypeInterface>", "getSourcePtr"
+ >,
+ InterfaceMethod<[{
+ Returns the target ptr-like value.
+ }],
+ "::mlir::TypedValue<::mlir::PtrLikeTypeInterface>", "getTargetPtr"
+ >,
+ InterfaceMethod<[{
+ Returns whether the memory space cast specified by `tgt` and `src`
+ is supported.
+ }],
+ "bool", "isValidMemorySpaceCast",
+ (ins "::mlir::PtrLikeTypeInterface":$tgt,
+ "::mlir::PtrLikeTypeInterface":$src)
+ >,
+ InterfaceMethod<[{
+ Clones the memory space cast op with the given source and target type.
+ }],
+ "::mlir::MemorySpaceCastOpInterface", "cloneMemorySpaceCastOp",
+ (ins "::mlir::OpBuilder &":$builder, "::mlir::Type":$tgt,
+ "::mlir::Value":$src)
+ >,
+ InterfaceMethod<[{
+ Returns whether the cast allows to be fused.
+ }],
+ "bool", "isFusableMemorySpaceCast"
+ >
+ ];
+ let verify = [{
+ return ::mlir::detail::verifyMemorySpaceCastOpInterface($_op);
+ }];
+ let dependentTraits = [Pure];
+ let extraClassDeclaration = [{
+ /// Returns the underlying `MemorySpaceCastOpInterface` op if `value`
+ /// is produced by a `MemorySpaceCastOpInterface` op, and
+ /// `isFusableMemorySpaceCast` returns true, otherwise it returns null.
+ static ::mlir::MemorySpaceCastOpInterface
+ getIfFusableCast(::mlir::Value value) {
+ auto op = ::llvm::dyn_cast_or_null<::mlir::MemorySpaceCastOpInterface>(
+ value.getDefiningOp());
+ if (!op || !op.isFusableMemorySpaceCast())
+ return nullptr;
+ return op;
+ }
+ }];
+}
+
+#endif // MLIR_INTERFACES_MEMOPINTERFACES_TD
diff --git a/mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h b/mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h
new file mode 100644
index 0000000000000..9333f92a10289
--- /dev/null
+++ b/mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h
@@ -0,0 +1,20 @@
+//===-- FuseMemorySpaceCastsIntoConsumers.h - Cast fusion patterns -C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_FUSEMEMORYSPACECASTSINTOCONSUMERS_H
+#define MLIR_TRANSFORMS_FUSEMEMORYSPACECASTSINTOCONSUMERS_H
+
+namespace mlir {
+class RewritePatternSet;
+/// Collect a set of patterns to fuse memory-space cast operations into
+/// consumers.
+void populateFuseMemorySpaceCastIntoConsumersPatterns(
+ RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_FUSEMEMORYSPACECASTSINTOCONSUMERS_H
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 9cd2ef34e15ea..610a9671fede8 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -46,6 +46,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_SYMBOLPRIVATIZE
#define GEN_PASS_DECL_TOPOLOGICALSORT
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
+#define GEN_PASS_DECL_FUSEMEMORYSPACECASTSINTOCONSUMERS
#include "mlir/Transforms/Passes.h.inc"
/// Creates an instance of the Canonicalizer pass, configured with default
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index beb59784947c5..69280e3d443ea 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -585,4 +585,44 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
];
}
+def FuseMemorySpaceCastsIntoConsumers :
+ Pass<"fuse-memory-space-casts-into-consumers"> {
+ let summary = "Fuses memory-space cast operations into consumers.";
+ let description = [{
+ This pass tries to fuse all possible memory-space cast operations into their consumers.
+ It does this by looking for `FuseMemorySpaceCastConsumerOpInterface`
+ operations, and invoking the interface methods to perform the fusion.
+
+ Example:
+
+ ```mlir
+ func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32>
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32>
+ %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32>
+ %loaded = memref.load %collapsed[%c0] : memref<16xf32>
+ %added = arith.addf %loaded, %arg2 : f32
+ memref.store %added, %collapsed[%c0] : memref<16xf32>
+ %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32
+ return %collapsed : memref<16xf32>
+ }
+ // mlir-opt --fuse-casts-into-consumers
+ func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
+ %c4 = arith.constant 4 : index
+ %c0 = arith.constant 0 : index
+ %expand_shape = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
+ %collapse_shape = memref.collapse_shape %expand_shape [[0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
+ %memspacecast = memref.memory_space_cast %collapse_shape : memref<16xf32, 1> to memref<16xf32>
+ %0 = memref.load %collapse_shape[%c0] : memref<16xf32, 1>
+ %1 = arith.addf %0, %arg2 : f32
+ memref.store %1, %collapse_shape[%c0] : memref<16xf32, 1>
+ %2 = memref.atomic_rmw addf %arg2, %collapse_shape[%c4] : (f32, memref<16xf32, 1>) -> f32
+ return %memspacecast : memref<16xf32>
+ }
+ ```
+ }];
+}
+
#endif // MLIR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index 734294bd014c6..e25a0121a3359 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRIR
+ MLIRMemOpInterfaces
MLIRMemorySlotInterfaces
MLIRShapedOpInterfaces
MLIRSideEffectInterfaces
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 5d15d5f6e3de4..0ddb2b0ca1645 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -111,6 +111,65 @@ static void constifyIndexValues(SmallVectorImpl<OpFoldResult> &values,
}
}
+/// Helper function to retrieve a fusable memory-space cast, and the
+/// corresponding new result memref type.
+static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
+getFuseCastInfo(BaseMemRefType resultTy, Value src) {
+ MemorySpaceCastOpInterface castOp =
+ MemorySpaceCastOpInterface::getIfFusableCast(src);
+
+ // Bail if the cast is not fusable.
+ if (!castOp)
+ return {};
+
+ // Transform the source and target type of `castOp` to have the same metadata
+ // as `resultTy`. Bail if not possible.
+ FailureOr<PtrLikeTypeInterface> srcTy = resultTy.clonePtrWith(
+ castOp.getSourcePtr().getType().getMemorySpace(), std::nullopt);
+ if (failed(srcTy))
+ return {};
+
+ FailureOr<PtrLikeTypeInterface> tgtTy = resultTy.clonePtrWith(
+ castOp.getTargetPtr().getType().getMemorySpace(), std::nullopt);
+ if (failed(tgtTy))
+ return {};
+
+ // Check if this is a valid memory-space cast.
+ if (!castOp.isValidMemorySpaceCast(*tgtTy, *srcTy))
+ return {};
+
+ return std::make_tuple(castOp, *tgtTy, *srcTy);
+}
+
+/// Implementation of `fuseCastOperands` method for memref operations that
+/// return a single memref result.
+template <typename ConcreteOpTy>
+static FailureOr<SmallVector<Value>>
+fuseCastOperandsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder,
+ bool &modifiedInPlace, OpOperand &src) {
+ auto [castOp, tgtTy, resTy] = getFuseCastInfo(op.getType(), src.get());
+ // Bail if we cannot cast.
+ if (!castOp)
+ return failure();
+
+ modifiedInPlace = false;
+
+ // Create the new operands.
+ SmallVector<Value> operands;
+ llvm::append_range(operands, op->getOperands());
+ operands[src.getOperandNumber()] = castOp.getSourcePtr();
+
+ // Create the fused op and results.
+ auto newOp = ConcreteOpTy::create(
+ builder, op.getLoc(), TypeRange(resTy), operands, op.getProperties(),
+ llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
+
+ // Insert a memory-space cast to the original memory space of the op.
+ MemorySpaceCastOpInterface result =
+ castOp.cloneMemorySpaceCastOp(builder, tgtTy, newOp.getResult());
+ return SmallVector<Value>({result.getTargetPtr()});
+}
+
//===----------------------------------------------------------------------===//
// AllocOp / AllocaOp
//===----------------------------------------------------------------------===//
@@ -542,6 +601,12 @@ OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
return getMemref();
}
+FailureOr<SmallVector<Value>>
+AssumeAlignmentOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+ getMemrefMutable());
+}
+
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
@@ -710,6 +775,12 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
}
+FailureOr<SmallVector<Value>> CastOp::fuseCastOperands(OpBuilder &builder,
+ bool &modifiedInPlace) {
+ return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+ getSourceMutable());
+}
+
//===----------------------------------------------------------------------===//
// CopyOp
//===----------------------------------------------------------------------===//
@@ -1601,6 +1672,12 @@ OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
return OpFoldResult();
}
+FailureOr<SmallVector<Value>> LoadOp::fuseCastOperands(OpBuilder &builder,
+ bool &modifiedInPlace) {
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getMemrefMutable(), getResult(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// MemorySpaceCastOp
//===----------------------------------------------------------------------===//
@@ -1645,6 +1722,33 @@ OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
return Value{};
}
+TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
+ return cast<TypedValue<PtrLikeTypeInterface>>(getSource());
+}
+
+TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
+ return cast<TypedValue<PtrLikeTypeInterface>>(getDest());
+}
+
+bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
+ PtrLikeTypeInterface src) {
+ return isa<MemRefType>(tgt) &&
+ tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
+}
+
+MemorySpaceCastOpInterface
+MemorySpaceCastOp::cloneMemorySpaceCastOp(OpBuilder &b, Type tgt, Value src) {
+ assert(isValidMemorySpaceCast(cast<PtrLikeTypeInterface>(tgt),
+ cast<PtrLikeTypeInterface>(src.getType())) &&
+ "invalid arguments");
+ return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
+}
+
+bool MemorySpaceCastOp::isFusableMemorySpaceCast() {
+ // Only allow fusion when this is discarding information.
+ return getDest().getType().getMemorySpace() == nullptr;
+}
+
//===----------------------------------------------------------------------===//
// PrefetchOp
//===----------------------------------------------------------------------===//
@@ -2041,6 +2145,12 @@ void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
}
+FailureOr<SmallVector<Value>>
+ReinterpretCastOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+ getSourceMutable());
+}
+
//===----------------------------------------------------------------------===//
// Reassociative reshape ops
//===----------------------------------------------------------------------===//
@@ -2348,6 +2458,12 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
}
+FailureOr<SmallVector<Value>>
+ExpandShapeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+ getSrcMutable());
+}
+
/// Compute the layout map after collapsing a given source MemRef type with the
/// specified reassociation indices.
///
@@ -2569,6 +2685,12 @@ OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
adaptor.getOperands());
}
+FailureOr<SmallVector<Value>>
+CollapseShapeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+ getSrcMutable());
+}
+
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
@@ -2609,6 +2731,12 @@ LogicalResult ReshapeOp::verify() {
return success();
}
+FailureOr<SmallVector<Value>>
+ReshapeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+ getSourceMutable());
+}
+
//===----------------------------------------------------------------------===//
// StoreOp
//===----------------------------------------------------------------------===//
@@ -2626,6 +2754,12 @@ LogicalResult StoreOp::fold(FoldAdaptor adaptor,
return foldMemRefCast(*this, getValueToStore());
}
+FailureOr<SmallVector<Value>> StoreOp::fuseCastOperands(OpBuilder &builder,
+ bool &modifiedInPlace) {
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getMemrefMutable(), ValueRange(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// SubViewOp
//===----------------------------------------------------------------------===//
@@ -3282,6 +3416,12 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
return {};
}
+FailureOr<SmallVector<Value>>
+SubViewOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+ getSourceMutable());
+}
+
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
@@ -3382,6 +3522,12 @@ OpFoldResult TransposeOp::fold(FoldAdaptor) {
return {};
}
+FailureOr<SmallVector<Value>>
+TransposeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+ getInMutable());
+}
+
//===----------------------------------------------------------------------===//
// ViewOp
//===----------------------------------------------------------------------===//
@@ -3525,6 +3671,12 @@ void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
}
+FailureOr<SmallVector<Value>> ViewOp::fuseCastOperands(OpBuilder &builder,
+ bool &modifiedInPlace) {
+ return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
+ getSourceMutable());
+}
+
//===----------------------------------------------------------------------===//
// AtomicRMWOp
//===----------------------------------------------------------------------===//
@@ -3570,6 +3722,12 @@ OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
return OpFoldResult();
}
+FailureOr<SmallVector<Value>>
+AtomicRMWOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getMemrefMutable(), getResult(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8d6e263934fb4..806e6c1c070aa 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5087,6 +5087,14 @@ void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<TransferReadAfterWriteToBroadcast>(context);
}
+FailureOr<SmallVector<Value>>
+TransferReadOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ if (!hasPureBufferSemantics())
+ return failure();
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getBaseMutable(), getResult(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// TransferWriteOp
//===----------------------------------------------------------------------===//
@@ -5574,6 +5582,14 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
}
+FailureOr<SmallVector<Value>>
+TransferWriteOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ if (!hasPureBufferSemantics())
+ return failure();
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getBaseMutable(), ValueRange(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// LoadOp
//===----------------------------------------------------------------------===//
@@ -5628,6 +5644,12 @@ std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}
+FailureOr<SmallVector<Value>> LoadOp::fuseCastOperands(OpBuilder &builder,
+ bool &modifiedInPlace) {
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getBaseMutable(), getResult(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// StoreOp
//===----------------------------------------------------------------------===//
@@ -5667,6 +5689,12 @@ std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}
+FailureOr<SmallVector<Value>> StoreOp::fuseCastOperands(OpBuilder &builder,
+ bool &modifiedInPlace) {
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getBaseMutable(), ValueRange(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// MaskedLoadOp
//===----------------------------------------------------------------------===//
@@ -5721,6 +5749,12 @@ OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
return OpFoldResult();
}
+FailureOr<SmallVector<Value>>
+MaskedLoadOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getBaseMutable(), getResult(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// MaskedStoreOp
//===----------------------------------------------------------------------===//
@@ -5771,6 +5805,12 @@ LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
return memref::foldMemRefCast(*this);
}
+FailureOr<SmallVector<Value>>
+MaskedStoreOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getBaseMutable(), ValueRange(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// GatherOp
//===----------------------------------------------------------------------===//
@@ -5874,6 +5914,12 @@ void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<GatherFolder, FoldContiguousGather>(context);
}
+FailureOr<SmallVector<Value>>
+GatherOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getBaseMutable(), getResult(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
@@ -5936,6 +5982,12 @@ void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ScatterFolder, FoldContiguousScatter>(context);
}
+FailureOr<SmallVector<Value>>
+ScatterOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getBaseMutable(), ValueRange(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// ExpandLoadOp
//===----------------------------------------------------------------------===//
@@ -5984,6 +6036,12 @@ void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ExpandLoadFolder>(context);
}
+FailureOr<SmallVector<Value>>
+ExpandLoadOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getBaseMutable(), getResult(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// CompressStoreOp
//===----------------------------------------------------------------------===//
@@ -6030,6 +6088,12 @@ void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<CompressStoreFolder>(context);
}
+FailureOr<SmallVector<Value>>
+CompressStoreOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+ return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ getBaseMutable(), ValueRange(), modifiedInPlace);
+}
+
//===----------------------------------------------------------------------===//
// ShapeCastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index fdc19844702bc..388de1c3e5abf 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -11,6 +11,7 @@ set(LLVM_OPTIONAL_SOURCES
InferIntRangeInterface.cpp
InferTypeOpInterface.cpp
LoopLikeInterface.cpp
+ MemOpInterfaces.cpp
MemorySlotInterfaces.cpp
ParallelCombiningOpInterface.cpp
RuntimeVerifiableOpInterface.cpp
@@ -79,6 +80,7 @@ add_mlir_library(MLIRLoopLikeInterface
MLIRFunctionInterfaces
)
+add_mlir_interface_library(MemOpInterfaces)
add_mlir_interface_library(MemorySlotInterfaces)
add_mlir_interface_library(ParallelCombiningOpInterface)
add_mlir_interface_library(RuntimeVerifiableOpInterface)
diff --git a/mlir/lib/Interfaces/MemOpInterfaces.cpp b/mlir/lib/Interfaces/MemOpInterfaces.cpp
new file mode 100644
index 0000000000000..013d828da1d66
--- /dev/null
+++ b/mlir/lib/Interfaces/MemOpInterfaces.cpp
@@ -0,0 +1,73 @@
+//===- MemOpInterfaces.cpp - Memory operation interfaces ---------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/MemOpInterfaces.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
+
+using namespace mlir;
+
+LogicalResult mlir::detail::verifyMemorySpaceCastOpInterface(Operation *op) {
+ auto memCastOp = cast<MemorySpaceCastOpInterface>(op);
+
+ // Verify that the source and target pointers are valid
+ Value sourcePtr = memCastOp.getSourcePtr();
+ Value targetPtr = memCastOp.getTargetPtr();
+
+ if (!sourcePtr || !targetPtr) {
+ return op->emitError()
+ << "memory space cast op must have valid source and target pointers";
+ }
+
+ if (sourcePtr.getType().getTypeID() != targetPtr.getType().getTypeID()) {
+ return op->emitError()
+ << "expected source and target types of the same kind";
+ }
+
+ // Verify the Types are of `PtrLikeTypeInterface` type.
+ auto sourceType = dyn_cast<PtrLikeTypeInterface>(sourcePtr.getType());
+ if (!sourceType) {
+ return op->emitError()
+ << "source type must implement `PtrLikeTypeInterface`, but got: "
+ << sourcePtr.getType();
+ }
+
+ auto targetType = dyn_cast<PtrLikeTypeInterface>(targetPtr.getType());
+ if (!targetType) {
+ return op->emitError()
+ << "target type must implement `PtrLikeTypeInterface`, but got: "
+ << targetPtr.getType();
+ }
+
+ // Verify that the operation has exactly one result
+ if (op->getNumResults() != 1) {
+ return op->emitError()
+ << "memory space cast op must have exactly one result";
+ }
+
+ return success();
+}
+
+FailureOr<SmallVector<Value>> mlir::detail::fuseInPlaceMemorySpaceCastImpl(
+ OpOperand &operand, ValueRange results, bool &modifiedInPlace) {
+ MemorySpaceCastOpInterface castOp =
+ MemorySpaceCastOpInterface::getIfFusableCast(operand.get());
+
+ // Bail if the src is not produced by a `MemorySpaceCastOpInterface`.
+ if (!castOp)
+ return failure();
+
+ // Modify the op.
+ modifiedInPlace = true;
+ operand.set(castOp.getSourcePtr());
+ return llvm::to_vector_of<Value>(results);
+}
+
+#include "mlir/Interfaces/MemOpInterfaces.cpp.inc"
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 058039e47313e..e9a7d3e4abe99 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_library(MLIRTransforms
ControlFlowSink.cpp
CSE.cpp
GenerateRuntimeVerification.cpp
+ FuseMemorySpaceCastsIntoConsumers.cpp
InlinerPass.cpp
LocationSnapshot.cpp
LoopInvariantCodeMotion.cpp
@@ -31,6 +32,7 @@ add_mlir_library(MLIRTransforms
MLIRAnalysis
MLIRFunctionInterfaces
MLIRLoopLikeInterface
+ MLIRMemOpInterfaces
MLIRMemorySlotInterfaces
MLIRPass
MLIRRuntimeVerifiableOpInterface
diff --git a/mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp b/mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp
new file mode 100644
index 0000000000000..010b88ac12de2
--- /dev/null
+++ b/mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp
@@ -0,0 +1,73 @@
+//===- FuseMemorySpaceCastsIntoConsumers.cpp - Fuse casts transform -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/MemOpInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+
+namespace mlir {
+#define GEN_PASS_DEF_FUSEMEMORYSPACECASTSINTOCONSUMERS
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+namespace {
+//===----------------------------------------------------------------------===//
+// FuseCastsPattern pattern
+//===----------------------------------------------------------------------===//
+/// Pattern to fuse casts into consumer operations.
+struct FuseCastsPattern
+ : public OpInterfaceRewritePattern<FuseMemorySpaceCastConsumerOpInterface> {
+ using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+
+ LogicalResult matchAndRewrite(FuseMemorySpaceCastConsumerOpInterface op,
+ PatternRewriter &rewriter) const override {
+ bool modifiedInPlace = false;
+ FailureOr<SmallVector<Value>> results =
+ op.fuseCastOperands(rewriter, modifiedInPlace);
+ assert((!failed(results) || !modifiedInPlace) &&
+ "expected `modifiedInPlace` to be false on fusion failure");
+ if (failed(results))
+ return failure();
+ if (modifiedInPlace) {
+ rewriter.modifyOpInPlace(op, []() {});
+ return success();
+ }
+ rewriter.replaceOp(op, *results);
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// FuseMemorySpaceCastsIntoConsumers pass
+//===----------------------------------------------------------------------===//
+
+struct FuseMemorySpaceCastsIntoConsumers
+ : public impl::FuseMemorySpaceCastsIntoConsumersBase<
+ FuseMemorySpaceCastsIntoConsumers> {
+ using impl::FuseMemorySpaceCastsIntoConsumersBase<
+ FuseMemorySpaceCastsIntoConsumers>::FuseMemorySpaceCastsIntoConsumersBase;
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateFuseMemorySpaceCastIntoConsumersPatterns(patterns);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+void mlir::populateFuseMemorySpaceCastIntoConsumersPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<FuseCastsPattern>(patterns.getContext());
+}
diff --git a/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir b/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir
new file mode 100644
index 0000000000000..69a15f429cec2
--- /dev/null
+++ b/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir
@@ -0,0 +1,281 @@
+// RUN: mlir-opt %s --fuse-memory-space-casts-into-consumers | FileCheck %s
+
+#map = affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>
+// CHECK-LABEL: func.func @load_store(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index) {
+// CHECK: %[[VAL_0:.*]] = memref.load %[[ARG0]]{{\[}}%[[ARG1]]] : memref<?xf32, 1>
+// CHECK: memref.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[ARG1]]] : memref<?xf32, 1>
+// CHECK: return
+// CHECK: }
+func.func @load_store(%arg0: memref<?xf32, 1>, %arg1: index) {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+ %0 = memref.load %memspacecast[%arg1] : memref<?xf32>
+ memref.store %0, %memspacecast[%arg1] : memref<?xf32>
+ return
+}
+
+// CHECK-LABEL: func.func @load_store_unfoldable(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index) {
+// CHECK: %[[VAL_0:.*]] = memref.memory_space_cast %[[ARG0]] : memref<?xf32, 1> to memref<?xf32, 2>
+// CHECK: %[[VAL_1:.*]] = memref.load %[[VAL_0]]{{\[}}%[[ARG1]]] : memref<?xf32, 2>
+// CHECK: memref.store %[[VAL_1]], %[[VAL_0]]{{\[}}%[[ARG1]]] : memref<?xf32, 2>
+// CHECK: return
+// CHECK: }
+func.func @load_store_unfoldable(%arg0: memref<?xf32, 1>, %arg1: index) {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32, 2>
+ %0 = memref.load %memspacecast[%arg1] : memref<?xf32, 2>
+ memref.store %0, %memspacecast[%arg1] : memref<?xf32, 2>
+ return
+}
+
+// CHECK-LABEL: func.func @view(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xi8, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> memref<?x?xi8> {
+// CHECK: %[[VAL_0:.*]] = arith.constant 100 : index
+// CHECK: %[[VAL_1:.*]] = memref.view %[[ARG0]]{{\[}}%[[ARG1]]]{{\[}}%[[ARG2]], %[[VAL_0]]] : memref<?xi8, 1> to memref<?x?xi8, 1>
+// CHECK: %[[VAL_2:.*]] = memref.memory_space_cast %[[VAL_1]] : memref<?x?xi8, 1> to memref<?x?xi8>
+// CHECK: return %[[VAL_2]] : memref<?x?xi8>
+// CHECK: }
+func.func @view(%arg0: memref<?xi8, 1>, %arg1: index, %arg2: index) -> memref<?x?xi8> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xi8, 1> to memref<?xi8>
+ %c100 = arith.constant 100 : index
+ %view = memref.view %memspacecast[%arg1][%arg2, %c100] : memref<?xi8> to memref<?x?xi8>
+ return %view : memref<?x?xi8>
+}
+
+// CHECK-LABEL: func.func @subview(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index) -> memref<8x2xf32, strided<[?, 2], offset: ?>> {
+// CHECK: %[[VAL_0:.*]] = memref.subview %[[ARG0]][4, 2] [8, 2] [3, 2] : memref<?x?xf32, 1> to memref<8x2xf32, strided<[?, 2], offset: ?>, 1>
+// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<8x2xf32, strided<[?, 2], offset: ?>, 1> to memref<8x2xf32, strided<[?, 2], offset: ?>>
+// CHECK: return %[[VAL_1]] : memref<8x2xf32, strided<[?, 2], offset: ?>>
+// CHECK: }
+func.func @subview(%arg0: memref<?x?xf32, 1>, %arg1: index) -> memref<8x2xf32, strided<[?, 2], offset: ?>> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?x?xf32, 1> to memref<?x?xf32>
+ %subview = memref.subview %memspacecast[4, 2] [8, 2] [3, 2] : memref<?x?xf32> to memref<8x2xf32, strided<[?, 2], offset: ?>>
+ return %subview : memref<8x2xf32, strided<[?, 2], offset: ?>>
+}
+
+// CHECK-LABEL: func.func @reinterpret_cast(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index) -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
+// CHECK: %[[VAL_0:.*]] = arith.constant 10 : index
+// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_2:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[VAL_1]]], sizes: [10, %[[VAL_0]]], strides: {{\[}}%[[VAL_0]], 1] : memref<?xf32, 1> to memref<10x?xf32, strided<[?, 1], offset: ?>, 1>
+// CHECK: %[[VAL_3:.*]] = memref.memory_space_cast %[[VAL_2]] : memref<10x?xf32, strided<[?, 1], offset: ?>, 1> to memref<10x?xf32, strided<[?, 1], offset: ?>>
+// CHECK: return %[[VAL_3]] : memref<10x?xf32, strided<[?, 1], offset: ?>>
+// CHECK: }
+func.func @reinterpret_cast(%arg0: memref<?xf32, 1>, %arg1: index) -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+ %c0 = arith.constant 0 : index
+ %c10 = arith.constant 10 : index
+ %reinterpret_cast = memref.reinterpret_cast %memspacecast to offset: [%c0], sizes: [10, %c10], strides: [%c10, 1] : memref<?xf32> to memref<10x?xf32, strided<[?, 1], offset: ?>>
+ return %reinterpret_cast : memref<10x?xf32, strided<[?, 1], offset: ?>>
+}
+
+// CHECK-LABEL: func.func @reshape(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: memref<1xindex>) -> memref<?xf32> {
+// CHECK: %[[VAL_0:.*]] = memref.reshape %[[ARG0]](%[[ARG1]]) : (memref<?x?xf32, 1>, memref<1xindex>) -> memref<?xf32, 1>
+// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<?xf32, 1> to memref<?xf32>
+// CHECK: return %[[VAL_1]] : memref<?xf32>
+// CHECK: }
+func.func @reshape(%arg0: memref<?x?xf32, 1>, %arg1: memref<1xindex>) -> memref<?xf32> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?x?xf32, 1> to memref<?x?xf32>
+ %reshape = memref.reshape %memspacecast(%arg1) : (memref<?x?xf32>, memref<1xindex>) -> memref<?xf32>
+ return %reshape : memref<?xf32>
+}
+
+// CHECK-LABEL: func.func @expand_shape(
+// CHECK-SAME: %[[ARG0:.*]]: memref<12xf32, 1>) -> memref<3x4xf32> {
+// CHECK: %[[VAL_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1]] output_shape [3, 4] : memref<12xf32, 1> into memref<3x4xf32, 1>
+// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<3x4xf32, 1> to memref<3x4xf32>
+// CHECK: return %[[VAL_1]] : memref<3x4xf32>
+// CHECK: }
+func.func @expand_shape(%arg0: memref<12xf32, 1>) -> memref<3x4xf32> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<12xf32, 1> to memref<12xf32>
+ %expand_shape = memref.expand_shape %memspacecast [[0, 1]] output_shape [3, 4] : memref<12xf32> into memref<3x4xf32>
+ return %expand_shape : memref<3x4xf32>
+}
+
+// CHECK-LABEL: func.func @collapse_shape(
+// CHECK-SAME: %[[ARG0:.*]]: memref<3x4xf32, 1>) -> memref<12xf32> {
+// CHECK: %[[VAL_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[\[}}0, 1]] : memref<3x4xf32, 1> into memref<12xf32, 1>
+// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<12xf32, 1> to memref<12xf32>
+// CHECK: return %[[VAL_1]] : memref<12xf32>
+// CHECK: }
+func.func @collapse_shape(%arg0: memref<3x4xf32, 1>) -> memref<12xf32> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<3x4xf32, 1> to memref<3x4xf32>
+ %collapse_shape = memref.collapse_shape %memspacecast [[0, 1]] : memref<3x4xf32> into memref<12xf32>
+ return %collapse_shape : memref<12xf32>
+}
+
+// CHECK-LABEL: func.func @transpose(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32, 1>) -> memref<?x?xf32, #[[$ATTR_0]]> {
+// CHECK: %[[VAL_0:.*]] = memref.transpose %[[ARG0]] (d0, d1) -> (d1, d0) : memref<?x?xf32, 1> to memref<?x?xf32, #[[$ATTR_0]], 1>
+// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<?x?xf32, #[[$ATTR_0]], 1> to memref<?x?xf32, #[[$ATTR_0]]>
+// CHECK: return %[[VAL_1]] : memref<?x?xf32, #[[$ATTR_0]]>
+// CHECK: }
+func.func @transpose(%arg0: memref<?x?xf32, 1>) -> memref<?x?xf32, #map> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?x?xf32, 1> to memref<?x?xf32>
+ %transpose = memref.transpose %memspacecast (d0, d1) -> (d1, d0) : memref<?x?xf32> to memref<?x?xf32, #map>
+ return %transpose : memref<?x?xf32, #map>
+}
+
+// CHECK-LABEL: func.func @atomic_rmw(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index,
+// CHECK-SAME: %[[ARG2:.*]]: f32) -> f32 {
+// CHECK: %[[VAL_0:.*]] = memref.atomic_rmw addf %[[ARG2]], %[[ARG0]]{{\[}}%[[ARG1]]] : (f32, memref<?xf32, 1>) -> f32
+// CHECK: return %[[VAL_0]] : f32
+// CHECK: }
+func.func @atomic_rmw(%arg0: memref<?xf32, 1>, %arg1: index, %arg2: f32) -> f32 {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+ %0 = memref.atomic_rmw addf %arg2, %memspacecast[%arg1] : (f32, memref<?xf32>) -> f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func.func @assume_alignment(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>) -> memref<?xf32> {
+// CHECK: %[[VAL_0:.*]] = memref.assume_alignment %[[ARG0]], 16 : memref<?xf32, 1>
+// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<?xf32, 1> to memref<?xf32>
+// CHECK: return %[[VAL_1]] : memref<?xf32>
+// CHECK: }
+func.func @assume_alignment(%arg0: memref<?xf32, 1>) -> memref<?xf32> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+ %1 = memref.assume_alignment %memspacecast, 16 : memref<?xf32>
+ return %1 : memref<?xf32>
+}
+
+// CHECK-LABEL: func.func @op_with_cast_sequence(
+// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index,
+// CHECK-SAME: %[[ARG2:.*]]: f32) -> memref<16xf32> {
+// CHECK: %[[VAL_0:.*]] = arith.constant 4 : index
+// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_2:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
+// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
+// CHECK: %[[VAL_4:.*]] = memref.memory_space_cast %[[VAL_3]] : memref<16xf32, 1> to memref<16xf32>
+// CHECK: %[[VAL_5:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_1]]] : memref<16xf32, 1>
+// CHECK: %[[VAL_6:.*]] = arith.addf %[[VAL_5]], %[[ARG2]] : f32
+// CHECK: memref.store %[[VAL_6]], %[[VAL_3]]{{\[}}%[[VAL_1]]] : memref<16xf32, 1>
+// CHECK: %[[VAL_7:.*]] = memref.atomic_rmw addf %[[ARG2]], %[[VAL_3]]{{\[}}%[[VAL_0]]] : (f32, memref<16xf32, 1>) -> f32
+// CHECK: return %[[VAL_4]] : memref<16xf32>
+// CHECK: }
+func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32>
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32>
+ %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32>
+ %loaded = memref.load %collapsed[%c0] : memref<16xf32>
+ %added = arith.addf %loaded, %arg2 : f32
+ memref.store %added, %collapsed[%c0] : memref<16xf32>
+ %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32
+ return %collapsed : memref<16xf32>
+}
+
+// CHECK-LABEL: func.func @transfer_read_write(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index) {
+// CHECK: %[[VAL_0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_1:.*]] = vector.transfer_read %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xf32>
+// CHECK: vector.transfer_write %[[VAL_1]], %[[ARG0]]{{\[}}%[[ARG1]]] : vector<4xf32>, memref<?xf32, 1>
+// CHECK: return
+// CHECK: }
+func.func @transfer_read_write(%arg0: memref<?xf32, 1>, %arg1: index) {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+ %c0 = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %memspacecast[%arg1], %c0 : memref<?xf32>, vector<4xf32>
+ vector.transfer_write %0, %memspacecast[%arg1] : vector<4xf32>, memref<?xf32>
+ return
+}
+
+// NOTE: The operations disappear because they can get folded.
+// CHECK-LABEL: func.func @transfer_read_write_tensor(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: index) -> tensor<?xf32> {
+// CHECK: return %[[ARG0]] : tensor<?xf32>
+// CHECK: }
+func.func @transfer_read_write_tensor(%arg0: tensor<?xf32>, %arg1: index) -> tensor<?xf32> {
+ %c0 = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %arg0[%arg1], %c0 : tensor<?xf32>, vector<4xf32>
+ %1 = vector.transfer_write %0, %arg0[%arg1] : vector<4xf32>, tensor<?xf32>
+ return %1 : tensor<?xf32>
+}
+
+// CHECK-LABEL: func.func @vector_load_store(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index) {
+// CHECK: %[[VAL_0:.*]] = vector.load %[[ARG0]]{{\[}}%[[ARG1]]] : memref<?xf32, 1>, vector<4xf32>
+// CHECK: vector.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[ARG1]]] : memref<?xf32, 1>, vector<4xf32>
+// CHECK: return
+// CHECK: }
+func.func @vector_load_store(%arg0: memref<?xf32, 1>, %arg1: index) {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+ %0 = vector.load %memspacecast[%arg1] : memref<?xf32>, vector<4xf32>
+ vector.store %0, %memspacecast[%arg1] : memref<?xf32>, vector<4xf32>
+ return
+}
+
+// CHECK-LABEL: func.func @masked_load_store(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index) {
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK: %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+// CHECK: %[[VAL_2:.*]] = vector.maskedload %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+// CHECK: vector.maskedstore %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_2]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32>
+// CHECK: return
+// CHECK: }
+func.func @masked_load_store(%arg0: memref<?xf32, 1>, %arg1: index) {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+ %mask = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+ %passthrough = arith.constant dense<0.0> : vector<4xf32>
+ %0 = vector.maskedload %memspacecast[%arg1], %mask, %passthrough : memref<?xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ vector.maskedstore %memspacecast[%arg1], %mask, %0 : memref<?xf32>, vector<4xi1>, vector<4xf32>
+ return
+}
+
+// CHECK-LABEL: func.func @gather_scatter(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index) {
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK: %[[VAL_1:.*]] = arith.constant dense<true> : vector<4xi1>
+// CHECK: %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = vector.gather %[[ARG0]]{{\[}}%[[VAL_3]]] {{\[}}%[[VAL_2]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+// CHECK: vector.scatter %[[ARG0]]{{\[}}%[[VAL_3]]] {{\[}}%[[VAL_2]]], %[[VAL_1]], %[[VAL_4]] : memref<?xf32, 1>, vector<4xindex>, vector<4xi1>, vector<4xf32>
+// CHECK: return
+// CHECK: }
+func.func @gather_scatter(%arg0: memref<?xf32, 1>, %arg1: index) {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+ %c0 = arith.constant 0 : index
+ %indices = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+ %mask = arith.constant dense<true> : vector<4xi1>
+ %passthrough = arith.constant dense<0.0> : vector<4xf32>
+ %0 = vector.gather %memspacecast[%c0] [%indices], %mask, %passthrough : memref<?xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ vector.scatter %memspacecast[%c0] [%indices], %mask, %0 : memref<?xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32>
+ return
+}
+
+// CHECK-LABEL: func.func @expandload_compressstore(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index) {
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK: %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+// CHECK: %[[VAL_2:.*]] = vector.expandload %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+// CHECK: vector.compressstore %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_2]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32>
+// CHECK: return
+// CHECK: }
+func.func @expandload_compressstore(%arg0: memref<?xf32, 1>, %arg1: index) {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+ %mask = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+ %passthrough = arith.constant dense<0.0> : vector<4xf32>
+ %0 = vector.expandload %memspacecast[%arg1], %mask, %passthrough : memref<?xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ vector.compressstore %memspacecast[%arg1], %mask, %0 : memref<?xf32>, vector<4xi1>, vector<4xf32>
+ return
+}
>From 3709f67bd0fbe1d0b8b86121c20fb0f5c9a933d2 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Fri, 19 Sep 2025 12:00:22 +0000
Subject: [PATCH 02/10] address comments 1/2
---
mlir/include/mlir/Transforms/Passes.td | 7 ++--
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 2 +-
.../test-fuse-casts-into-consumers.mlir | 41 +++++++++++++------
3 files changed, 34 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 69280e3d443ea..3204e80919456 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -589,9 +589,10 @@ def FuseMemorySpaceCastsIntoConsumers :
Pass<"fuse-memory-space-casts-into-consumers"> {
let summary = "Fuses memory-space cast operations into consumers.";
let description = [{
- This pass tries to fuse all possible memory-space cast operations into their consumers.
- It does this by looking for `FuseMemorySpaceCastConsumerOpInterface`
- operations, and invoking the interface methods to perform the fusion.
+ This pass tries to iteratively fuse all possible memory-space cast
+ operations into their consumers. It does this by looking for
+ `FuseMemorySpaceCastConsumerOpInterface` operations, and invoking the
+ interface methods to perform the fusion.
Example:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 0ddb2b0ca1645..11fd43ff54575 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1732,7 +1732,7 @@ TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
PtrLikeTypeInterface src) {
- return isa<MemRefType>(tgt) &&
+ return isa<BaseMemRefType>(tgt) &&
tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
}
diff --git a/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir b/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir
index 69a15f429cec2..7534332b3663a 100644
--- a/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir
+++ b/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir
@@ -32,6 +32,23 @@ func.func @load_store_unfoldable(%arg0: memref<?xf32, 1>, %arg1: index) {
return
}
+// CHECK-LABEL: func.func @cast(
+// CHECK-SAME: %[[ARG0:.*]]: memref<2xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: memref<*xf32, 1>) -> (memref<*xf32>, memref<3x2xf32>) {
+// CHECK: %[[VAL_0:.*]] = memref.cast %[[ARG0]] : memref<2xf32, 1> to memref<*xf32, 1>
+// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<*xf32, 1> to memref<*xf32>
+// CHECK: %[[VAL_2:.*]] = memref.cast %[[ARG1]] : memref<*xf32, 1> to memref<3x2xf32, 1>
+// CHECK: %[[VAL_3:.*]] = memref.memory_space_cast %[[VAL_2]] : memref<3x2xf32, 1> to memref<3x2xf32>
+// CHECK: return %[[VAL_1]], %[[VAL_3]] : memref<*xf32>, memref<3x2xf32>
+// CHECK: }
+func.func @cast(%arg0: memref<2xf32, 1>, %arg1: memref<*xf32, 1>) -> (memref<*xf32>, memref<3x2xf32>) {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<2xf32, 1> to memref<2xf32>
+ %1 = memref.cast %memspacecast : memref<2xf32> to memref<*xf32>
+ %memspacecast_1 = memref.memory_space_cast %arg1 : memref<*xf32, 1> to memref<*xf32>
+ %2 = memref.cast %memspacecast_1 : memref<*xf32> to memref<3x2xf32>
+ return %1, %2 : memref<*xf32>, memref<3x2xf32>
+}
+
// CHECK-LABEL: func.func @view(
// CHECK-SAME: %[[ARG0:.*]]: memref<?xi8, 1>,
// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> memref<?x?xi8> {
@@ -63,8 +80,8 @@ func.func @subview(%arg0: memref<?x?xf32, 1>, %arg1: index) -> memref<8x2xf32, s
// CHECK-LABEL: func.func @reinterpret_cast(
// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
// CHECK-SAME: %[[ARG1:.*]]: index) -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
-// CHECK: %[[VAL_0:.*]] = arith.constant 10 : index
-// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_0:.*]] = arith.constant 10 : index
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_2:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[VAL_1]]], sizes: [10, %[[VAL_0]]], strides: {{\[}}%[[VAL_0]], 1] : memref<?xf32, 1> to memref<10x?xf32, strided<[?, 1], offset: ?>, 1>
// CHECK: %[[VAL_3:.*]] = memref.memory_space_cast %[[VAL_2]] : memref<10x?xf32, strided<[?, 1], offset: ?>, 1> to memref<10x?xf32, strided<[?, 1], offset: ?>>
// CHECK: return %[[VAL_3]] : memref<10x?xf32, strided<[?, 1], offset: ?>>
@@ -155,8 +172,8 @@ func.func @assume_alignment(%arg0: memref<?xf32, 1>) -> memref<?xf32> {
// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf32, 1>,
// CHECK-SAME: %[[ARG1:.*]]: index,
// CHECK-SAME: %[[ARG2:.*]]: f32) -> memref<16xf32> {
-// CHECK: %[[VAL_0:.*]] = arith.constant 4 : index
-// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_0:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_2:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
// CHECK: %[[VAL_4:.*]] = memref.memory_space_cast %[[VAL_3]] : memref<16xf32, 1> to memref<16xf32>
@@ -225,8 +242,8 @@ func.func @vector_load_store(%arg0: memref<?xf32, 1>, %arg1: index) {
// CHECK-LABEL: func.func @masked_load_store(
// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
// CHECK-SAME: %[[ARG1:.*]]: index) {
-// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
-// CHECK: %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+// CHECK-DAG: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
// CHECK: %[[VAL_2:.*]] = vector.maskedload %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32> into vector<4xf32>
// CHECK: vector.maskedstore %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_2]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32>
// CHECK: return
@@ -243,10 +260,10 @@ func.func @masked_load_store(%arg0: memref<?xf32, 1>, %arg1: index) {
// CHECK-LABEL: func.func @gather_scatter(
// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
// CHECK-SAME: %[[ARG1:.*]]: index) {
-// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
-// CHECK: %[[VAL_1:.*]] = arith.constant dense<true> : vector<4xi1>
-// CHECK: %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<true> : vector<4xi1>
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_4:.*]] = vector.gather %[[ARG0]]{{\[}}%[[VAL_3]]] {{\[}}%[[VAL_2]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
// CHECK: vector.scatter %[[ARG0]]{{\[}}%[[VAL_3]]] {{\[}}%[[VAL_2]]], %[[VAL_1]], %[[VAL_4]] : memref<?xf32, 1>, vector<4xindex>, vector<4xi1>, vector<4xf32>
// CHECK: return
@@ -265,8 +282,8 @@ func.func @gather_scatter(%arg0: memref<?xf32, 1>, %arg1: index) {
// CHECK-LABEL: func.func @expandload_compressstore(
// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
// CHECK-SAME: %[[ARG1:.*]]: index) {
-// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
-// CHECK: %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+// CHECK-DAG: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
// CHECK: %[[VAL_2:.*]] = vector.expandload %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32> into vector<4xf32>
// CHECK: vector.compressstore %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_2]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32>
// CHECK: return
>From 4789d11a5916bdcb3b2cf060954a26b8d57fb190 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Fri, 19 Sep 2025 12:56:13 +0000
Subject: [PATCH 03/10] address comements 2/2
---
.../mlir/Dialect/MemRef/IR/MemRefOps.td | 24 ++--
.../mlir/Dialect/Vector/IR/VectorOps.td | 20 +--
.../include/mlir/Interfaces/MemOpInterfaces.h | 11 +-
.../mlir/Interfaces/MemOpInterfaces.td | 45 ++++---
.../Transforms/BubbleDownMemorySpaceCasts.h | 20 +++
.../FuseMemorySpaceCastsIntoConsumers.h | 20 ---
mlir/include/mlir/Transforms/Passes.h | 2 +-
mlir/include/mlir/Transforms/Passes.td | 15 +--
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 124 ++++++++----------
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 80 +++++------
mlir/lib/Interfaces/MemOpInterfaces.cpp | 12 +-
...ers.cpp => BubbleDownMemorySpaceCasts.cpp} | 46 +++----
mlir/lib/Transforms/CMakeLists.txt | 2 +-
... test-bubble-down-memory-space-casts.mlir} | 2 +-
14 files changed, 205 insertions(+), 218 deletions(-)
create mode 100644 mlir/include/mlir/Transforms/BubbleDownMemorySpaceCasts.h
delete mode 100644 mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h
rename mlir/lib/Transforms/{FuseMemorySpaceCastsIntoConsumers.cpp => BubbleDownMemorySpaceCasts.cpp} (53%)
rename mlir/test/Transforms/{test-fuse-casts-into-consumers.mlir => test-bubble-down-memory-space-casts.mlir} (99%)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 238a767ac8b73..c708d7f3d884a 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -147,7 +147,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
Pure,
ViewLikeOpInterface,
SameOperandsAndResultType,
- DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
]> {
let summary =
"assumption that gives alignment information to the input memref";
@@ -458,7 +458,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
def MemRef_CastOp : MemRef_Op<"cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
MemRefsNormalizable,
Pure,
SameOperandsAndResultShape,
@@ -1197,7 +1197,7 @@ def LoadOp : MemRef_Op<"load",
"memref", "result",
"::llvm::cast<MemRefType>($_self).getElementType()">,
MemRefsNormalizable,
- DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
let summary = "load operation";
@@ -1381,7 +1381,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
def MemRef_ReinterpretCastOp
: MemRef_OpWithOffsetSizesAndStrides<"reinterpret_cast", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
AttrSizedOperandSegments,
MemRefsNormalizable,
Pure,
@@ -1609,7 +1609,7 @@ def MemRef_RankOp : MemRef_Op<"rank", [Pure]> {
def MemRef_ReshapeOp: MemRef_Op<"reshape", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
Pure,
ViewLikeOpInterface]> {
let summary = "memref reshape operation";
@@ -1708,7 +1708,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
let summary = "operation to produce a memref with a higher rank.";
let description = [{
@@ -1831,7 +1831,7 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
]> {
let summary = "operation to produce a memref with a smaller rank.";
let description = [{
@@ -1939,7 +1939,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
"memref", "value",
"::llvm::cast<MemRefType>($_self).getElementType()">,
MemRefsNormalizable,
- DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
let summary = "store operation";
@@ -2017,7 +2017,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface,
@@ -2293,7 +2293,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
def MemRef_TransposeOp : MemRef_Op<"transpose", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
Pure]>,
Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
Results<(outs AnyStridedMemRef)> {
@@ -2329,7 +2329,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [
def MemRef_ViewOp : MemRef_Op<"view", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
Pure]> {
let summary = "memref view operation";
@@ -2406,7 +2406,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
//===----------------------------------------------------------------------===//
def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
- DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
AllTypesMatch<["value", "result"]>,
TypesMatchWith<"value type matches element type of memref",
"memref", "value",
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 93e9bfc78ea75..252c0b72456df 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1247,7 +1247,7 @@ def Vector_TransferReadOp :
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
- DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
AttrSizedOperandSegments,
DestinationStyleOpInterface
]>,
@@ -1495,7 +1495,7 @@ def Vector_TransferWriteOp :
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
- DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
AttrSizedOperandSegments,
DestinationStyleOpInterface
]>,
@@ -1652,7 +1652,7 @@ def Vector_TransferWriteOp :
def Vector_LoadOp : Vector_Op<"load", [
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
- DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
]> {
let summary = "reads an n-D slice of memory into an n-D vector";
let description = [{
@@ -1769,7 +1769,7 @@ def Vector_LoadOp : Vector_Op<"load", [
def Vector_StoreOp : Vector_Op<"store", [
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
- DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
]> {
let summary = "writes an n-D vector to an n-D slice of memory";
let description = [{
@@ -1874,7 +1874,7 @@ def Vector_StoreOp : Vector_Op<"store", [
}
def Vector_MaskedLoadOp :
- Vector_Op<"maskedload", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
+ Vector_Op<"maskedload", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -1966,7 +1966,7 @@ def Vector_MaskedLoadOp :
}
def Vector_MaskedStoreOp :
- Vector_Op<"maskedstore", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
+ Vector_Op<"maskedstore", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2046,7 +2046,7 @@ def Vector_MaskedStoreOp :
def Vector_GatherOp :
Vector_Op<"gather", [
DeclareOpInterfaceMethods<MaskableOpInterface>,
- DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
@@ -2150,7 +2150,7 @@ def Vector_GatherOp :
}
def Vector_ScatterOp :
- Vector_Op<"scatter", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
+ Vector_Op<"scatter", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$offsets,
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
@@ -2235,7 +2235,7 @@ def Vector_ScatterOp :
}
def Vector_ExpandLoadOp :
- Vector_Op<"expandload", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
+ Vector_Op<"expandload", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2323,7 +2323,7 @@ def Vector_ExpandLoadOp :
}
def Vector_CompressStoreOp :
- Vector_Op<"compressstore", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
+ Vector_Op<"compressstore", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.h b/mlir/include/mlir/Interfaces/MemOpInterfaces.h
index cc9f4c6b3882e..d4ed71e38f4ff 100644
--- a/mlir/include/mlir/Interfaces/MemOpInterfaces.h
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.h
@@ -21,13 +21,12 @@ namespace detail {
/// Attempt to verify the given memory space cast operation.
LogicalResult verifyMemorySpaceCastOpInterface(Operation *op);
-/// Tries to fuse inplace a `MemorySpaceCastOpInterface` operation referenced by
-/// `operand`. On success, it returns `results`, and sets `modifiedInPlace` to
-/// true. It returns failure if `operand` doesn't reference a
+/// Tries to bubble-down inplace a `MemorySpaceCastOpInterface` operation
+/// referenced by `operand`. On success, it returns `results` and true. It
+/// returns failure if `operand` doesn't reference a
/// `MemorySpaceCastOpInterface` op.
-FailureOr<SmallVector<Value>>
-fuseInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results,
- bool &modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results);
} // namespace detail
} // namespace mlir
diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.td b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
index 0b8ba19171fb7..d097b00c8e80c 100644
--- a/mlir/include/mlir/Interfaces/MemOpInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
@@ -16,27 +16,26 @@
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
-def FuseMemorySpaceCastConsumerOpInterface :
- OpInterface<"FuseMemorySpaceCastConsumerOpInterface"> {
+def MemorySpaceCastConsumerOpInterface :
+ OpInterface<"MemorySpaceCastConsumerOpInterface"> {
let description = [{
- An interface to fuse memory-space cast operands into a consumer operation.
- It is the responsibility of the interface to determine which casts can be
- fused into the operation.
+ An interface for operations that can consume memory-space cast-like
+ operations.
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<[{
- Attempt to fuse the incoming cast-like operands. Returns `success`
- and any new results on fusion success, otherwise it returns failure.
+ Attempt to bubble-down the incoming cast-like operands. On success
+ returns any new results, and whether the operation was modified in
+ place, otherwise it returns failure.
If new results are produced, these must be compatible with the original
operation results.
- The `modifiedInPlace` parameter indicates whether the operation was
- modified in place. If `false` and the fusion succeeded, then the
- interface guarantees it is valid to erase the original operation.
- If `true`, then the interface must guarantee no operations were created
- by the method, and that no further IR modification is necessary. It is
- considered an error if `modifiedInPlace` is true and the fusion failed.
+ If the operation was not modified in place, then the interface
+ guarantees it is valid to erase the original operation.
+ If the operation was modified in place, then the interface must
+ guarantee no operations were created by the method, and that no further
+ IR modification is necessary.
Any implementations of this method must not erase/replace the original
operation, instead it is the caller responsibility to erase or replace
@@ -45,8 +44,9 @@ def FuseMemorySpaceCastConsumerOpInterface :
Finally, any implementations of this method have to guarantee that the
IR remains valid at all times.
}],
- "::llvm::FailureOr<::llvm::SmallVector<::mlir::Value>>", "fuseCastOperands",
- (ins "::mlir::OpBuilder &":$builder, "bool &":$modifiedInPlace)
+ "::llvm::FailureOr<std::pair<::llvm::SmallVector<::mlir::Value>, bool>>",
+ "bubbleDownCasts",
+ (ins "::mlir::OpBuilder &":$builder)
>,
];
}
@@ -83,13 +83,16 @@ def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
Clones the memory space cast op with the given source and target type.
}],
"::mlir::MemorySpaceCastOpInterface", "cloneMemorySpaceCastOp",
- (ins "::mlir::OpBuilder &":$builder, "::mlir::Type":$tgt,
+ (ins "::mlir::OpBuilder &":$builder, "::mlir::PtrLikeTypeInterface":$tgt,
"::mlir::Value":$src)
>,
InterfaceMethod<[{
- Returns whether the cast allows to be fused.
+ Returns whether the memory-space cast is lossless. A lossless
+ memory-space cast must not lose any information encoded in the memory
+ space. An example of such cast, is any conversion to the generic memory
+ space.
}],
- "bool", "isFusableMemorySpaceCast"
+ "bool", "isLosslessCast"
>
];
let verify = [{
@@ -99,12 +102,12 @@ def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
let extraClassDeclaration = [{
/// Returns the underlying `MemorySpaceCastOpInterface` op if `value`
/// is produced by a `MemorySpaceCastOpInterface` op, and
- /// `isFusableMemorySpaceCast` returns true, otherwise it returns null.
+ /// `isLosslessCast` returns true, otherwise it returns null.
static ::mlir::MemorySpaceCastOpInterface
- getIfFusableCast(::mlir::Value value) {
+ getIfLosslessCast(::mlir::Value value) {
auto op = ::llvm::dyn_cast_or_null<::mlir::MemorySpaceCastOpInterface>(
value.getDefiningOp());
- if (!op || !op.isFusableMemorySpaceCast())
+ if (!op || !op.isLosslessCast())
return nullptr;
return op;
}
diff --git a/mlir/include/mlir/Transforms/BubbleDownMemorySpaceCasts.h b/mlir/include/mlir/Transforms/BubbleDownMemorySpaceCasts.h
new file mode 100644
index 0000000000000..99db092879a90
--- /dev/null
+++ b/mlir/include/mlir/Transforms/BubbleDownMemorySpaceCasts.h
@@ -0,0 +1,20 @@
+//===-- BubbleDownMemorySpaceCasts.h - Bubble down cast patterns ---C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_BUBBLEDOWNMEMORYSPACECASTS_H
+#define MLIR_TRANSFORMS_BUBBLEDOWNMEMORYSPACECASTS_H
+
+namespace mlir {
+class PatternBenefit;
+class RewritePatternSet;
+/// Collect a set of patterns to bubble-down memory-space cast operations.
+void populateBubbleDownMemorySpaceCastPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit);
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_BUBBLEDOWNMEMORYSPACECASTS_H
diff --git a/mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h b/mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h
deleted file mode 100644
index 9333f92a10289..0000000000000
--- a/mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h
+++ /dev/null
@@ -1,20 +0,0 @@
-//===-- FuseMemorySpaceCastsIntoConsumers.h - Cast fusion patterns -C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_TRANSFORMS_FUSEMEMORYSPACECASTSINTOCONSUMERS_H
-#define MLIR_TRANSFORMS_FUSEMEMORYSPACECASTSINTOCONSUMERS_H
-
-namespace mlir {
-class RewritePatternSet;
-/// Collect a set of patterns to fuse memory-space cast operations into
-/// consumers.
-void populateFuseMemorySpaceCastIntoConsumersPatterns(
- RewritePatternSet &patterns);
-} // namespace mlir
-
-#endif // MLIR_TRANSFORMS_FUSEMEMORYSPACECASTSINTOCONSUMERS_H
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 610a9671fede8..1c035f2a843ff 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -46,7 +46,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_SYMBOLPRIVATIZE
#define GEN_PASS_DECL_TOPOLOGICALSORT
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
-#define GEN_PASS_DECL_FUSEMEMORYSPACECASTSINTOCONSUMERS
+#define GEN_PASS_DECL_BUBBLEDOWNMEMORYSPACECASTS
#include "mlir/Transforms/Passes.h.inc"
/// Creates an instance of the Canonicalizer pass, configured with default
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 3204e80919456..8f0b80c5e511b 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -585,14 +585,13 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
];
}
-def FuseMemorySpaceCastsIntoConsumers :
- Pass<"fuse-memory-space-casts-into-consumers"> {
- let summary = "Fuses memory-space cast operations into consumers.";
+def BubbleDownMemorySpaceCasts :
+ Pass<"bubble-down-memory-space-casts"> {
+ let summary = "Bubbles down memory-space cast operations.";
let description = [{
- This pass tries to iteratively fuse all possible memory-space cast
- operations into their consumers. It does this by looking for
- `FuseMemorySpaceCastConsumerOpInterface` operations, and invoking the
- interface methods to perform the fusion.
+ This pass tries to iteratively bubble down all possible memory-space cast
+ operations. It does this by looking for `MemorySpaceCastConsumerOpInterface`
+ operations, and invoking the interface methods to perform the bubbling down.
Example:
@@ -609,7 +608,7 @@ def FuseMemorySpaceCastsIntoConsumers :
%atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32
return %collapsed : memref<16xf32>
}
- // mlir-opt --fuse-casts-into-consumers
+ // mlir-opt --bubble-down-memory-space-casts
func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
%c4 = arith.constant 4 : index
%c0 = arith.constant 0 : index
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11fd43ff54575..6f276efb84c1c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -111,14 +111,14 @@ static void constifyIndexValues(SmallVectorImpl<OpFoldResult> &values,
}
}
-/// Helper function to retrieve a fusable memory-space cast, and the
+/// Helper function to retrieve a lossless memory-space cast, and the
/// corresponding new result memref type.
static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
-getFuseCastInfo(BaseMemRefType resultTy, Value src) {
+getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src) {
MemorySpaceCastOpInterface castOp =
- MemorySpaceCastOpInterface::getIfFusableCast(src);
+ MemorySpaceCastOpInterface::getIfLosslessCast(src);
- // Bail if the cast is not fusable.
+ // Bail if the cast is not lossless.
if (!castOp)
return {};
@@ -141,25 +141,23 @@ getFuseCastInfo(BaseMemRefType resultTy, Value src) {
return std::make_tuple(castOp, *tgtTy, *srcTy);
}
-/// Implementation of `fuseCastOperands` method for memref operations that
+/// Implementation of `bubbleDownCasts` method for memref operations that
/// return a single memref result.
template <typename ConcreteOpTy>
-static FailureOr<SmallVector<Value>>
-fuseCastOperandsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder,
- bool &modifiedInPlace, OpOperand &src) {
- auto [castOp, tgtTy, resTy] = getFuseCastInfo(op.getType(), src.get());
+static FailureOr<std::pair<SmallVector<Value>, bool>>
+bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder,
+ OpOperand &src) {
+ auto [castOp, tgtTy, resTy] = getMemorySpaceCastInfo(op.getType(), src.get());
// Bail if we cannot cast.
if (!castOp)
return failure();
- modifiedInPlace = false;
-
// Create the new operands.
SmallVector<Value> operands;
llvm::append_range(operands, op->getOperands());
operands[src.getOperandNumber()] = castOp.getSourcePtr();
- // Create the fused op and results.
+ // Create the new op and results.
auto newOp = ConcreteOpTy::create(
builder, op.getLoc(), TypeRange(resTy), operands, op.getProperties(),
llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
@@ -167,7 +165,7 @@ fuseCastOperandsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder,
// Insert a memory-space cast to the original memory space of the op.
MemorySpaceCastOpInterface result =
castOp.cloneMemorySpaceCastOp(builder, tgtTy, newOp.getResult());
- return SmallVector<Value>({result.getTargetPtr()});
+ return std::make_pair(SmallVector<Value>({result.getTargetPtr()}), false);
}
//===----------------------------------------------------------------------===//
@@ -601,10 +599,9 @@ OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
return getMemref();
}
-FailureOr<SmallVector<Value>>
-AssumeAlignmentOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
- return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
- getMemrefMutable());
+FailureOr<std::pair<SmallVector<Value>, bool>>
+AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
+ return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
}
//===----------------------------------------------------------------------===//
@@ -775,10 +772,9 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
}
-FailureOr<SmallVector<Value>> CastOp::fuseCastOperands(OpBuilder &builder,
- bool &modifiedInPlace) {
- return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
- getSourceMutable());
+FailureOr<std::pair<SmallVector<Value>, bool>>
+CastOp::bubbleDownCasts(OpBuilder &builder) {
+ return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
}
//===----------------------------------------------------------------------===//
@@ -1672,10 +1668,10 @@ OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
return OpFoldResult();
}
-FailureOr<SmallVector<Value>> LoadOp::fuseCastOperands(OpBuilder &builder,
- bool &modifiedInPlace) {
- return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
- getMemrefMutable(), getResult(), modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+LoadOp::bubbleDownCasts(OpBuilder &builder) {
+ return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
+ getResult());
}
//===----------------------------------------------------------------------===//
@@ -1737,15 +1733,16 @@ bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
}
MemorySpaceCastOpInterface
-MemorySpaceCastOp::cloneMemorySpaceCastOp(OpBuilder &b, Type tgt, Value src) {
- assert(isValidMemorySpaceCast(cast<PtrLikeTypeInterface>(tgt),
- cast<PtrLikeTypeInterface>(src.getType())) &&
- "invalid arguments");
+MemorySpaceCastOp::cloneMemorySpaceCastOp(OpBuilder &b,
+ PtrLikeTypeInterface tgt, Value src) {
+ assert(
+ isValidMemorySpaceCast(tgt, cast<PtrLikeTypeInterface>(src.getType())) &&
+ "invalid arguments");
return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
}
-bool MemorySpaceCastOp::isFusableMemorySpaceCast() {
- // Only allow fusion when this is discarding information.
+bool MemorySpaceCastOp::isLosslessCast() {
+ // The only cast we recognize as lossless is to the generic space.
return getDest().getType().getMemorySpace() == nullptr;
}
@@ -2145,10 +2142,9 @@ void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
}
-FailureOr<SmallVector<Value>>
-ReinterpretCastOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
- return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
- getSourceMutable());
+FailureOr<std::pair<SmallVector<Value>, bool>>
+ReinterpretCastOp::bubbleDownCasts(OpBuilder &builder) {
+ return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
}
//===----------------------------------------------------------------------===//
@@ -2458,10 +2454,9 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
}
-FailureOr<SmallVector<Value>>
-ExpandShapeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
- return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
- getSrcMutable());
+FailureOr<std::pair<SmallVector<Value>, bool>>
+ExpandShapeOp::bubbleDownCasts(OpBuilder &builder) {
+ return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
}
/// Compute the layout map after collapsing a given source MemRef type with the
@@ -2685,10 +2680,9 @@ OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
adaptor.getOperands());
}
-FailureOr<SmallVector<Value>>
-CollapseShapeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
- return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
- getSrcMutable());
+FailureOr<std::pair<SmallVector<Value>, bool>>
+CollapseShapeOp::bubbleDownCasts(OpBuilder &builder) {
+ return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
}
//===----------------------------------------------------------------------===//
@@ -2731,10 +2725,9 @@ LogicalResult ReshapeOp::verify() {
return success();
}
-FailureOr<SmallVector<Value>>
-ReshapeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
- return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
- getSourceMutable());
+FailureOr<std::pair<SmallVector<Value>, bool>>
+ReshapeOp::bubbleDownCasts(OpBuilder &builder) {
+ return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
}
//===----------------------------------------------------------------------===//
@@ -2754,10 +2747,10 @@ LogicalResult StoreOp::fold(FoldAdaptor adaptor,
return foldMemRefCast(*this, getValueToStore());
}
-FailureOr<SmallVector<Value>> StoreOp::fuseCastOperands(OpBuilder &builder,
- bool &modifiedInPlace) {
- return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
- getMemrefMutable(), ValueRange(), modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+StoreOp::bubbleDownCasts(OpBuilder &builder) {
+ return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
+ ValueRange());
}
//===----------------------------------------------------------------------===//
@@ -3416,10 +3409,9 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
return {};
}
-FailureOr<SmallVector<Value>>
-SubViewOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
- return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
- getSourceMutable());
+FailureOr<std::pair<SmallVector<Value>, bool>>
+SubViewOp::bubbleDownCasts(OpBuilder &builder) {
+ return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
}
//===----------------------------------------------------------------------===//
@@ -3522,10 +3514,9 @@ OpFoldResult TransposeOp::fold(FoldAdaptor) {
return {};
}
-FailureOr<SmallVector<Value>>
-TransposeOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
- return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
- getInMutable());
+FailureOr<std::pair<SmallVector<Value>, bool>>
+TransposeOp::bubbleDownCasts(OpBuilder &builder) {
+ return bubbleDownCastsPassthroughOpImpl(*this, builder, getInMutable());
}
//===----------------------------------------------------------------------===//
@@ -3671,10 +3662,9 @@ void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
}
-FailureOr<SmallVector<Value>> ViewOp::fuseCastOperands(OpBuilder &builder,
- bool &modifiedInPlace) {
- return fuseCastOperandsPassthroughOpImpl(*this, builder, modifiedInPlace,
- getSourceMutable());
+FailureOr<std::pair<SmallVector<Value>, bool>>
+ViewOp::bubbleDownCasts(OpBuilder &builder) {
+ return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
}
//===----------------------------------------------------------------------===//
@@ -3722,10 +3712,10 @@ OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
return OpFoldResult();
}
-FailureOr<SmallVector<Value>>
-AtomicRMWOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
- return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
- getMemrefMutable(), getResult(), modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+AtomicRMWOp::bubbleDownCasts(OpBuilder &builder) {
+ return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
+ getResult());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 806e6c1c070aa..77dcb1fc6220e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5087,12 +5087,12 @@ void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<TransferReadAfterWriteToBroadcast>(context);
}
-FailureOr<SmallVector<Value>>
-TransferReadOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+FailureOr<std::pair<SmallVector<Value>, bool>>
+TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
if (!hasPureBufferSemantics())
return failure();
- return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
- getBaseMutable(), getResult(), modifiedInPlace);
+ return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+ getResult());
}
//===----------------------------------------------------------------------===//
@@ -5582,12 +5582,12 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
}
-FailureOr<SmallVector<Value>>
-TransferWriteOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
+FailureOr<std::pair<SmallVector<Value>, bool>>
+TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
if (!hasPureBufferSemantics())
return failure();
- return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
- getBaseMutable(), ValueRange(), modifiedInPlace);
+ return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+ ValueRange());
}
//===----------------------------------------------------------------------===//
@@ -5644,10 +5644,10 @@ std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}
-FailureOr<SmallVector<Value>> LoadOp::fuseCastOperands(OpBuilder &builder,
- bool &modifiedInPlace) {
- return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
- getBaseMutable(), getResult(), modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+LoadOp::bubbleDownCasts(OpBuilder &builder) {
+ return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+ getResult());
}
//===----------------------------------------------------------------------===//
@@ -5689,10 +5689,10 @@ std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}
-FailureOr<SmallVector<Value>> StoreOp::fuseCastOperands(OpBuilder &builder,
- bool &modifiedInPlace) {
- return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
- getBaseMutable(), ValueRange(), modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+StoreOp::bubbleDownCasts(OpBuilder &builder) {
+ return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+ ValueRange());
}
//===----------------------------------------------------------------------===//
@@ -5749,10 +5749,10 @@ OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
return OpFoldResult();
}
-FailureOr<SmallVector<Value>>
-MaskedLoadOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
- return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
- getBaseMutable(), getResult(), modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
+ return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+ getResult());
}
//===----------------------------------------------------------------------===//
@@ -5805,10 +5805,10 @@ LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
return memref::foldMemRefCast(*this);
}
-FailureOr<SmallVector<Value>>
-MaskedStoreOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
- return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
- getBaseMutable(), ValueRange(), modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
+ return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+ ValueRange());
}
//===----------------------------------------------------------------------===//
@@ -5914,10 +5914,10 @@ void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<GatherFolder, FoldContiguousGather>(context);
}
-FailureOr<SmallVector<Value>>
-GatherOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
- return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
- getBaseMutable(), getResult(), modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+GatherOp::bubbleDownCasts(OpBuilder &builder) {
+ return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+ getResult());
}
//===----------------------------------------------------------------------===//
@@ -5982,10 +5982,10 @@ void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ScatterFolder, FoldContiguousScatter>(context);
}
-FailureOr<SmallVector<Value>>
-ScatterOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
- return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
- getBaseMutable(), ValueRange(), modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+ScatterOp::bubbleDownCasts(OpBuilder &builder) {
+ return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+ ValueRange());
}
//===----------------------------------------------------------------------===//
@@ -6036,10 +6036,10 @@ void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ExpandLoadFolder>(context);
}
-FailureOr<SmallVector<Value>>
-ExpandLoadOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
- return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
- getBaseMutable(), getResult(), modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
+ return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+ getResult());
}
//===----------------------------------------------------------------------===//
@@ -6088,10 +6088,10 @@ void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<CompressStoreFolder>(context);
}
-FailureOr<SmallVector<Value>>
-CompressStoreOp::fuseCastOperands(OpBuilder &builder, bool &modifiedInPlace) {
- return mlir::detail::fuseInPlaceMemorySpaceCastImpl(
- getBaseMutable(), ValueRange(), modifiedInPlace);
+FailureOr<std::pair<SmallVector<Value>, bool>>
+CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
+ return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
+ ValueRange());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/MemOpInterfaces.cpp b/mlir/lib/Interfaces/MemOpInterfaces.cpp
index 013d828da1d66..10303185ad833 100644
--- a/mlir/lib/Interfaces/MemOpInterfaces.cpp
+++ b/mlir/lib/Interfaces/MemOpInterfaces.cpp
@@ -55,19 +55,19 @@ LogicalResult mlir::detail::verifyMemorySpaceCastOpInterface(Operation *op) {
return success();
}
-FailureOr<SmallVector<Value>> mlir::detail::fuseInPlaceMemorySpaceCastImpl(
- OpOperand &operand, ValueRange results, bool &modifiedInPlace) {
+FailureOr<std::pair<SmallVector<Value>, bool>>
+mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand,
+ ValueRange results) {
MemorySpaceCastOpInterface castOp =
- MemorySpaceCastOpInterface::getIfFusableCast(operand.get());
+ MemorySpaceCastOpInterface::getIfLosslessCast(operand.get());
- // Bail if the src is not produced by a `MemorySpaceCastOpInterface`.
+ // Bail if the src is not valid.
if (!castOp)
return failure();
// Modify the op.
- modifiedInPlace = true;
operand.set(castOp.getSourcePtr());
- return llvm::to_vector_of<Value>(results);
+ return std::make_pair(llvm::to_vector_of<Value>(results), true);
}
#include "mlir/Interfaces/MemOpInterfaces.cpp.inc"
diff --git a/mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp b/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp
similarity index 53%
rename from mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp
rename to mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp
index 010b88ac12de2..96e0e8d584ea7 100644
--- a/mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp
+++ b/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp
@@ -1,4 +1,4 @@
-//===- FuseMemorySpaceCastsIntoConsumers.cpp - Fuse casts transform -------===//
+//===- BubbleDownMemorySpaceCasts.cpp - Bubble down casts transform -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h"
+#include "mlir/Transforms/BubbleDownMemorySpaceCasts.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/MemOpInterfaces.h"
#include "mlir/Pass/Pass.h"
@@ -17,57 +17,53 @@
using namespace mlir;
namespace mlir {
-#define GEN_PASS_DEF_FUSEMEMORYSPACECASTSINTOCONSUMERS
+#define GEN_PASS_DEF_BUBBLEDOWNMEMORYSPACECASTS
#include "mlir/Transforms/Passes.h.inc"
} // namespace mlir
namespace {
//===----------------------------------------------------------------------===//
-// FuseCastsPattern pattern
+// BubbleDownCastsPattern pattern
//===----------------------------------------------------------------------===//
-/// Pattern to fuse casts into consumer operations.
-struct FuseCastsPattern
- : public OpInterfaceRewritePattern<FuseMemorySpaceCastConsumerOpInterface> {
+/// Pattern to bubble down casts into consumer operations.
+struct BubbleDownCastsPattern
+ : public OpInterfaceRewritePattern<MemorySpaceCastConsumerOpInterface> {
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
- LogicalResult matchAndRewrite(FuseMemorySpaceCastConsumerOpInterface op,
+ LogicalResult matchAndRewrite(MemorySpaceCastConsumerOpInterface op,
PatternRewriter &rewriter) const override {
- bool modifiedInPlace = false;
- FailureOr<SmallVector<Value>> results =
- op.fuseCastOperands(rewriter, modifiedInPlace);
- assert((!failed(results) || !modifiedInPlace) &&
- "expected `modifiedInPlace` to be false on fusion failure");
+ FailureOr<std::pair<SmallVector<Value>, bool>> results =
+ op.bubbleDownCasts(rewriter);
if (failed(results))
return failure();
- if (modifiedInPlace) {
+ if (results->second) {
rewriter.modifyOpInPlace(op, []() {});
return success();
}
- rewriter.replaceOp(op, *results);
+ rewriter.replaceOp(op, results->first);
return success();
}
};
//===----------------------------------------------------------------------===//
-// FuseMemorySpaceCastsIntoConsumers pass
+// BubbleDownMemorySpaceCasts pass
//===----------------------------------------------------------------------===//
-struct FuseMemorySpaceCastsIntoConsumers
- : public impl::FuseMemorySpaceCastsIntoConsumersBase<
- FuseMemorySpaceCastsIntoConsumers> {
- using impl::FuseMemorySpaceCastsIntoConsumersBase<
- FuseMemorySpaceCastsIntoConsumers>::FuseMemorySpaceCastsIntoConsumersBase;
+struct BubbleDownMemorySpaceCasts
+ : public impl::BubbleDownMemorySpaceCastsBase<BubbleDownMemorySpaceCasts> {
+ using impl::BubbleDownMemorySpaceCastsBase<
+ BubbleDownMemorySpaceCasts>::BubbleDownMemorySpaceCastsBase;
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
- populateFuseMemorySpaceCastIntoConsumersPatterns(patterns);
+ populateBubbleDownMemorySpaceCastPatterns(patterns, PatternBenefit());
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
};
} // namespace
-void mlir::populateFuseMemorySpaceCastIntoConsumersPatterns(
- RewritePatternSet &patterns) {
- patterns.add<FuseCastsPattern>(patterns.getContext());
+void mlir::populateBubbleDownMemorySpaceCastPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<BubbleDownCastsPattern>(patterns.getContext(), benefit);
}
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index e9a7d3e4abe99..54b67f5c7a91e 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -6,7 +6,7 @@ add_mlir_library(MLIRTransforms
ControlFlowSink.cpp
CSE.cpp
GenerateRuntimeVerification.cpp
- FuseMemorySpaceCastsIntoConsumers.cpp
+ BubbleDownMemorySpaceCasts.cpp
InlinerPass.cpp
LocationSnapshot.cpp
LoopInvariantCodeMotion.cpp
diff --git a/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir b/mlir/test/Transforms/test-bubble-down-memory-space-casts.mlir
similarity index 99%
rename from mlir/test/Transforms/test-fuse-casts-into-consumers.mlir
rename to mlir/test/Transforms/test-bubble-down-memory-space-casts.mlir
index 7534332b3663a..e4fce89cffb45 100644
--- a/mlir/test/Transforms/test-fuse-casts-into-consumers.mlir
+++ b/mlir/test/Transforms/test-bubble-down-memory-space-casts.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --fuse-memory-space-casts-into-consumers | FileCheck %s
+// RUN: mlir-opt %s --bubble-down-memory-space-casts | FileCheck %s
#map = affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>
>From 7a881b6cf29e2f7d1c7a8fbcf3d0c9edb63cfd8f Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Fri, 19 Sep 2025 09:04:19 -0400
Subject: [PATCH 04/10] Update mlir/include/mlir/Interfaces/MemOpInterfaces.td
Co-authored-by: Copilot <175728472+Copilot at users.noreply.github.com>
---
mlir/include/mlir/Interfaces/MemOpInterfaces.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.td b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
index d097b00c8e80c..575fd0af7e020 100644
--- a/mlir/include/mlir/Interfaces/MemOpInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
@@ -90,7 +90,7 @@ def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
Returns whether the memory-space cast is lossless. A lossless
memory-space cast must not lose any information encoded in the memory
space. An example of such cast, is any conversion to the generic memory
- space.
+ space.
}],
"bool", "isLosslessCast"
>
>From 7f4a7f95e4f9040eec8cc47f2270e7aeb039fea4 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Fri, 19 Sep 2025 13:30:58 +0000
Subject: [PATCH 05/10] fix benefit
---
mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp b/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp
index 96e0e8d584ea7..b9f00d4d4e23e 100644
--- a/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp
+++ b/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp
@@ -56,7 +56,7 @@ struct BubbleDownMemorySpaceCasts
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
- populateBubbleDownMemorySpaceCastPatterns(patterns, PatternBenefit());
+ populateBubbleDownMemorySpaceCastPatterns(patterns, PatternBenefit(1));
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
>From 150b79198701b51d4d7a17fc41ad8eb9f530c256 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Mon, 22 Sep 2025 06:37:44 -0400
Subject: [PATCH 06/10] Update mlir/include/mlir/Interfaces/MemOpInterfaces.td
Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
---
mlir/include/mlir/Interfaces/MemOpInterfaces.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.td b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
index 575fd0af7e020..bdecac2b3512f 100644
--- a/mlir/include/mlir/Interfaces/MemOpInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
@@ -84,7 +84,7 @@ def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
}],
"::mlir::MemorySpaceCastOpInterface", "cloneMemorySpaceCastOp",
(ins "::mlir::OpBuilder &":$builder, "::mlir::PtrLikeTypeInterface":$tgt,
- "::mlir::Value":$src)
+ "::mlir::TypedValue<::mlir::PtrLikeTypeInterface>>":$src)
>,
InterfaceMethod<[{
Returns whether the memory-space cast is lossless. A lossless
>From 231ae132f79208e5bf6ba6888a73d7ba89a15183 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Mon, 22 Sep 2025 10:44:47 +0000
Subject: [PATCH 07/10] fix build
Signed-off-by: Fabian Mora <fabian.mora-cordero at amd.com>
---
mlir/include/mlir/Interfaces/MemOpInterfaces.td | 2 +-
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 15 +++++++--------
2 files changed, 8 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.td b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
index bdecac2b3512f..3a5affb55ebbc 100644
--- a/mlir/include/mlir/Interfaces/MemOpInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
@@ -84,7 +84,7 @@ def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
}],
"::mlir::MemorySpaceCastOpInterface", "cloneMemorySpaceCastOp",
(ins "::mlir::OpBuilder &":$builder, "::mlir::PtrLikeTypeInterface":$tgt,
- "::mlir::TypedValue<::mlir::PtrLikeTypeInterface>>":$src)
+ "::mlir::TypedValue<::mlir::PtrLikeTypeInterface>":$src)
>,
InterfaceMethod<[{
Returns whether the memory-space cast is lossless. A lossless
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 6f276efb84c1c..b600d0d32293c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -163,8 +163,9 @@ bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder,
llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
// Insert a memory-space cast to the original memory space of the op.
- MemorySpaceCastOpInterface result =
- castOp.cloneMemorySpaceCastOp(builder, tgtTy, newOp.getResult());
+ MemorySpaceCastOpInterface result = castOp.cloneMemorySpaceCastOp(
+ builder, tgtTy,
+ cast<TypedValue<PtrLikeTypeInterface>>(newOp.getResult()));
return std::make_pair(SmallVector<Value>({result.getTargetPtr()}), false);
}
@@ -1732,12 +1733,10 @@ bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
}
-MemorySpaceCastOpInterface
-MemorySpaceCastOp::cloneMemorySpaceCastOp(OpBuilder &b,
- PtrLikeTypeInterface tgt, Value src) {
- assert(
- isValidMemorySpaceCast(tgt, cast<PtrLikeTypeInterface>(src.getType())) &&
- "invalid arguments");
+MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(
+ OpBuilder &b, PtrLikeTypeInterface tgt,
+ TypedValue<PtrLikeTypeInterface> src) {
+ assert(isValidMemorySpaceCast(tgt, src.getType()) && "invalid arguments");
return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
}
>From 13214545f1457e64485cb1adcfe38ae4c11710ba Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Mon, 22 Sep 2025 11:24:17 +0000
Subject: [PATCH 08/10] rename isLosslesscast method
---
mlir/include/mlir/Interfaces/MemOpInterfaces.td | 15 +++++++--------
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 6 +++---
mlir/lib/Interfaces/MemOpInterfaces.cpp | 2 +-
3 files changed, 11 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.td b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
index 3a5affb55ebbc..02e01d81912b2 100644
--- a/mlir/include/mlir/Interfaces/MemOpInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
@@ -87,12 +87,11 @@ def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
"::mlir::TypedValue<::mlir::PtrLikeTypeInterface>":$src)
>,
InterfaceMethod<[{
- Returns whether the memory-space cast is lossless. A lossless
- memory-space cast must not lose any information encoded in the memory
- space. An example of such cast, is any conversion to the generic memory
- space.
+ Returns whether the source pointer of the memory-space cast can be used
+ by the `MemorySpaceCastConsumerOpInterface::bubbleDownCasts` method to
+ promote the source pointer and bubble down the cast.
}],
- "bool", "isLosslessCast"
+ "bool", "isSourcePromotable"
>
];
let verify = [{
@@ -102,12 +101,12 @@ def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
let extraClassDeclaration = [{
/// Returns the underlying `MemorySpaceCastOpInterface` op if `value`
/// is produced by a `MemorySpaceCastOpInterface` op, and
- /// `isLosslessCast` returns true, otherwise it returns null.
+ /// `isSourcePromotable` returns true, otherwise it returns null.
static ::mlir::MemorySpaceCastOpInterface
- getIfLosslessCast(::mlir::Value value) {
+ getIfPromotableCast(::mlir::Value value) {
auto op = ::llvm::dyn_cast_or_null<::mlir::MemorySpaceCastOpInterface>(
value.getDefiningOp());
- if (!op || !op.isLosslessCast())
+ if (!op || !op.isSourcePromotable())
return nullptr;
return op;
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b600d0d32293c..cc82602239d48 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -116,7 +116,7 @@ static void constifyIndexValues(SmallVectorImpl<OpFoldResult> &values,
static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src) {
MemorySpaceCastOpInterface castOp =
- MemorySpaceCastOpInterface::getIfLosslessCast(src);
+ MemorySpaceCastOpInterface::getIfPromotableCast(src);
// Bail if the cast is not lossless.
if (!castOp)
@@ -1740,8 +1740,8 @@ MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(
return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
}
-bool MemorySpaceCastOp::isLosslessCast() {
- // The only cast we recognize as lossless is to the generic space.
+bool MemorySpaceCastOp::isSourcePromotable() {
+ // The only cast we recognize as promotable is to the generic space.
return getDest().getType().getMemorySpace() == nullptr;
}
diff --git a/mlir/lib/Interfaces/MemOpInterfaces.cpp b/mlir/lib/Interfaces/MemOpInterfaces.cpp
index 10303185ad833..c29c7a9244651 100644
--- a/mlir/lib/Interfaces/MemOpInterfaces.cpp
+++ b/mlir/lib/Interfaces/MemOpInterfaces.cpp
@@ -59,7 +59,7 @@ FailureOr<std::pair<SmallVector<Value>, bool>>
mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand,
ValueRange results) {
MemorySpaceCastOpInterface castOp =
- MemorySpaceCastOpInterface::getIfLosslessCast(operand.get());
+ MemorySpaceCastOpInterface::getIfPromotableCast(operand.get());
// Bail if the src is not valid.
if (!castOp)
>From d753be0129f123c711ccd341a4176cc799414afa Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Tue, 23 Sep 2025 11:15:11 +0000
Subject: [PATCH 09/10] use std::optional in bubbleDownCasts
Signed-off-by: Fabian Mora <fabian.mora-cordero at amd.com>
---
.../include/mlir/Interfaces/MemOpInterfaces.h | 4 +--
.../mlir/Interfaces/MemOpInterfaces.td | 7 +++--
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 31 ++++++++++---------
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 20 ++++++------
mlir/lib/Interfaces/MemOpInterfaces.cpp | 4 +--
.../Transforms/BubbleDownMemorySpaceCasts.cpp | 6 ++--
6 files changed, 37 insertions(+), 35 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.h b/mlir/include/mlir/Interfaces/MemOpInterfaces.h
index d4ed71e38f4ff..cdc423f5da1a5 100644
--- a/mlir/include/mlir/Interfaces/MemOpInterfaces.h
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.h
@@ -22,10 +22,10 @@ namespace detail {
LogicalResult verifyMemorySpaceCastOpInterface(Operation *op);
/// Tries to bubble-down inplace a `MemorySpaceCastOpInterface` operation
-/// referenced by `operand`. On success, it returns `results` and true. It
+/// referenced by `operand`. On success, it returns `std::nullopt`. It
/// returns failure if `operand` doesn't reference a
/// `MemorySpaceCastOpInterface` op.
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results);
} // namespace detail
} // namespace mlir
diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.td b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
index 02e01d81912b2..0c7aff8cd7ff3 100644
--- a/mlir/include/mlir/Interfaces/MemOpInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
@@ -26,8 +26,9 @@ def MemorySpaceCastConsumerOpInterface :
let methods = [
InterfaceMethod<[{
Attempt to bubble-down the incoming cast-like operands. On success
- returns any new results, and whether the operation was modified in
- place, otherwise it returns failure.
+ returns a `std::optional<SmallVector<Value>>`, otherwise it returns
+ failure. If the optional is `std::nullopt` then the cast was performed
+ in place, otherwise the method returns a list of replacement values.
If new results are produced, these must be compatible with the original
operation results.
@@ -44,7 +45,7 @@ def MemorySpaceCastConsumerOpInterface :
Finally, any implementations of this method have to guarantee that the
IR remains valid at all times.
}],
- "::llvm::FailureOr<std::pair<::llvm::SmallVector<::mlir::Value>, bool>>",
+ "::llvm::FailureOr<std::optional<::llvm::SmallVector<::mlir::Value>>>",
"bubbleDownCasts",
(ins "::mlir::OpBuilder &":$builder)
>,
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index cc82602239d48..349b4deb29023 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -144,7 +144,7 @@ getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src) {
/// Implementation of `bubbleDownCasts` method for memref operations that
/// return a single memref result.
template <typename ConcreteOpTy>
-static FailureOr<std::pair<SmallVector<Value>, bool>>
+static FailureOr<std::optional<SmallVector<Value>>>
bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder,
OpOperand &src) {
auto [castOp, tgtTy, resTy] = getMemorySpaceCastInfo(op.getType(), src.get());
@@ -166,7 +166,8 @@ bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder,
MemorySpaceCastOpInterface result = castOp.cloneMemorySpaceCastOp(
builder, tgtTy,
cast<TypedValue<PtrLikeTypeInterface>>(newOp.getResult()));
- return std::make_pair(SmallVector<Value>({result.getTargetPtr()}), false);
+ return std::optional<SmallVector<Value>>(
+ SmallVector<Value>({result.getTargetPtr()}));
}
//===----------------------------------------------------------------------===//
@@ -600,7 +601,7 @@ OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
return getMemref();
}
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
}
@@ -773,7 +774,7 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
}
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
CastOp::bubbleDownCasts(OpBuilder &builder) {
return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
}
@@ -1669,7 +1670,7 @@ OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
return OpFoldResult();
}
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
LoadOp::bubbleDownCasts(OpBuilder &builder) {
return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
getResult());
@@ -1740,8 +1741,8 @@ MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(
return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
}
+/// The only cast we recognize as promotable is to the generic space.
bool MemorySpaceCastOp::isSourcePromotable() {
- // The only cast we recognize as promotable is to the generic space.
return getDest().getType().getMemorySpace() == nullptr;
}
@@ -2141,7 +2142,7 @@ void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
}
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
ReinterpretCastOp::bubbleDownCasts(OpBuilder &builder) {
return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
}
@@ -2453,7 +2454,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
}
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
ExpandShapeOp::bubbleDownCasts(OpBuilder &builder) {
return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
}
@@ -2679,7 +2680,7 @@ OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
adaptor.getOperands());
}
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
CollapseShapeOp::bubbleDownCasts(OpBuilder &builder) {
return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
}
@@ -2724,7 +2725,7 @@ LogicalResult ReshapeOp::verify() {
return success();
}
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
ReshapeOp::bubbleDownCasts(OpBuilder &builder) {
return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
}
@@ -2746,7 +2747,7 @@ LogicalResult StoreOp::fold(FoldAdaptor adaptor,
return foldMemRefCast(*this, getValueToStore());
}
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
StoreOp::bubbleDownCasts(OpBuilder &builder) {
return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
ValueRange());
@@ -3408,7 +3409,7 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
return {};
}
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
SubViewOp::bubbleDownCasts(OpBuilder &builder) {
return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
}
@@ -3513,7 +3514,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor) {
return {};
}
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
TransposeOp::bubbleDownCasts(OpBuilder &builder) {
return bubbleDownCastsPassthroughOpImpl(*this, builder, getInMutable());
}
@@ -3661,7 +3662,7 @@ void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
}
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
ViewOp::bubbleDownCasts(OpBuilder &builder) {
return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
}
@@ -3711,7 +3712,7 @@ OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
return OpFoldResult();
}
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
AtomicRMWOp::bubbleDownCasts(OpBuilder &builder) {
return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
getResult());
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 77dcb1fc6220e..b2e5a5b1e36cc 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5087,7 +5087,7 @@ void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<TransferReadAfterWriteToBroadcast>(context);
}
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
if (!hasPureBufferSemantics())
return failure();
@@ -5582,7 +5582,7 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
}
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
if (!hasPureBufferSemantics())
return failure();
@@ -5644,7 +5644,7 @@ std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
LoadOp::bubbleDownCasts(OpBuilder &builder) {
return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
getResult());
@@ -5689,7 +5689,7 @@ std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
StoreOp::bubbleDownCasts(OpBuilder &builder) {
return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
ValueRange());
@@ -5749,7 +5749,7 @@ OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
return OpFoldResult();
}
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
getResult());
@@ -5805,7 +5805,7 @@ LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
return memref::foldMemRefCast(*this);
}
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
ValueRange());
@@ -5914,7 +5914,7 @@ void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<GatherFolder, FoldContiguousGather>(context);
}
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
GatherOp::bubbleDownCasts(OpBuilder &builder) {
return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
getResult());
@@ -5982,7 +5982,7 @@ void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ScatterFolder, FoldContiguousScatter>(context);
}
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
ScatterOp::bubbleDownCasts(OpBuilder &builder) {
return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
ValueRange());
@@ -6036,7 +6036,7 @@ void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ExpandLoadFolder>(context);
}
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
getResult());
@@ -6088,7 +6088,7 @@ void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<CompressStoreFolder>(context);
}
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
ValueRange());
diff --git a/mlir/lib/Interfaces/MemOpInterfaces.cpp b/mlir/lib/Interfaces/MemOpInterfaces.cpp
index c29c7a9244651..fe5c717f67bc4 100644
--- a/mlir/lib/Interfaces/MemOpInterfaces.cpp
+++ b/mlir/lib/Interfaces/MemOpInterfaces.cpp
@@ -55,7 +55,7 @@ LogicalResult mlir::detail::verifyMemorySpaceCastOpInterface(Operation *op) {
return success();
}
-FailureOr<std::pair<SmallVector<Value>, bool>>
+FailureOr<std::optional<SmallVector<Value>>>
mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand,
ValueRange results) {
MemorySpaceCastOpInterface castOp =
@@ -67,7 +67,7 @@ mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand,
// Modify the op.
operand.set(castOp.getSourcePtr());
- return std::make_pair(llvm::to_vector_of<Value>(results), true);
+ return std::optional<SmallVector<Value>>();
}
#include "mlir/Interfaces/MemOpInterfaces.cpp.inc"
diff --git a/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp b/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp
index b9f00d4d4e23e..00dac19e37171 100644
--- a/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp
+++ b/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp
@@ -32,15 +32,15 @@ struct BubbleDownCastsPattern
LogicalResult matchAndRewrite(MemorySpaceCastConsumerOpInterface op,
PatternRewriter &rewriter) const override {
- FailureOr<std::pair<SmallVector<Value>, bool>> results =
+ FailureOr<std::optional<SmallVector<Value>>> results =
op.bubbleDownCasts(rewriter);
if (failed(results))
return failure();
- if (results->second) {
+ if (!results->has_value()) {
rewriter.modifyOpInPlace(op, []() {});
return success();
}
- rewriter.replaceOp(op, results->first);
+ rewriter.replaceOp(op, **results);
return success();
}
};
>From e70b0f2b460c6b19152e5a13333c1fd5a97a082e Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Tue, 23 Sep 2025 17:48:06 +0000
Subject: [PATCH 10/10] improve docs
---
.../mlir/Dialect/MemRef/IR/MemRefOps.td | 23 ++++++++++++++++++-
.../mlir/Interfaces/MemOpInterfaces.td | 12 ++++++++--
mlir/include/mlir/Transforms/Passes.td | 8 +++++--
3 files changed, 38 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index c708d7f3d884a..bddf766d8eb21 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1288,7 +1288,7 @@ def LoadOp : MemRef_Op<"load",
def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- DeclareOpInterfaceMethods<MemorySpaceCastOpInterface>,
+ MemorySpaceCastOpInterface,
MemRefsNormalizable,
Pure,
SameOperandsAndResultElementType,
@@ -1326,6 +1326,27 @@ def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
let extraClassDeclaration = [{
Value getViewSource() { return getSource(); }
+
+ //===------------------------------------------------------------------===//
+ // MemorySpaceCastConsumerOpInterface
+ //===------------------------------------------------------------------===//
+ /// Returns the `source` memref.
+ TypedValue<PtrLikeTypeInterface> getSourcePtr();
+ /// Returns the `dest` memref.
+ TypedValue<PtrLikeTypeInterface> getTargetPtr();
+ /// Returns whether the memory-space cast is valid. Only casts between
+ /// memrefs are considered valid. Further, the `tgt` and `src` should only
+ /// differ on the memory-space parameter of the memref type.
+ bool isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
+ PtrLikeTypeInterface src);
+ /// Clones the operation using a new target type and source value.
+ MemorySpaceCastOpInterface cloneMemorySpaceCastOp(
+ OpBuilder &b, PtrLikeTypeInterface tgt,
+ TypedValue<PtrLikeTypeInterface> src);
+ /// Returns whether the `source` value can be promoted by the
+ /// `MemorySpaceCastConsumerOpInterface::bubbleDownCasts` method. The only
+ /// casts the op recognizes as promotable are to the generic memory-space.
+ bool isSourcePromotable();
}];
let hasFolder = 1;
diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.td b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
index 0c7aff8cd7ff3..1a64e97c3412d 100644
--- a/mlir/include/mlir/Interfaces/MemOpInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
@@ -14,13 +14,15 @@
#define MLIR_INTERFACES_MEMOPINTERFACES_TD
include "mlir/IR/OpBase.td"
-include "mlir/Interfaces/SideEffectInterfaces.td"
def MemorySpaceCastConsumerOpInterface :
OpInterface<"MemorySpaceCastConsumerOpInterface"> {
let description = [{
An interface for operations that can consume memory-space cast-like
operations.
+
+ This interface can be used to bubble-down memory-space cast operations,
+ see the `bubble-down-memory-space-casts` pass for an example.
}];
let cppNamespace = "::mlir";
let methods = [
@@ -59,6 +61,10 @@ def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
These operations expect to have a well-defined ptr-like operand, and
a well-defined target ptr-like result.
+
+ This interface also allows to determine whether a cast can be bubbled-down
+ by the `MemorySpaceCastConsumerOpInterface`, allowing control over which
+ casts can be bubbled-down or not.
}];
let cppNamespace = "::mlir";
let methods = [
@@ -91,6 +97,9 @@ def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
Returns whether the source pointer of the memory-space cast can be used
by the `MemorySpaceCastConsumerOpInterface::bubbleDownCasts` method to
promote the source pointer and bubble down the cast.
+
+ For example, a cast operation might decide that all casts to the generic
+ memory-space can be promoted.
}],
"bool", "isSourcePromotable"
>
@@ -98,7 +107,6 @@ def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
let verify = [{
return ::mlir::detail::verifyMemorySpaceCastOpInterface($_op);
}];
- let dependentTraits = [Pure];
let extraClassDeclaration = [{
/// Returns the underlying `MemorySpaceCastOpInterface` op if `value`
/// is produced by a `MemorySpaceCastOpInterface` op, and
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 8f0b80c5e511b..b2b7f20a497e3 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -590,8 +590,12 @@ def BubbleDownMemorySpaceCasts :
let summary = "Bubbles down memory-space cast operations.";
let description = [{
This pass tries to iteratively bubble down all possible memory-space cast
- operations. It does this by looking for `MemorySpaceCastConsumerOpInterface`
- operations, and invoking the interface methods to perform the bubbling down.
+ operations. It is important to note that the determination of which casts
+ are bubbled down is based on the interfaces
+ `MemorySpaceCastConsumerOpInterface`, and `MemorySpaceCastOpInterface`, and
+ not the pass. The pass only looks for operations implementing the
+ `MemorySpaceCastConsumerOpInterface` interface, and invoking the interface
+ methods to perform the bubbling down.
Example:
More information about the Mlir-commits
mailing list