[Mlir-commits] [mlir] [mlir][memref] Rename `memref.load`/`store` members to align with vector dialect (PR #173185)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 22 04:08:20 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-memref

Author: Ivan Butygin (Hardcode84)

<details>
<summary>Changes</summary>

Rename memref load/store `memref`/`value` fields to `base`/`valueToStore`, same as vector load/store. Add backward compatibility accessors to `extraClassDeclaration`. Unfortunately we cannot add `extraClassDeclaration` to op adaptors so this is still a breaking change.

Discussion https://discourse.llvm.org/t/rfc-aligning-member-names-between-vector-and-memref-load-store-ops/89175

---
Full diff: https://github.com/llvm/llvm-project/pull/173185.diff


7 Files Affected:

- (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+15-10) 
- (modified) mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp (+3-3) 
- (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+7-7) 
- (modified) mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (+9-8) 
- (modified) mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp (+8-8) 
- (modified) mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp (+4-4) 
- (modified) mlir/test/python/dialects/openacc.py (+2-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 0bf22928f6900..cf5c93fc136d9 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1233,8 +1233,8 @@ def MemRef_GlobalOp : MemRef_Op<"global", [Symbol,
 //===----------------------------------------------------------------------===//
 
 def LoadOp : MemRef_Op<"load",
-     [TypesMatchWith<"result type matches element type of 'memref'",
-                     "memref", "result",
+     [TypesMatchWith<"result type matches element type of 'base'",
+                     "base", "result",
                      "::llvm::cast<MemRefType>($_self).getElementType()">,
       MemRefsNormalizable,
       DeclareOpInterfaceMethods<AlignmentAttrOpInterface>,
@@ -1273,7 +1273,7 @@ def LoadOp : MemRef_Op<"load",
   }];
 
   let arguments = (ins Arg<AnyMemRef, "the reference to load from",
-                           [MemRead]>:$memref,
+                           [MemRead]>:$base,
                        Variadic<Index>:$indices,
                        DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
                        OptionalAttr<IntValidAlignment<I64Attr>>:$alignment);
@@ -1310,6 +1310,8 @@ def LoadOp : MemRef_Op<"load",
   let results = (outs AnyType:$result);
 
   let extraClassDeclaration = [{
+    ::mlir::TypedValue<::mlir::MemRefType> getMemref() { return getBase(); }
+    ::mlir::OpOperand &getMemrefMutable() { return getBaseMutable(); }
     Value getMemRef() { return getOperand(0); }
     void setMemRef(Value value) { setOperand(0, value); }
     MemRefType getMemRefType() {
@@ -1320,7 +1322,7 @@ def LoadOp : MemRef_Op<"load",
   let hasFolder = 1;
   let hasVerifier = 1;
 
-  let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)";
+  let assemblyFormat = "$base `[` $indices `]` attr-dict `:` type($base)";
 }
 
 //===----------------------------------------------------------------------===//
@@ -2010,8 +2012,8 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
 //===----------------------------------------------------------------------===//
 
 def MemRef_StoreOp : MemRef_Op<"store",
-     [TypesMatchWith<"type of 'value' matches element type of 'memref'",
-                     "memref", "value",
+     [TypesMatchWith<"type of 'valueToStore' matches element type of 'base'",
+                     "base", "valueToStore",
                      "::llvm::cast<MemRefType>($_self).getElementType()">,
       MemRefsNormalizable,
       DeclareOpInterfaceMethods<AlignmentAttrOpInterface>,
@@ -2046,9 +2048,9 @@ def MemRef_StoreOp : MemRef_Op<"store",
     ```
   }];
 
-  let arguments = (ins AnyType:$value,
+  let arguments = (ins AnyType:$valueToStore,
                        Arg<AnyMemRef, "the reference to store to",
-                           [MemWrite]>:$memref,
+                           [MemWrite]>:$base,
                        Variadic<Index>:$indices,
                        DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
                        OptionalAttr<IntValidAlignment<I64Attr>>:$alignment);
@@ -2070,7 +2072,10 @@ def MemRef_StoreOp : MemRef_Op<"store",
   ];
 
   let extraClassDeclaration = [{
-      Value getValueToStore() { return getOperand(0); }
+      ::mlir::TypedValue<::mlir::MemRefType> getMemref() { return getBase(); }
+      ::mlir::OpOperand &getMemrefMutable() { return getBaseMutable(); }
+
+      Value getValue() { return getOperand(0); }
 
       Value getMemRef() { return getOperand(1); }
       void setMemRef(Value value) { setOperand(1, value); }
@@ -2083,7 +2088,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
   let hasVerifier = 1;
 
   let assemblyFormat = [{
-    $value `,` $memref `[` $indices `]` attr-dict `:` type($memref)
+    $valueToStore `,` $base `[` $indices `]` attr-dict `:` type($base)
   }];
 }
 
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 0a382d812f362..caefe0bde3cff 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -342,7 +342,7 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
     }
 
     auto arrayValue =
-        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
+        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getBase());
     if (!arrayValue) {
       return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
     }
@@ -362,7 +362,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
   matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
                   ConversionPatternRewriter &rewriter) const override {
     auto arrayValue =
-        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
+        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getBase());
     if (!arrayValue) {
       return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
     }
@@ -370,7 +370,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
     auto subscript = emitc::SubscriptOp::create(
         rewriter, op.getLoc(), arrayValue, operands.getIndices());
     rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
-                                                 operands.getValue());
+                                                 operands.getValueToStore());
     return success();
   }
 };
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 91a0c4b55fa84..af1e25add0167 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -941,9 +941,9 @@ struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
     // Per memref.load spec, the indices must be in-bounds:
     // 0 <= idx < dim_size, and additionally all offsets are non-negative,
     // hence inbounds and nuw are used when lowering to llvm.getelementptr.
-    Value dataPtr = getStridedElementPtr(rewriter, loadOp.getLoc(), type,
-                                         adaptor.getMemref(),
-                                         adaptor.getIndices(), kNoWrapFlags);
+    Value dataPtr =
+        getStridedElementPtr(rewriter, loadOp.getLoc(), type, adaptor.getBase(),
+                             adaptor.getIndices(), kNoWrapFlags);
     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
         loadOp, typeConverter->convertType(type.getElementType()), dataPtr,
         loadOp.getAlignment().value_or(0), false, loadOp.getNontemporal());
@@ -965,11 +965,11 @@ struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
     // 0 <= idx < dim_size, and additionally all offsets are non-negative,
     // hence inbounds and nuw are used when lowering to llvm.getelementptr.
     Value dataPtr =
-        getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(),
+        getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getBase(),
                              adaptor.getIndices(), kNoWrapFlags);
-    rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
-                                               op.getAlignment().value_or(0),
-                                               false, op.getNontemporal());
+    rewriter.replaceOpWithNewOp<LLVM::StoreOp>(
+        op, adaptor.getValueToStore(), dataPtr, op.getAlignment().value_or(0),
+        false, op.getNontemporal());
     return success();
   }
 };
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index a90dcc8cc3ef1..03e896ce5f9b6 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -555,7 +555,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
 
   const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
   Value accessChain =
-      spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
+      spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
                            adaptor.getIndices(), loc, rewriter);
 
   if (!accessChain)
@@ -682,7 +682,7 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
         "failed to lower memref in image storage class to storage buffer");
 
   Value loadPtr = spirv::getElementPtr(
-      *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
+      *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getBase(),
       adaptor.getIndices(), loadOp.getLoc(), rewriter);
 
   if (!loadPtr)
@@ -743,7 +743,7 @@ ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
     return rewriter.notifyMatchFailure(
         loadOp, "failed to lower memref in non-image storage class to image");
 
-  Value loadPtr = adaptor.getMemref();
+  Value loadPtr = adaptor.getBase();
   auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
   if (failed(memoryRequirements))
     return rewriter.notifyMatchFailure(
@@ -824,7 +824,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
   auto loc = storeOp.getLoc();
   auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
   Value accessChain =
-      spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
+      spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
                            adaptor.getIndices(), loc, rewriter);
 
   if (!accessChain)
@@ -874,7 +874,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
           storeOp, "failed to determine memory requirements");
 
     auto [memoryAccess, alignment] = *memoryRequirements;
-    Value storeVal = adaptor.getValue();
+    Value storeVal = adaptor.getValueToStore();
     if (isBool)
       storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
     rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
@@ -915,7 +915,8 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
   clearBitsMask =
       rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
 
-  Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
+  Value storeVal =
+      shiftValue(loc, adaptor.getValueToStore(), offset, mask, rewriter);
   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
                                                    srcBits, dstBits, rewriter);
   std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
@@ -1020,7 +1021,7 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
   if (memrefType.getElementType().isSignlessInteger())
     return rewriter.notifyMatchFailure(storeOp, "signless int");
   auto storePtr = spirv::getElementPtr(
-      *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
+      *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getBase(),
       adaptor.getIndices(), storeOp.getLoc(), rewriter);
 
   if (!storePtr)
@@ -1033,7 +1034,7 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
 
   auto [memoryAccess, alignment] = *memoryRequirements;
   rewriter.replaceOpWithNewOp<spirv::StoreOp>(
-      storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
+      storeOp, storePtr, adaptor.getValueToStore(), memoryAccess, alignment);
   return success();
 }
 
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 09d4ffa61738a..1030faa212f11 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -284,7 +284,7 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
   LogicalResult
   matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
+    auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     auto convertedElementType = convertedType.getElementType();
     auto oldElementType = op.getMemRefType().getElementType();
     int srcBits = oldElementType.getIntOrFloatBitWidth();
@@ -298,7 +298,7 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
     // Special case 0-rank memref loads.
     Value bitsLoad;
     if (convertedType.getRank() == 0) {
-      bitsLoad = memref::LoadOp::create(rewriter, loc, adaptor.getMemref(),
+      bitsLoad = memref::LoadOp::create(rewriter, loc, adaptor.getBase(),
                                         ValueRange{});
     } else {
       // Linearize the indices of the original load instruction. Do not account
@@ -307,7 +307,7 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
           rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
 
       Value newLoad = memref::LoadOp::create(
-          rewriter, loc, adaptor.getMemref(),
+          rewriter, loc, adaptor.getBase(),
           getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
                                    dstBits));
 
@@ -414,7 +414,7 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
   LogicalResult
   matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
+    auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     int srcBits = op.getMemRefType().getElementTypeBitWidth();
     int dstBits = convertedType.getElementTypeBitWidth();
     auto dstIntegerType = rewriter.getIntegerType(dstBits);
@@ -426,7 +426,7 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
     Location loc = op.getLoc();
 
     // Pad the input value with 0s on the left.
-    Value input = adaptor.getValue();
+    Value input = adaptor.getValueToStore();
     if (!input.getType().isInteger()) {
       input = arith::BitcastOp::create(
           rewriter, loc,
@@ -440,7 +440,7 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
     // Special case 0-rank memref stores. No need for masking.
     if (convertedType.getRank() == 0) {
       memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::assign,
-                                  extendedInput, adaptor.getMemref(),
+                                  extendedInput, adaptor.getBase(),
                                   ValueRange{});
       rewriter.eraseOp(op);
       return success();
@@ -460,10 +460,10 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
 
     // Clear destination bits
     memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::andi,
-                                writeMask, adaptor.getMemref(), storeIndices);
+                                writeMask, adaptor.getBase(), storeIndices);
     // Write srcs bits to destination
     memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::ori,
-                                alignedVal, adaptor.getMemref(), storeIndices);
+                                alignedVal, adaptor.getBase(), storeIndices);
     rewriter.eraseOp(op);
     return success();
   }
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
index 6f815ae46904c..79282ecd79d5e 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
@@ -66,9 +66,9 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
           op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
                                       op.getMemRefType()));
 
-    rewriter.replaceOpWithNewOp<memref::LoadOp>(
-        op, newResTy, adaptor.getMemref(), adaptor.getIndices(),
-        op.getNontemporal());
+    rewriter.replaceOpWithNewOp<memref::LoadOp>(op, newResTy, adaptor.getBase(),
+                                                adaptor.getIndices(),
+                                                op.getNontemporal());
     return success();
   }
 };
@@ -90,7 +90,7 @@ struct ConvertMemRefStore final : OpConversionPattern<memref::StoreOp> {
                                       op.getMemRefType()));
 
     rewriter.replaceOpWithNewOp<memref::StoreOp>(
-        op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices(),
+        op, adaptor.getValueToStore(), adaptor.getBase(), adaptor.getIndices(),
         op.getNontemporal());
     return success();
   }
diff --git a/mlir/test/python/dialects/openacc.py b/mlir/test/python/dialects/openacc.py
index 8f2142a74c7a1..d3af869889e10 100644
--- a/mlir/test/python/dialects/openacc.py
+++ b/mlir/test/python/dialects/openacc.py
@@ -121,8 +121,8 @@ def testParallelMemcpy():
 
             with InsertionPoint(loop_block):
                 idx = arith.index_cast(out=IndexType.get(), in_=loop_block.arguments[0])
-                val = memref.load(memref=copied, indices=[idx])
-                memref.store(value=val, memref=created, indices=[idx])
+                val = memref.load(base=copied, indices=[idx])
+                memref.store(value_to_store=val, base=created, indices=[idx])
                 openacc.YieldOp([])
 
             openacc.YieldOp([])

``````````

</details>


https://github.com/llvm/llvm-project/pull/173185


More information about the Mlir-commits mailing list