[Mlir-commits] [mlir] [mlir][memref] Rename `memref.load`/`store` members to align with vector dialect (PR #173185)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 22 04:08:20 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
Author: Ivan Butygin (Hardcode84)
<details>
<summary>Changes</summary>
Rename memref load/store `memref`/`value` fields to `base`/`valueToStore`, same as vector load/store. Add backward compatibility accessors to `extraClassDeclaration`. Unfortunately we cannot add `extraClassDeclaration` to op adaptors so this is still a breaking change.
Discussion https://discourse.llvm.org/t/rfc-aligning-member-names-between-vector-and-memref-load-store-ops/89175
---
Full diff: https://github.com/llvm/llvm-project/pull/173185.diff
7 Files Affected:
- (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+15-10)
- (modified) mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp (+3-3)
- (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+7-7)
- (modified) mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (+9-8)
- (modified) mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp (+8-8)
- (modified) mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp (+4-4)
- (modified) mlir/test/python/dialects/openacc.py (+2-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 0bf22928f6900..cf5c93fc136d9 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1233,8 +1233,8 @@ def MemRef_GlobalOp : MemRef_Op<"global", [Symbol,
//===----------------------------------------------------------------------===//
def LoadOp : MemRef_Op<"load",
- [TypesMatchWith<"result type matches element type of 'memref'",
- "memref", "result",
+ [TypesMatchWith<"result type matches element type of 'base'",
+ "base", "result",
"::llvm::cast<MemRefType>($_self).getElementType()">,
MemRefsNormalizable,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>,
@@ -1273,7 +1273,7 @@ def LoadOp : MemRef_Op<"load",
}];
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
- [MemRead]>:$memref,
+ [MemRead]>:$base,
Variadic<Index>:$indices,
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
OptionalAttr<IntValidAlignment<I64Attr>>:$alignment);
@@ -1310,6 +1310,8 @@ def LoadOp : MemRef_Op<"load",
let results = (outs AnyType:$result);
let extraClassDeclaration = [{
+ ::mlir::TypedValue<::mlir::MemRefType> getMemref() { return getBase(); }
+ ::mlir::OpOperand &getMemrefMutable() { return getBaseMutable(); }
Value getMemRef() { return getOperand(0); }
void setMemRef(Value value) { setOperand(0, value); }
MemRefType getMemRefType() {
@@ -1320,7 +1322,7 @@ def LoadOp : MemRef_Op<"load",
let hasFolder = 1;
let hasVerifier = 1;
- let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)";
+ let assemblyFormat = "$base `[` $indices `]` attr-dict `:` type($base)";
}
//===----------------------------------------------------------------------===//
@@ -2010,8 +2012,8 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
//===----------------------------------------------------------------------===//
def MemRef_StoreOp : MemRef_Op<"store",
- [TypesMatchWith<"type of 'value' matches element type of 'memref'",
- "memref", "value",
+ [TypesMatchWith<"type of 'valueToStore' matches element type of 'base'",
+ "base", "valueToStore",
"::llvm::cast<MemRefType>($_self).getElementType()">,
MemRefsNormalizable,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>,
@@ -2046,9 +2048,9 @@ def MemRef_StoreOp : MemRef_Op<"store",
```
}];
- let arguments = (ins AnyType:$value,
+ let arguments = (ins AnyType:$valueToStore,
Arg<AnyMemRef, "the reference to store to",
- [MemWrite]>:$memref,
+ [MemWrite]>:$base,
Variadic<Index>:$indices,
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
OptionalAttr<IntValidAlignment<I64Attr>>:$alignment);
@@ -2070,7 +2072,10 @@ def MemRef_StoreOp : MemRef_Op<"store",
];
let extraClassDeclaration = [{
- Value getValueToStore() { return getOperand(0); }
+ ::mlir::TypedValue<::mlir::MemRefType> getMemref() { return getBase(); }
+ ::mlir::OpOperand &getMemrefMutable() { return getBaseMutable(); }
+
+ Value getValue() { return getOperand(0); }
Value getMemRef() { return getOperand(1); }
void setMemRef(Value value) { setOperand(1, value); }
@@ -2083,7 +2088,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
let hasVerifier = 1;
let assemblyFormat = [{
- $value `,` $memref `[` $indices `]` attr-dict `:` type($memref)
+ $valueToStore `,` $base `[` $indices `]` attr-dict `:` type($base)
}];
}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 0a382d812f362..caefe0bde3cff 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -342,7 +342,7 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
}
auto arrayValue =
- dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
+ dyn_cast<TypedValue<emitc::ArrayType>>(operands.getBase());
if (!arrayValue) {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
}
@@ -362,7 +362,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
auto arrayValue =
- dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
+ dyn_cast<TypedValue<emitc::ArrayType>>(operands.getBase());
if (!arrayValue) {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
}
@@ -370,7 +370,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
auto subscript = emitc::SubscriptOp::create(
rewriter, op.getLoc(), arrayValue, operands.getIndices());
rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
- operands.getValue());
+ operands.getValueToStore());
return success();
}
};
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 91a0c4b55fa84..af1e25add0167 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -941,9 +941,9 @@ struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
// Per memref.load spec, the indices must be in-bounds:
// 0 <= idx < dim_size, and additionally all offsets are non-negative,
// hence inbounds and nuw are used when lowering to llvm.getelementptr.
- Value dataPtr = getStridedElementPtr(rewriter, loadOp.getLoc(), type,
- adaptor.getMemref(),
- adaptor.getIndices(), kNoWrapFlags);
+ Value dataPtr =
+ getStridedElementPtr(rewriter, loadOp.getLoc(), type, adaptor.getBase(),
+ adaptor.getIndices(), kNoWrapFlags);
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
loadOp, typeConverter->convertType(type.getElementType()), dataPtr,
loadOp.getAlignment().value_or(0), false, loadOp.getNontemporal());
@@ -965,11 +965,11 @@ struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
// 0 <= idx < dim_size, and additionally all offsets are non-negative,
// hence inbounds and nuw are used when lowering to llvm.getelementptr.
Value dataPtr =
- getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(),
+ getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getBase(),
adaptor.getIndices(), kNoWrapFlags);
- rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
- op.getAlignment().value_or(0),
- false, op.getNontemporal());
+ rewriter.replaceOpWithNewOp<LLVM::StoreOp>(
+ op, adaptor.getValueToStore(), dataPtr, op.getAlignment().value_or(0),
+ false, op.getNontemporal());
return success();
}
};
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index a90dcc8cc3ef1..03e896ce5f9b6 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -555,7 +555,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
Value accessChain =
- spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
+ spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
adaptor.getIndices(), loc, rewriter);
if (!accessChain)
@@ -682,7 +682,7 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
"failed to lower memref in image storage class to storage buffer");
Value loadPtr = spirv::getElementPtr(
- *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
+ *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getBase(),
adaptor.getIndices(), loadOp.getLoc(), rewriter);
if (!loadPtr)
@@ -743,7 +743,7 @@ ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
return rewriter.notifyMatchFailure(
loadOp, "failed to lower memref in non-image storage class to image");
- Value loadPtr = adaptor.getMemref();
+ Value loadPtr = adaptor.getBase();
auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
if (failed(memoryRequirements))
return rewriter.notifyMatchFailure(
@@ -824,7 +824,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
auto loc = storeOp.getLoc();
auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
Value accessChain =
- spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
+ spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
adaptor.getIndices(), loc, rewriter);
if (!accessChain)
@@ -874,7 +874,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
storeOp, "failed to determine memory requirements");
auto [memoryAccess, alignment] = *memoryRequirements;
- Value storeVal = adaptor.getValue();
+ Value storeVal = adaptor.getValueToStore();
if (isBool)
storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
@@ -915,7 +915,8 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
clearBitsMask =
rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
- Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
+ Value storeVal =
+ shiftValue(loc, adaptor.getValueToStore(), offset, mask, rewriter);
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
srcBits, dstBits, rewriter);
std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
@@ -1020,7 +1021,7 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
if (memrefType.getElementType().isSignlessInteger())
return rewriter.notifyMatchFailure(storeOp, "signless int");
auto storePtr = spirv::getElementPtr(
- *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
+ *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getBase(),
adaptor.getIndices(), storeOp.getLoc(), rewriter);
if (!storePtr)
@@ -1033,7 +1034,7 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
auto [memoryAccess, alignment] = *memoryRequirements;
rewriter.replaceOpWithNewOp<spirv::StoreOp>(
- storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
+ storeOp, storePtr, adaptor.getValueToStore(), memoryAccess, alignment);
return success();
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 09d4ffa61738a..1030faa212f11 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -284,7 +284,7 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
LogicalResult
matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
+ auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
auto convertedElementType = convertedType.getElementType();
auto oldElementType = op.getMemRefType().getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
@@ -298,7 +298,7 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
// Special case 0-rank memref loads.
Value bitsLoad;
if (convertedType.getRank() == 0) {
- bitsLoad = memref::LoadOp::create(rewriter, loc, adaptor.getMemref(),
+ bitsLoad = memref::LoadOp::create(rewriter, loc, adaptor.getBase(),
ValueRange{});
} else {
// Linearize the indices of the original load instruction. Do not account
@@ -307,7 +307,7 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
Value newLoad = memref::LoadOp::create(
- rewriter, loc, adaptor.getMemref(),
+ rewriter, loc, adaptor.getBase(),
getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
dstBits));
@@ -414,7 +414,7 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
LogicalResult
matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
+ auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
int srcBits = op.getMemRefType().getElementTypeBitWidth();
int dstBits = convertedType.getElementTypeBitWidth();
auto dstIntegerType = rewriter.getIntegerType(dstBits);
@@ -426,7 +426,7 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
Location loc = op.getLoc();
// Pad the input value with 0s on the left.
- Value input = adaptor.getValue();
+ Value input = adaptor.getValueToStore();
if (!input.getType().isInteger()) {
input = arith::BitcastOp::create(
rewriter, loc,
@@ -440,7 +440,7 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
// Special case 0-rank memref stores. No need for masking.
if (convertedType.getRank() == 0) {
memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::assign,
- extendedInput, adaptor.getMemref(),
+ extendedInput, adaptor.getBase(),
ValueRange{});
rewriter.eraseOp(op);
return success();
@@ -460,10 +460,10 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
// Clear destination bits
memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::andi,
- writeMask, adaptor.getMemref(), storeIndices);
+ writeMask, adaptor.getBase(), storeIndices);
// Write srcs bits to destination
memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::ori,
- alignedVal, adaptor.getMemref(), storeIndices);
+ alignedVal, adaptor.getBase(), storeIndices);
rewriter.eraseOp(op);
return success();
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
index 6f815ae46904c..79282ecd79d5e 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
@@ -66,9 +66,9 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
op.getMemRefType()));
- rewriter.replaceOpWithNewOp<memref::LoadOp>(
- op, newResTy, adaptor.getMemref(), adaptor.getIndices(),
- op.getNontemporal());
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(op, newResTy, adaptor.getBase(),
+ adaptor.getIndices(),
+ op.getNontemporal());
return success();
}
};
@@ -90,7 +90,7 @@ struct ConvertMemRefStore final : OpConversionPattern<memref::StoreOp> {
op.getMemRefType()));
rewriter.replaceOpWithNewOp<memref::StoreOp>(
- op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices(),
+ op, adaptor.getValueToStore(), adaptor.getBase(), adaptor.getIndices(),
op.getNontemporal());
return success();
}
diff --git a/mlir/test/python/dialects/openacc.py b/mlir/test/python/dialects/openacc.py
index 8f2142a74c7a1..d3af869889e10 100644
--- a/mlir/test/python/dialects/openacc.py
+++ b/mlir/test/python/dialects/openacc.py
@@ -121,8 +121,8 @@ def testParallelMemcpy():
with InsertionPoint(loop_block):
idx = arith.index_cast(out=IndexType.get(), in_=loop_block.arguments[0])
- val = memref.load(memref=copied, indices=[idx])
- memref.store(value=val, memref=created, indices=[idx])
+ val = memref.load(base=copied, indices=[idx])
+ memref.store(value_to_store=val, base=created, indices=[idx])
openacc.YieldOp([])
openacc.YieldOp([])
``````````
</details>
https://github.com/llvm/llvm-project/pull/173185
More information about the Mlir-commits
mailing list