[Mlir-commits] [mlir] [mlir][mem2reg] Promote memory slots through transparent view operations (PR #196924)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 11 03:59:18 PDT 2026
llvmorg-github-actions[bot] wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: jeanPerier
<details>
<summary>Changes</summary>
This patch enables mem2reg to operate on load/store that are made through cast/views of the memory slot (not directly using the SSA value produced by the allocation op).
This is done by adding two APIs to `PromotableOpInterface` that must be implemented by operation that defines view of the slot and through which mem2reg should still happen. The first API is used to return a new slot (the "slot view") that will be used when visiting load/stores operation on its results. The second API allows inserting value casts (which will be bitcasts for instance in for the fir.convert use case) that the mem2reg framework will generate around load/stores so that the load/store APIs are provided with reaching definitions consistent with the element type of the slot view they are operating on.
Assisted by: Claude
---
Patch is 23.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/196924.diff
7 Files Affected:
- (modified) mlir/include/mlir/Interfaces/MemorySlotInterfaces.h (+43)
- (modified) mlir/include/mlir/Interfaces/MemorySlotInterfaces.td (+38)
- (modified) mlir/lib/Interfaces/MemorySlotInterfaces.cpp (+109)
- (modified) mlir/lib/Transforms/Mem2Reg.cpp (+50-13)
- (modified) mlir/test/Transforms/mem2reg.mlir (+91)
- (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+56)
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+26)
``````````diff
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
index 7bebfc9a30064..2163593ef823e 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
@@ -30,6 +30,15 @@ struct DestructurableMemorySlot : public MemorySlot {
DenseMap<Attribute, Type> subelementTypes;
};
+/// Description of a memory slot view produced by a `PromotableOpInterface`
+/// operation: `slotPointerOperand` is the operand viewed by the op,
+/// `view.ptr` is the result aliasing it, and `view.elemType` is the type
+/// at which `view.ptr` aliases the underlying slot.
+struct PromotableSlotView {
+ Value slotPointerOperand;
+ MemorySlot view;
+};
+
/// Returned by operation promotion logic requesting the deletion of an
/// operation.
enum class DeletionKind {
@@ -44,4 +53,38 @@ enum class DeletionKind {
#include "mlir/Interfaces/MemorySlotOpInterfaces.h.inc"
#include "mlir/Interfaces/MemorySlotTypeInterfaces.h.inc"
+namespace mlir {
+
+/// Returns true if `value` is `rootSlot.ptr` or a transitive view of it,
+/// following `PromotableOpInterface::getPromotableSlotView` chains. The
+/// element type at which `value` aliases the slot is written to
+/// `*outViewElemType` (equal to `rootSlot.elemType` when the chain is empty).
+bool isPromotableSlotView(Value value, const MemorySlot &rootSlot,
+ Type *outViewElemType = nullptr);
+
+/// Returns a MemorySlot whose `ptr` is the operand of `op` that is a
+/// (possibly transitive) view of `rootSlot.ptr`, with `elemType` equal to
+/// the type at which that operand aliases the slot. Mem2Reg uses this to
+/// hand each `PromotableMemOpInterface` op a slot description tailored to
+/// its memref operand. Returns `nullopt` if no operand is a view of
+/// `rootSlot`.
+std::optional<MemorySlot> getOpViewSlot(Operation *op,
+ const MemorySlot &rootSlot);
+
+/// Converts `slotValue` (typed at `rootSlot.elemType`) to the type at which
+/// `viewPtr` aliases `rootSlot`, by chaining
+/// `PromotableOpInterface::convertSlotValue` calls along the view chain
+/// root-to-leaf. Returns `nullptr` if any step's converter fails.
+Value convertSlotValueToViewValue(Value slotValue, Value viewPtr,
+ const MemorySlot &rootSlot,
+ OpBuilder &builder);
+
+/// Inverse of `convertSlotValueToViewValue`: converts `viewValue` back to
+/// `rootSlot.elemType` along the chain leaf-to-root.
+Value convertViewValueToSlotValue(Value viewValue, Value viewPtr,
+ const MemorySlot &rootSlot,
+ OpBuilder &builder);
+
+} // namespace mlir
+
#endif // MLIR_INTERFACES_MEMORYSLOTINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index 801555fba4947..a8084ab7bf189 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -263,6 +263,44 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
(ins "::llvm::ArrayRef<std::pair<::mlir::Operation*, ::mlir::Value>>":$mutatedDefs,
"::mlir::OpBuilder &":$builder), [{}], [{ return; }]
>,
+ InterfaceMethod<[{
+ Describes this operation as a transparent view of a memory slot
+ reached through one of its operands.
+
+ The returned `view.ptr` must be a result of this operation;
+ `view.elemType` is the type at which `view.ptr` aliases the slot
+ pointed to by `slotPointerOperand`, possibly different from the
+ underlying slot's element type.
+
+ Returning a view here implies `convertSlotValue` can bridge
+ between `slotPointerOperand`'s element type and `view.elemType`
+ in both directions; if no such conversion exists, return
+ `std::nullopt`.
+
+ No IR mutation is allowed in this method.
+ }],
+ "::std::optional<::mlir::PromotableSlotView>",
+ "getPromotableSlotView",
+ (ins), [{}],
+ [{ return std::nullopt; }]
+ >,
+ InterfaceMethod<[{
+ Builds a value of `targetType` from `value`, bridging the
+ underlying slot's element type and the view's element type.
+ Mem2reg calls this in both directions (load: slot → view; store:
+ view → slot).
+ }],
+ "::mlir::Value",
+ "convertSlotValue",
+ (ins "::mlir::Value":$value,
+ "::mlir::Type":$targetType,
+ "::mlir::OpBuilder &":$builder), [{}],
+ [{
+ if (value.getType() == targetType)
+ return value;
+ return ::mlir::Value{};
+ }]
+ >,
];
}
diff --git a/mlir/lib/Interfaces/MemorySlotInterfaces.cpp b/mlir/lib/Interfaces/MemorySlotInterfaces.cpp
index 2c9e23250e9ee..1e1961eeca07c 100644
--- a/mlir/lib/Interfaces/MemorySlotInterfaces.cpp
+++ b/mlir/lib/Interfaces/MemorySlotInterfaces.cpp
@@ -8,5 +8,114 @@
#include "mlir/Interfaces/MemorySlotInterfaces.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
+
#include "mlir/Interfaces/MemorySlotOpInterfaces.cpp.inc"
#include "mlir/Interfaces/MemorySlotTypeInterfaces.cpp.inc"
+
+using namespace mlir;
+
+namespace {
+/// One step in a view chain, leaf-first. `inputElemType` is the elemType
+/// of the slot one step closer to root; `outputElemType` is the elemType
+/// this step exposes.
+struct ViewStep {
+ PromotableOpInterface view;
+ Type inputElemType;
+ Type outputElemType;
+};
+} // namespace
+
+/// Walks back from `value` to `rootSlot.ptr` along
+/// `getPromotableSlotView` chains. On success, populates `chainOut` with
+/// the view ops leaf-to-root and writes the type at which `value` aliases
+/// the underlying slot to `*outViewElemType`.
+static bool walkPromotableSlotViewChain(Value value, const MemorySlot &rootSlot,
+ SmallVectorImpl<ViewStep> &chainOut,
+ Type *outViewElemType) {
+ if (value == rootSlot.ptr) {
+ if (outViewElemType)
+ *outViewElemType = rootSlot.elemType;
+ return true;
+ }
+
+ Value current = value;
+ Type aliasElemType{};
+ llvm::SmallPtrSet<Value, 4> seen;
+ while (current != rootSlot.ptr) {
+ if (!seen.insert(current).second)
+ return false;
+ auto promotable =
+ dyn_cast_or_null<PromotableOpInterface>(current.getDefiningOp());
+ if (!promotable)
+ return false;
+ std::optional<PromotableSlotView> info = promotable.getPromotableSlotView();
+ if (!info || info->view.ptr != current)
+ return false;
+ if (!aliasElemType)
+ aliasElemType = info->view.elemType;
+ chainOut.push_back(ViewStep{promotable, /*inputElemType=*/Type{},
+ /*outputElemType=*/info->view.elemType});
+ current = info->slotPointerOperand;
+ }
+
+ // Fill in each step's `inputElemType` from the previous step's output
+ // (or `rootSlot.elemType` for the root-most step).
+ Type prevOutput = rootSlot.elemType;
+ for (ViewStep &step : llvm::reverse(chainOut)) {
+ step.inputElemType = prevOutput;
+ prevOutput = step.outputElemType;
+ }
+
+ if (outViewElemType)
+ *outViewElemType = aliasElemType ? aliasElemType : rootSlot.elemType;
+ return true;
+}
+
+bool mlir::isPromotableSlotView(Value value, const MemorySlot &rootSlot,
+ Type *outViewElemType) {
+ SmallVector<ViewStep> chain;
+ return walkPromotableSlotViewChain(value, rootSlot, chain, outViewElemType);
+}
+
+std::optional<MemorySlot> mlir::getOpViewSlot(Operation *op,
+ const MemorySlot &rootSlot) {
+ for (Value operand : op->getOperands()) {
+ Type viewElemType;
+ if (isPromotableSlotView(operand, rootSlot, &viewElemType))
+ return MemorySlot{operand, viewElemType};
+ }
+ return std::nullopt;
+}
+
+Value mlir::convertSlotValueToViewValue(Value slotValue, Value viewPtr,
+ const MemorySlot &rootSlot,
+ OpBuilder &builder) {
+ SmallVector<ViewStep> chain;
+ if (!walkPromotableSlotViewChain(viewPtr, rootSlot, chain, /*out=*/nullptr))
+ return {};
+ Value current = slotValue;
+ // Root-to-leaf walk: reverse the leaf-first chain.
+ for (ViewStep &step : llvm::reverse(chain)) {
+ current = step.view.convertSlotValue(current, step.outputElemType, builder);
+ if (!current)
+ return {};
+ }
+ return current;
+}
+
+Value mlir::convertViewValueToSlotValue(Value viewValue, Value viewPtr,
+ const MemorySlot &rootSlot,
+ OpBuilder &builder) {
+ SmallVector<ViewStep> chain;
+ if (!walkPromotableSlotViewChain(viewPtr, rootSlot, chain, /*out=*/nullptr))
+ return {};
+ Value current = viewValue;
+ for (ViewStep &step : chain) {
+ current = step.view.convertSlotValue(current, step.inputElemType, builder);
+ if (!current)
+ return {};
+ }
+ return current;
+}
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 40d08d869a9e2..4f7039bede304 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -400,14 +400,15 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
return failure();
regionsWithDirectUse.insert(user->getParentRegion());
} else if (auto promotable = dyn_cast<PromotableMemOpInterface>(user)) {
- if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses,
+ MemorySlot viewSlot = getOpViewSlot(user, slot).value_or(slot);
+ if (!promotable.canUsesBeRemoved(viewSlot, blockingUses, newBlockingUses,
dataLayout))
return failure();
// Operations that interact with the slot's memory will be promoted using
// a reaching definition. Therefore, the operation must be within a region
// where the reaching definition can be computed.
- if (promotable.storesTo(slot))
+ if (promotable.storesTo(viewSlot))
regionsWithDirectStore.insert(user->getParentRegion());
else
regionsWithDirectUse.insert(user->getParentRegion());
@@ -515,11 +516,17 @@ MemorySlotPromotionAnalyzer::computeInfo() {
// Compute the blocks containing a store for each region, either directly or
// inherited from a nested region. As a side effect, `definingBlocks` contains
// all regions with at least one store.
+ //
+ // Iterating `info.userToBlockingUses` lets this also pick up stores that
+ // reach the slot through chains of views (`getPromotableSlotView`).
DenseMap<Region *, SmallPtrSet<Block *, 16>> definingBlocks;
- for (Operation *user : slot.ptr.getUsers())
- if (auto storeOp = dyn_cast<PromotableMemOpInterface>(user))
- if (storeOp.storesTo(slot))
- definingBlocks[user->getParentRegion()].insert(user->getBlock());
+ for (auto &[region, opsMap] : info.userToBlockingUses)
+ for (auto &[user, _blockingUses] : opsMap)
+ if (auto storeOp = dyn_cast<PromotableMemOpInterface>(user)) {
+ MemorySlot viewSlot = getOpViewSlot(user, slot).value_or(slot);
+ if (storeOp.storesTo(viewSlot))
+ definingBlocks[region].insert(user->getBlock());
+ }
for (auto &[region, regionInfo] : info.regionsToPromote)
if (regionInfo.hasValueStores)
definingBlocks[region->getParentRegion()].insert(
@@ -550,18 +557,37 @@ Value MemorySlotPromoter::promoteInBlock(Block *block, Value reachingDef) {
if (info.userToBlockingUses[memOp->getParentRegion()].contains(memOp))
reachingDefs.insert({memOp, reachingDef});
- if (memOp.storesTo(slot)) {
+ MemorySlot viewSlot = getOpViewSlot(memOp, slot).value_or(slot);
+ if (memOp.storesTo(viewSlot)) {
builder.setInsertionPointAfter(memOp);
// To not expose default value creation to the interfaces, if we have
// no reaching definition by now, we set it to the default value.
// This is slightly too eager as `getStored` may not need it.
if (!reachingDef)
reachingDef = getOrCreateDefaultValue();
- Value stored = memOp.getStored(slot, builder, reachingDef, dataLayout);
+ Value reachingDefAtStore = reachingDef;
+ if (slot.ptr != viewSlot.ptr) {
+ // The store sees the slot at `viewSlot.elemType`; convert the
+ // reaching definition (at root elem type) before handing it to
+ // `getStored`.
+ reachingDefAtStore = convertSlotValueToViewValue(
+ reachingDef, viewSlot.ptr, slot, builder);
+ assert(reachingDefAtStore && "convertSlotValue contract violation");
+ }
+ Value stored =
+ memOp.getStored(viewSlot, builder, reachingDefAtStore, dataLayout);
assert(stored && "a memory operation storing to a slot must provide a "
"new definition of the slot");
- reachingDef = stored;
+ // `replacedValuesMap` keeps `stored` at `viewSlot.elemType` for
+ // `visitReplacedValues`; the new reaching definition is tracked at
+ // the root slot's elem type, so convert `stored` back.
replacedValuesMap[memOp] = stored;
+ if (viewSlot.ptr != slot.ptr) {
+ stored =
+ convertViewValueToSlotValue(stored, viewSlot.ptr, slot, builder);
+ assert(stored && "convertSlotValue contract violation");
+ }
+ reachingDef = stored;
}
}
@@ -763,11 +789,22 @@ void MemorySlotPromoter::removeBlockingUses(Region *region) {
reachingDef = getOrCreateDefaultValue();
builder.setInsertionPointAfter(toPromote);
- if (toPromoteMemOp.removeBlockingUses(slot, blockingUsesMap[toPromote],
- builder, reachingDef,
- dataLayout) == DeletionKind::Delete)
+ MemorySlot viewSlot = getOpViewSlot(toPromote, slot).value_or(slot);
+ Value reachingDefAtBlockingUse = reachingDef;
+ if (viewSlot.ptr != slot.ptr) {
+ // Convert the reaching definition to `viewSlot.elemType` to match
+ // what the impl sees. Skipped when the chain is empty; any cast
+ // unused by the impl will be cleaned up by DCE.
+ reachingDefAtBlockingUse = convertSlotValueToViewValue(
+ reachingDef, viewSlot.ptr, slot, builder);
+ assert(reachingDefAtBlockingUse &&
+ "convertSlotValue contract violation");
+ }
+ if (toPromoteMemOp.removeBlockingUses(
+ viewSlot, blockingUsesMap[toPromote], builder,
+ reachingDefAtBlockingUse, dataLayout) == DeletionKind::Delete)
toErase.insert(toPromote);
- if (toPromoteMemOp.storesTo(slot))
+ if (toPromoteMemOp.storesTo(viewSlot))
if (Value replacedValue = replacedValuesMap[toPromoteMemOp])
replacedValues.push_back({toPromoteMemOp, replacedValue});
continue;
diff --git a/mlir/test/Transforms/mem2reg.mlir b/mlir/test/Transforms/mem2reg.mlir
index 94b721cf28dcf..edc0a37807b7b 100644
--- a/mlir/test/Transforms/mem2reg.mlir
+++ b/mlir/test/Transforms/mem2reg.mlir
@@ -181,3 +181,94 @@ func.func @poison_insertion_point(%val: f64) {
^bb3:
return
}
+
+// -----
+
+// Verifies that mem2reg promotes a memory slot whose stores and loads are
+// reached through a transparent view operation that exposes itself via
+// PromotableOpInterface::getPromotableSlotView. The conditional store on
+// the view in ^bb1 must be discovered as a defining block, otherwise the
+// merge point at ^bb2 would not get a block argument and the promotion
+// would silently drop the conditional update.
+
+// CHECK-LABEL: func.func @promotable_through_view
+// CHECK-SAME: (%[[A:.*]]: i32, %[[COND:.*]]: i1) -> i32
+// CHECK-NOT: test.multi_slot_alloca
+// CHECK-NOT: test.transparent_view
+// CHECK: %[[C42:.*]] = arith.constant 42 : i32
+// CHECK: cf.cond_br %[[COND]], ^[[BB1:.*]], ^[[BB2:.*]](%[[C42]] : i32)
+// CHECK: ^[[BB1]]:
+// CHECK: cf.br ^[[BB2]](%[[A]] : i32)
+// CHECK: ^[[BB2]](%[[MERGE:.*]]: i32):
+// CHECK: return %[[MERGE]] : i32
+func.func @promotable_through_view(%a: i32, %cond: i1) -> i32 {
+ %c42 = arith.constant 42 : i32
+ %slot = test.multi_slot_alloca : () -> memref<i32>
+ %view = test.transparent_view %slot : (memref<i32>) -> memref<i32>
+ memref.store %c42, %view[] : memref<i32>
+ cf.cond_br %cond, ^bb1, ^bb2
+^bb1:
+ memref.store %a, %view[] : memref<i32>
+ cf.br ^bb2
+^bb2:
+ %v = memref.load %view[] : memref<i32>
+ return %v : i32
+}
+
+// -----
+
+// Type-changing transparent view: the store and load see the slot at f32
+// while the underlying allocation is at i32. mem2reg materialises an
+// `unrealized_conversion_cast` (the view op's `convertSlotValue`) at the
+// store (f32 → i32 to update the reaching def at the slot's elem type) and
+// at the load (i32 → f32 to feed the load's f32 result type).
+
+// CHECK-LABEL: func.func @promotable_through_cast_view
+// CHECK-SAME: (%[[A:.*]]: f32) -> f32
+// CHECK-NOT: test.multi_slot_alloca
+// CHECK-NOT: test.transparent_cast_view
+// CHECK: %[[I32:.*]] = builtin.unrealized_conversion_cast %[[A]] : f32 to i32
+// CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[I32]] : i32 to f32
+// CHECK: return %{{.*}} : f32
+func.func @promotable_through_cast_view(%a: f32) -> f32 {
+ %slot = test.multi_slot_alloca : () -> memref<i32>
+ %view = test.transparent_cast_view %slot : (memref<i32>) -> memref<f32>
+ memref.store %a, %view[] : memref<f32>
+ %v = memref.load %view[] : memref<f32>
+ return %v : f32
+}
+
+// -----
+
+// Same as above with a conditional store across blocks. The merge-point
+// block argument is at the root slot's element type (i32), and the
+// `convertSlotValue` casts are inserted at the store sites (f32 → i32) so
+// the merge argument can carry the conditional update; the load site
+// inserts the inverse cast (i32 → f32) for its result.
+
+// CHECK-LABEL: func.func @promotable_through_cast_view_blocks
+// CHECK-SAME: (%[[A:.*]]: f32, %[[COND:.*]]: i1) -> f32
+// CHECK-NOT: test.multi_slot_alloca
+// CHECK-NOT: test.transparent_cast_view
+// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[CST_I32:.*]] = builtin.unrealized_conversion_cast %[[CST]] : f32 to i32
+// CHECK: cf.cond_br %[[COND]], ^[[BB1:.*]], ^[[BB2:.*]](%[[CST_I32]] : i32)
+// CHECK: ^[[BB1]]:
+// CHECK: %[[A_I32:.*]] = builtin.unrealized_conversion_cast %[[A]] : f32 to i32
+// CHECK: cf.br ^[[BB2]](%[[A_I32]] : i32)
+// CHECK: ^[[BB2]](%[[MERGE:.*]]: i32):
+// CHECK: %[[MERGE_F32:.*]] = builtin.unrealized_conversion_cast %[[MERGE]] : i32 to f32
+// CHECK: return %[[MERGE_F32]] : f32
+func.func @promotable_through_cast_view_blocks(%a: f32, %cond: i1) -> f32 {
+ %cst = arith.constant 1.0 : f32
+ %slot = test.multi_slot_alloca : () -> memref<i32>
+ %view = test.transparent_cast_view %slot : (memref<i32>) -> memref<f32>
+ memref.store %cst, %view[] : memref<f32>
+ cf.cond_br %cond, ^bb1, ^bb2
+^bb1:
+ memref.store %a, %view[] : memref<f32>
+ cf.br ^bb2
+^bb2:
+ %v = memref.load %view[] : memref<f32>
+ return %v : f32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index a3ff397ac26db..ca7677a9663e7 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1769,6 +1769,62 @@ TestMultiSlotAlloca::handleDestructuringComplete(
return createNewMultiAllocaWithoutSlot(slot, builder, *this);
}
+//===----------------------------------------------------------------------===//
+// TestTransparentView
+//===----------------------------------------------------------------------===//
+
+std::optional<PromotableSlotView> TestTransparentView::getPromotableSlotView() {
+ Type elemType = cast<MemRefType>(getResult().getType()).getElementType();
+ return PromotableSlotView{getSource(), MemorySlot{getResult(), elemType}};
+}
+
+bool TestTransparentView::canUsesBeRemoved(
+ const SmallPtrSetImpl<OpOperand *> &blockingUses,
+ SmallVectorImpl<OpOperand *> &newBlockingUses,
+ const DataLayout &dataLayout) {
+ for (OpOperand &use : getResult().getUses())
+ newBlockingUses.p...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/196924
More information about the Mlir-commits
mailing list