[Mlir-commits] [mlir] 04ad8d4 - Emit inbounds and nuw attributes in memref. (#138984)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 20 14:16:26 PDT 2025
Author: Peiyong Lin
Date: 2025-05-20T14:16:22-07:00
New Revision: 04ad8d4900fb4534a2120335e26d1a1a310ef256
URL: https://github.com/llvm/llvm-project/commit/04ad8d4900fb4534a2120335e26d1a1a310ef256
DIFF: https://github.com/llvm/llvm-project/commit/04ad8d4900fb4534a2120335e26d1a1a310ef256.diff
LOG: Emit inbounds and nuw attributes in memref. (#138984)
Now that MLIR accepts nuw and nusw in getelementptr, this patch emits
the inbounds and nuw attributes when lower memref to LLVM in load and
store operators.
This patch also strengthens the memref.load and memref.store spec about
undefined behaviour during lowering.
This patch also lifts the |rewriter| parameter in getStridedElementPtr
ahead so that LLVM::GEPNoWrapFlags can be added at the end with a
default value and grouped together with other operators' parameters.
Signed-off-by: Lin, Peiyong <linpyong at gmail.com>
Added:
Modified:
mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
mlir/lib/Conversion/LLVMCommon/Pattern.cpp
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index ddbac85aa34fd..2bf9a021f48e1 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -109,9 +109,10 @@ class ConvertToLLVMPattern : public ConversionPattern {
// This is a strided getElementPtr variant that linearizes subscripts as:
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
- Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
- ValueRange indices,
- ConversionPatternRewriter &rewriter) const;
+ Value getStridedElementPtr(
+ ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
+ Value memRefDesc, ValueRange indices,
+ LLVM::GEPNoWrapFlags noWrapFlags = LLVM::GEPNoWrapFlags::none) const;
/// Returns if the given memref type is convertible to LLVM and has an
/// identity layout map.
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 54ac899f96f06..f33ecb28d27cd 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1202,7 +1202,12 @@ def LoadOp : MemRef_Op<"load",
The `load` op reads an element from a memref at the specified indices.
The number of indices must match the rank of the memref. The indices must
- be in-bounds: `0 <= idx < dim_size`
+ be in-bounds: `0 <= idx < dim_size`.
+
+ Lowerings of `memref.load` may emit attributes, e.g. `inbouds` + `nuw`
+ when converting to LLVM's `llvm.getelementptr`, that would cause undefined
+ behavior if indices are out of bounds or if computing the offset in the
+ memref would cause signed overflow of the `index` type.
The single result of `memref.load` is a value with the same type as the
element type of the memref.
@@ -1896,7 +1901,12 @@ def MemRef_StoreOp : MemRef_Op<"store",
The `store` op stores an element into a memref at the specified indices.
The number of indices must match the rank of the memref. The indices must
- be in-bounds: `0 <= idx < dim_size`
+ be in-bounds: `0 <= idx < dim_size`.
+
+ Lowerings of `memref.store` may emit attributes, e.g. `inbouds` + `nuw`
+ when converting to LLVM's `llvm.getelementptr`, that would cause undefined
+ behavior if indices are out of bounds or if computing the offset in the
+ memref would cause signed overflow of the `index` type.
A set `nontemporal` attribute indicates that this store is not expected to
be reused in the cache. For details, refer to the
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index d0093b8dc8c2a..0694cf27faff4 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1117,10 +1117,12 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
return op.emitOpError("chipset unsupported element size");
- Value srcPtr = getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(),
- (adaptor.getSrcIndices()), rewriter);
- Value dstPtr = getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(),
- (adaptor.getDstIndices()), rewriter);
+ Value srcPtr =
+ getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
+ (adaptor.getSrcIndices()));
+ Value dstPtr =
+ getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
+ (adaptor.getDstIndices()));
rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 417555792b44f..0c3f942b5cbd9 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -299,9 +299,9 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getI64Type(), sliceIndex);
return getStridedElementPtr(
- loc, llvm::cast<MemRefType>(tileMemory.getType()),
- descriptor.getResult(0), {sliceIndexI64, zero},
- static_cast<ConversionPatternRewriter &>(rewriter));
+ static_cast<ConversionPatternRewriter &>(rewriter), loc,
+ llvm::cast<MemRefType>(tileMemory.getType()), descriptor.getResult(0),
+ {sliceIndexI64, zero});
}
/// Emits an in-place swap of a slice of a tile in ZA and a slice of a
@@ -507,9 +507,9 @@ struct LoadTileSliceConversion
if (!tileId)
return failure();
- Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
- adaptor.getBase(),
- adaptor.getIndices(), rewriter);
+ Value ptr = this->getStridedElementPtr(
+ rewriter, loc, loadTileSliceOp.getMemRefType(), adaptor.getBase(),
+ adaptor.getIndices());
auto tileSlice = loadTileSliceOp.getTileSliceIndex();
@@ -554,8 +554,8 @@ struct StoreTileSliceConversion
// Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
Value ptr = this->getStridedElementPtr(
- loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
- adaptor.getIndices(), rewriter);
+ rewriter, loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
+ adaptor.getIndices());
auto tileSlice = storeTileSliceOp.getTileSliceIndex();
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4bd94bcebf290..45fd933d58857 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -122,8 +122,9 @@ struct WmmaLoadOpToNVVMLowering
// Create nvvm.mma_load op according to the operand types.
Value dataPtr = getStridedElementPtr(
- loc, cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
- adaptor.getSrcMemref(), adaptor.getIndices(), rewriter);
+ rewriter, loc,
+ cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
+ adaptor.getSrcMemref(), adaptor.getIndices());
Value leadingDim = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(),
@@ -177,9 +178,9 @@ struct WmmaStoreOpToNVVMLowering
}
Value dataPtr = getStridedElementPtr(
- loc,
+ rewriter, loc,
cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()),
- adaptor.getDstMemref(), adaptor.getIndices(), rewriter);
+ adaptor.getDstMemref(), adaptor.getIndices());
Value leadingDim = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(),
subgroupMmaStoreMatrixOp.getLeadDimensionAttr());
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index b975d6f7a6a3c..8da850678878d 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -59,8 +59,9 @@ Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
}
Value ConvertToLLVMPattern::getStridedElementPtr(
- Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
- ConversionPatternRewriter &rewriter) const {
+ ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
+ Value memRefDesc, ValueRange indices,
+ LLVM::GEPNoWrapFlags noWrapFlags) const {
auto [strides, offset] = type.getStridesAndOffset();
@@ -91,7 +92,7 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
return index ? rewriter.create<LLVM::GEPOp>(
loc, elementPtrType,
getTypeConverter()->convertType(type.getElementType()),
- base, index)
+ base, index, noWrapFlags)
: base;
}
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 7f45904fab7e1..ade4e4d3de8ec 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -35,6 +35,9 @@ namespace mlir {
using namespace mlir;
+static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags =
+ LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw;
+
namespace {
static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
@@ -420,8 +423,8 @@ struct AssumeAlignmentOpLowering
auto loc = op.getLoc();
auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
- Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
- rewriter);
+ Value ptr = getStridedElementPtr(rewriter, loc, srcMemRefType, memref,
+ /*indices=*/{});
// Emit llvm.assume(true) ["align"(memref, alignment)].
// This is more direct than ptrtoint-based checks, is explicitly supported,
@@ -643,8 +646,8 @@ struct GenericAtomicRMWOpLowering
// Compute the loaded value and branch to the loop block.
rewriter.setInsertionPointToEnd(initBlock);
auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
- auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
- adaptor.getIndices(), rewriter);
+ auto dataPtr = getStridedElementPtr(
+ rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
Value init = rewriter.create<LLVM::LoadOp>(
loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
@@ -828,9 +831,12 @@ struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
ConversionPatternRewriter &rewriter) const override {
auto type = loadOp.getMemRefType();
- Value dataPtr =
- getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
- adaptor.getIndices(), rewriter);
+ // Per memref.load spec, the indices must be in-bounds:
+ // 0 <= idx < dim_size, and additionally all offsets are non-negative,
+ // hence inbounds and nuw are used when lowering to llvm.getelementptr.
+ Value dataPtr = getStridedElementPtr(rewriter, loadOp.getLoc(), type,
+ adaptor.getMemref(),
+ adaptor.getIndices(), kNoWrapFlags);
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
false, loadOp.getNontemporal());
@@ -848,8 +854,12 @@ struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
ConversionPatternRewriter &rewriter) const override {
auto type = op.getMemRefType();
- Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
- adaptor.getIndices(), rewriter);
+ // Per memref.store spec, the indices must be in-bounds:
+ // 0 <= idx < dim_size, and additionally all offsets are non-negative,
+ // hence inbounds and nuw are used when lowering to llvm.getelementptr.
+ Value dataPtr =
+ getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(),
+ adaptor.getIndices(), kNoWrapFlags);
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
0, false, op.getNontemporal());
return success();
@@ -867,8 +877,8 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
auto type = prefetchOp.getMemRefType();
auto loc = prefetchOp.getLoc();
- Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
- adaptor.getIndices(), rewriter);
+ Value dataPtr = getStridedElementPtr(
+ rewriter, loc, type, adaptor.getMemref(), adaptor.getIndices());
// Replace with llvm.prefetch.
IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
@@ -1808,8 +1818,8 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
if (failed(memRefType.getStridesAndOffset(strides, offset)))
return failure();
auto dataPtr =
- getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
- adaptor.getIndices(), rewriter);
+ getStridedElementPtr(rewriter, atomicOp.getLoc(), memRefType,
+ adaptor.getMemref(), adaptor.getIndices());
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
LLVM::AtomicOrdering::acq_rel);
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 69fa62c8196e4..eb3558d2460e4 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -283,8 +283,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
Value srcPtr =
- getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
- adaptor.getIndices(), rewriter);
+ getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
+ adaptor.getSrcMemref(), adaptor.getIndices());
Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
ldMatrixResultType, srcPtr,
/*num=*/op.getNumTiles(),
@@ -661,8 +661,8 @@ struct NVGPUAsyncCopyLowering
Location loc = op.getLoc();
auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
Value dstPtr =
- getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
- adaptor.getDstIndices(), rewriter);
+ getStridedElementPtr(rewriter, b.getLoc(), dstMemrefType,
+ adaptor.getDst(), adaptor.getDstIndices());
FailureOr<unsigned> dstAddressSpace =
getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
if (failed(dstAddressSpace))
@@ -676,8 +676,9 @@ struct NVGPUAsyncCopyLowering
return rewriter.notifyMatchFailure(
loc, "source memref address space not convertible to integer");
- Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
- adaptor.getSrcIndices(), rewriter);
+ Value scrPtr =
+ getStridedElementPtr(rewriter, loc, srcMemrefType, adaptor.getSrc(),
+ adaptor.getSrcIndices());
// Intrinsics takes a global pointer so we need an address space cast.
auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace);
@@ -814,7 +815,7 @@ struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
MemRefType mbarrierMemrefType =
nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
return ConvertToLLVMPattern::getStridedElementPtr(
- b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
+ rewriter, b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId});
}
};
@@ -995,8 +996,8 @@ struct NVGPUTmaAsyncLoadOpLowering
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
- Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
- adaptor.getDst(), {}, rewriter);
+ Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
+ adaptor.getDst(), {});
Value barrier =
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
@@ -1021,8 +1022,8 @@ struct NVGPUTmaAsyncStoreOpLowering
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
- Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
- adaptor.getSrc(), {}, rewriter);
+ Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
+ adaptor.getSrc(), {});
SmallVector<Value> coords = adaptor.getCoordinates();
for (auto [index, value] : llvm::enumerate(coords)) {
coords[index] = truncToI32(b, value);
@@ -1083,8 +1084,8 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
Value leadDim = makeConst(leadDimVal);
Value baseAddr = getStridedElementPtr(
- op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
- adaptor.getTensor(), {}, rewriter);
+ rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
+ adaptor.getTensor(), {});
Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
// Just use 14 bits for base address
Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 400003d37bf20..f725993635672 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -289,8 +289,8 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
// Resolve address.
auto vtype = cast<VectorType>(
this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
- Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
- adaptor.getIndices(), rewriter);
+ Value dataPtr = this->getStridedElementPtr(
+ rewriter, loc, memRefTy, adaptor.getBase(), adaptor.getIndices());
replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
rewriter);
return success();
@@ -337,8 +337,8 @@ class VectorGatherOpConversion
return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
// Resolve address.
- Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
- adaptor.getIndices(), rewriter);
+ Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
+ adaptor.getBase(), adaptor.getIndices());
Value base = adaptor.getBase();
Value ptrs =
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
@@ -393,8 +393,8 @@ class VectorScatterOpConversion
"could not resolve alignment");
// Resolve address.
- Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
- adaptor.getIndices(), rewriter);
+ Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
+ adaptor.getBase(), adaptor.getIndices());
Value ptrs =
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
@@ -428,8 +428,8 @@ class VectorExpandLoadOpConversion
// Resolve address.
auto vtype = typeConverter->convertType(expand.getVectorType());
- Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
- adaptor.getIndices(), rewriter);
+ Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
+ adaptor.getBase(), adaptor.getIndices());
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
@@ -450,8 +450,8 @@ class VectorCompressStoreOpConversion
MemRefType memRefType = compress.getMemRefType();
// Resolve address.
- Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
- adaptor.getIndices(), rewriter);
+ Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
+ adaptor.getBase(), adaptor.getIndices());
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 4cb777b03b196..2168409184549 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -105,8 +105,8 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
if (failed(stride))
return failure();
// Replace operation with intrinsic.
- Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
- adaptor.getIndices(), rewriter);
+ Value ptr = getStridedElementPtr(rewriter, op.getLoc(), mType,
+ adaptor.getBase(), adaptor.getIndices());
Type resType = typeConverter->convertType(tType);
rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
op, resType, tsz.first, tsz.second, ptr, stride.value());
@@ -131,8 +131,8 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
if (failed(stride))
return failure();
// Replace operation with intrinsic.
- Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
- adaptor.getIndices(), rewriter);
+ Value ptr = getStridedElementPtr(rewriter, op.getLoc(), mType,
+ adaptor.getBase(), adaptor.getIndices());
rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal());
return success();
diff --git a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
index 058b69b8e3596..3b52d8fd76464 100644
--- a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
@@ -266,7 +266,7 @@ func.func @bare_ptr_calling_conv(%arg0: memref<4x3xf32>, %arg1 : index, %arg2 :
// CHECK: %[[INSERT_STRIDE1:.*]] = llvm.insertvalue %[[C1]], %[[INSERT_DIM1]][4, 1]
// CHECK: %[[ALIGNEDPTR:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][1]
- // CHECK: %[[STOREPTR:.*]] = llvm.getelementptr %[[ALIGNEDPTR]]
+ // CHECK: %[[STOREPTR:.*]] = llvm.getelementptr inbounds|nuw %[[ALIGNEDPTR]]
// CHECK: llvm.store %{{.*}}, %[[STOREPTR]]
memref.store %arg3, %arg0[%arg1, %arg2] : memref<4x3xf32>
@@ -295,12 +295,12 @@ func.func @bare_ptr_calling_conv_multiresult(%arg0: memref<4x3xf32>, %arg1 : ind
// CHECK: %[[INSERT_STRIDE1:.*]] = llvm.insertvalue %[[C1]], %[[INSERT_DIM1]][4, 1]
// CHECK: %[[ALIGNEDPTR:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][1]
- // CHECK: %[[STOREPTR:.*]] = llvm.getelementptr %[[ALIGNEDPTR]]
+ // CHECK: %[[STOREPTR:.*]] = llvm.getelementptr inbounds|nuw %[[ALIGNEDPTR]]
// CHECK: llvm.store %{{.*}}, %[[STOREPTR]]
memref.store %arg3, %arg0[%arg1, %arg2] : memref<4x3xf32>
// CHECK: %[[ALIGNEDPTR0:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][1]
- // CHECK: %[[LOADPTR:.*]] = llvm.getelementptr %[[ALIGNEDPTR0]]
+ // CHECK: %[[LOADPTR:.*]] = llvm.getelementptr inbounds|nuw %[[ALIGNEDPTR0]]
// CHECK: %[[RETURN0:.*]] = llvm.load %[[LOADPTR]]
%0 = memref.load %arg0[%arg1, %arg2] : memref<4x3xf32>
diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
index be3ddc20c17b7..9ca8bcd1491bc 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
@@ -177,7 +177,7 @@ func.func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) {
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64
-// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK-NEXT: llvm.load %[[addr]] : !llvm.ptr -> f32
%0 = memref.load %mixed[%i, %j] : memref<42x?xf32>
return
@@ -194,7 +194,7 @@ func.func @dynamic_load(%dynamic : memref<?x?xf32>, %i : index, %j : index) {
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64
-// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK-NEXT: llvm.load %[[addr]] : !llvm.ptr -> f32
%0 = memref.load %dynamic[%i, %j] : memref<?x?xf32>
return
@@ -232,7 +232,7 @@ func.func @dynamic_store(%dynamic : memref<?x?xf32>, %i : index, %j : index, %va
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64
-// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : f32, !llvm.ptr
memref.store %val, %dynamic[%i, %j] : memref<?x?xf32>
return
@@ -249,7 +249,7 @@ func.func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val :
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64
-// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : f32, !llvm.ptr
memref.store %val, %mixed[%i, %j] : memref<42x?xf32>
return
diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
index 0a92c7cf7b216..b03ac2c20112b 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
@@ -140,7 +140,7 @@ func.func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) {
// CHECK: %[[st0:.*]] = llvm.mlir.constant(42 : index) : i64
// CHECK: %[[offI:.*]] = llvm.mul %[[II]], %[[st0]] : i64
// CHECK: %[[off1:.*]] = llvm.add %[[offI]], %[[JJ]] : i64
-// CHECK: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: llvm.load %[[addr]] : !llvm.ptr -> f32
%0 = memref.load %static[%i, %j] : memref<10x42xf32>
return
@@ -168,7 +168,7 @@ func.func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %va
// CHECK: %[[st0:.*]] = llvm.mlir.constant(42 : index) : i64
// CHECK: %[[offI:.*]] = llvm.mul %[[II]], %[[st0]] : i64
// CHECK: %[[off1:.*]] = llvm.add %[[offI]], %[[JJ]] : i64
-// CHECK: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: llvm.store %{{.*}}, %[[addr]] : f32, !llvm.ptr
memref.store %val, %static[%i, %j] : memref<10x42xf32>
@@ -307,7 +307,7 @@ func.func @memref.reshape.dynamic.dim(%arg: memref<?x?x?xf32>, %shape: memref<4x
// CHECK: %[[three_hundred_and_eighty_four:.*]] = llvm.mlir.constant(384 : index) : i64
// CHECK: %[[one1:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[shape_ptr0:.*]] = llvm.extractvalue %[[shape_cast]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: %[[shape_gep0:.*]] = llvm.getelementptr %[[shape_ptr0]][%[[one1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i64
+ // CHECK: %[[shape_gep0:.*]] = llvm.getelementptr inbounds|nuw %[[shape_ptr0]][%[[one1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i64
// CHECK: %[[shape_load0:.*]] = llvm.load %[[shape_gep0]] : !llvm.ptr -> i64
// CHECK: %[[insert7:.*]] = llvm.insertvalue %[[shape_load0]], %[[insert6]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)>
// CHECK: %[[insert8:.*]] = llvm.insertvalue %[[three_hundred_and_eighty_four]], %[[insert7]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)>
@@ -315,7 +315,7 @@ func.func @memref.reshape.dynamic.dim(%arg: memref<?x?x?xf32>, %shape: memref<4x
// CHECK: %[[mul:.*]] = llvm.mul %19, %23 : i64
// CHECK: %[[zero1:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[shape_ptr1:.*]] = llvm.extractvalue %[[shape_cast]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: %[[shape_gep1:.*]] = llvm.getelementptr %[[shape_ptr1]][%[[zero1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i64
+ // CHECK: %[[shape_gep1:.*]] = llvm.getelementptr inbounds|nuw %[[shape_ptr1]][%[[zero1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i64
// CHECK: %[[shape_load1:.*]] = llvm.load %[[shape_gep1]] : !llvm.ptr -> i64
// CHECK: %[[insert9:.*]] = llvm.insertvalue %[[shape_load1]], %[[insert8]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)>
// CHECK: %[[insert10:.*]] = llvm.insertvalue %[[mul]], %[[insert9]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)>
@@ -347,7 +347,7 @@ func.func @memref.reshape_index(%arg0: memref<?x?xi32>, %shape: memref<1xindex>)
// CHECK: %[[zero1:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[shape_ptr0:.*]] = llvm.extractvalue %[[shape_cast:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: %[[shape_gep0:.*]] = llvm.getelementptr %[[shape_ptr0:.*]][%[[zero1:.*]]] : (!llvm.ptr, i64) -> !llvm.ptr, i64
+ // CHECK: %[[shape_gep0:.*]] = llvm.getelementptr inbounds|nuw %[[shape_ptr0:.*]][%[[zero1:.*]]] : (!llvm.ptr, i64) -> !llvm.ptr, i64
// CHECK: %[[shape_load0:.*]] = llvm.load %[[shape_gep0:.*]] : !llvm.ptr -> i64
// CHECK: %[[insert3:.*]] = llvm.insertvalue %[[shape_load0:.*]], %[[insert2:.*]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[insert4:.*]] = llvm.insertvalue %[[one0:.*]], %[[insert3:.*]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index 8dd7edf3e29b1..68c3e9f5e26ec 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -676,7 +676,7 @@ func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf
// CHECK: %[[OFFSET:.*]] = llvm.extractvalue %[[DESC]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BUFF_ADDR:.*]] = llvm.getelementptr %[[ALIGNED_PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: llvm.intr.assume %{{.*}} ["align"(%[[BUFF_ADDR]], %{{.*}} : !llvm.ptr, i64)] : i1
-// CHECK: %[[LD_ADDR:.*]] = llvm.getelementptr %[[BUFF_ADDR]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK: %[[LD_ADDR:.*]] = llvm.getelementptr inbounds|nuw %[[BUFF_ADDR]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: %[[VAL:.*]] = llvm.load %[[LD_ADDR]] : !llvm.ptr -> f32
// CHECK: return %[[VAL]] : f32
func.func @load_and_assume(
More information about the Mlir-commits
mailing list