[Mlir-commits] [mlir] [mlir][llvm] Add support for memset.inline (PR #115711)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 11 04:51:07 PST 2024
https://github.com/PikachuHyA created https://github.com/llvm/llvm-project/pull/115711
support `llvm.intr.memset.inline` in llvm-project repo before we add support for `__builtin_memset_inline` in clangir
cc @bcardosolopes
>From 83950f2ba1fdc8738f626c5b581d56bffb3ad9f9 Mon Sep 17 00:00:00 2001
From: PikachuHy <pikachuhy at linux.alibaba.com>
Date: Mon, 11 Nov 2024 20:44:47 +0800
Subject: [PATCH] [mlir][llvm] Add support for memset.inline
---
.../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td | 26 ++
mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp | 4 +
mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 255 +++++++++++++-----
.../test/Target/LLVMIR/llvmir-intrinsics.mlir | 4 +
4 files changed, 224 insertions(+), 65 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index d07ebbacc60434..85785938405859 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -256,6 +256,32 @@ def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2],
];
}
+def LLVM_MemsetInlineOp : LLVM_ZeroResultIntrOp<"memset.inline", [0, 2],
+ [DeclareOpInterfaceMethods<PromotableMemOpInterface>,
+ DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
+ DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
+ /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
+ /*requiresOpBundles=*/0, /*immArgPositions=*/[2, 3],
+ /*immArgAttrNames=*/["len", "isVolatile"]> {
+ dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
+ I8:$val, APIntAttr:$len, I1Attr:$isVolatile);
+ // Append the alias attributes defined by LLVM_IntrOpBase.
+ let arguments = !con(args, aliasAttrs);
+ let builders = [
+ OpBuilder<(ins "Value":$dst, "Value":$val, "IntegerAttr":$len,
+ "bool":$isVolatile), [{
+ build($_builder, $_state, dst, val, len,
+ $_builder.getBoolAttr(isVolatile));
+ }]>,
+ OpBuilder<(ins "Value":$dst, "Value":$val, "IntegerAttr":$len,
+ "IntegerAttr":$isVolatile), [{
+ build($_builder, $_state, dst, val, len, isVolatile,
+ /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
+ /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+ }]>
+ ];
+}
+
def LLVM_NoAliasScopeDeclOp
: LLVM_ZeroResultIntrOp<"experimental.noalias.scope.decl"> {
let arguments = (ins LLVM_AliasScopeAttr:$scope);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
index cff16afc73af3f..a59900745d026e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
@@ -94,6 +94,10 @@ SmallVector<Value> mlir::LLVM::MemsetOp::getAccessedOperands() {
return {getDst()};
}
+SmallVector<Value> mlir::LLVM::MemsetInlineOp::getAccessedOperands() {
+ return {getDst()};
+}
+
SmallVector<Value> mlir::LLVM::CallOp::getAccessedOperands() {
return llvm::to_vector(
llvm::make_filter_range(getArgOperands(), [](Value arg) {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 5dc506c14ef96f..e262964c9f348c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -904,6 +904,14 @@ std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemcpyInlineOp op) {
return memIntrLen.getZExtValue();
}
+template <>
+std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemsetInlineOp op) {
+ APInt memIntrLen = op.getLen();
+ if (memIntrLen.getBitWidth() > 64)
+ return {};
+ return memIntrLen.getZExtValue();
+}
+
} // namespace
/// Returns whether one can be sure the memory intrinsic does not write outside
@@ -931,23 +939,37 @@ static bool areAllIndicesI32(const DestructurableMemorySlot &slot) {
}
//===----------------------------------------------------------------------===//
-// Interfaces for memset
+// Interfaces for memset && memset.inline
//===----------------------------------------------------------------------===//
-bool LLVM::MemsetOp::loadsFrom(const MemorySlot &slot) { return false; }
+template <class MemsetLike>
+static bool memsetCanRewire(MemsetLike op, const DestructurableMemorySlot &slot,
+ SmallPtrSetImpl<Attribute> &usedIndices,
+ SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
+ if (&slot.elemType.getDialect() != op.getOperation()->getDialect())
+ return false;
-bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
- return getDst() == slot.ptr;
+ if (op.getIsVolatile())
+ return false;
+
+ if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
+ return false;
+
+ if (!areAllIndicesI32(slot))
+ return false;
+
+ return definitelyWritesOnlyWithinSlot(op, slot, dataLayout);
}
-Value LLVM::MemsetOp::getStored(const MemorySlot &slot, OpBuilder &builder,
- Value reachingDef,
- const DataLayout &dataLayout) {
+template <class MemsetLike>
+static Value memsetGetStored(MemsetLike op, const MemorySlot &slot,
+ OpBuilder &builder) {
// TODO: Support non-integer types.
return TypeSwitch<Type, Value>(slot.elemType)
.Case([&](IntegerType intType) -> Value {
if (intType.getWidth() == 8)
- return getVal();
+ return op.getVal();
assert(intType.getWidth() % 8 == 0);
@@ -955,14 +977,14 @@ Value LLVM::MemsetOp::getStored(const MemorySlot &slot, OpBuilder &builder,
// or-ing it with the previous value.
uint64_t coveredBits = 8;
Value currentValue =
- builder.create<LLVM::ZExtOp>(getLoc(), intType, getVal());
+ builder.create<LLVM::ZExtOp>(op.getLoc(), intType, op.getVal());
while (coveredBits < intType.getWidth()) {
- Value shiftBy =
- builder.create<LLVM::ConstantOp>(getLoc(), intType, coveredBits);
+ Value shiftBy = builder.create<LLVM::ConstantOp>(op.getLoc(), intType,
+ coveredBits);
Value shifted =
- builder.create<LLVM::ShlOp>(getLoc(), currentValue, shiftBy);
+ builder.create<LLVM::ShlOp>(op.getLoc(), currentValue, shiftBy);
currentValue =
- builder.create<LLVM::OrOp>(getLoc(), currentValue, shifted);
+ builder.create<LLVM::OrOp>(op.getLoc(), currentValue, shifted);
coveredBits *= 2;
}
@@ -974,10 +996,12 @@ Value LLVM::MemsetOp::getStored(const MemorySlot &slot, OpBuilder &builder,
});
}
-bool LLVM::MemsetOp::canUsesBeRemoved(
- const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- SmallVectorImpl<OpOperand *> &newBlockingUses,
- const DataLayout &dataLayout) {
+template <class MemsetLike>
+static bool
+memsetCanUsesBeRemoved(MemsetLike op, const MemorySlot &slot,
+ const SmallPtrSetImpl<OpOperand *> &blockingUses,
+ SmallVectorImpl<OpOperand *> &newBlockingUses,
+ const DataLayout &dataLayout) {
// TODO: Support non-integer types.
bool canConvertType =
TypeSwitch<Type, bool>(slot.elemType)
@@ -988,62 +1012,74 @@ bool LLVM::MemsetOp::canUsesBeRemoved(
if (!canConvertType)
return false;
- if (getIsVolatile())
+ if (op.getIsVolatile())
return false;
- return getStaticMemIntrLen(*this) == dataLayout.getTypeSize(slot.elemType);
+ return getStaticMemIntrLen(op) == dataLayout.getTypeSize(slot.elemType);
}
-
-DeletionKind LLVM::MemsetOp::removeBlockingUses(
- const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- OpBuilder &builder, Value reachingDefinition,
- const DataLayout &dataLayout) {
- return DeletionKind::Delete;
+namespace {
+template <class MemsetLike>
+void createMemsetLenAttr(MemsetLike op, IntegerAttr &memsetLenAttr) {
+ bool successfulMatch =
+ matchPattern(op.getLen(), m_Constant<IntegerAttr>(&memsetLenAttr));
+ (void)successfulMatch;
+ assert(successfulMatch);
}
-
-LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses(
- const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
- const DataLayout &dataLayout) {
- return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
+template <>
+void createMemsetLenAttr(LLVM::MemsetInlineOp op, IntegerAttr &memsetLenAttr) {
+ memsetLenAttr = op.getLenAttr();
+}
+template <class MemsetLike>
+void createMemsetLikeToReplace(OpBuilder &builder, MemsetLike toReplace,
+ IntegerAttr memsetLenAttr,
+ uint64_t newMemsetSize,
+ DenseMap<Attribute, MemorySlot> &subslots,
+ Attribute index) {
+
+ Value newMemsetSizeValue =
+ builder
+ .create<LLVM::ConstantOp>(
+ toReplace.getLen().getLoc(),
+ IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize))
+ .getResult();
+
+ builder.create<LLVM::MemsetOp>(toReplace.getLoc(), subslots.at(index).ptr,
+ toReplace.getVal(), newMemsetSizeValue,
+ toReplace.getIsVolatile());
}
+template <>
+void createMemsetLikeToReplace(OpBuilder &builder,
+ LLVM::MemsetInlineOp toReplace,
+ IntegerAttr memsetLenAttr,
+ uint64_t newMemsetSize,
+ DenseMap<Attribute, MemorySlot> &subslots,
+ Attribute index) {
-bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
- SmallPtrSetImpl<Attribute> &usedIndices,
- SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
- const DataLayout &dataLayout) {
- if (&slot.elemType.getDialect() != getOperation()->getDialect())
- return false;
-
- if (getIsVolatile())
- return false;
-
- if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
- return false;
-
- if (!areAllIndicesI32(slot))
- return false;
+ auto newMemsetSizeValue =
+ IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize);
- return definitelyWritesOnlyWithinSlot(*this, slot, dataLayout);
+ builder.create<LLVM::MemsetInlineOp>(
+ toReplace.getLoc(), subslots.at(index).ptr, toReplace.getVal(),
+ newMemsetSizeValue, toReplace.getIsVolatile());
}
+} // namespace
+template <class MemsetLike>
+static DeletionKind
+memsetRewire(MemsetLike op, const DestructurableMemorySlot &slot,
+ DenseMap<Attribute, MemorySlot> &subslots, OpBuilder &builder,
+ const DataLayout &dataLayout) {
-DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
- DenseMap<Attribute, MemorySlot> &subslots,
- OpBuilder &builder,
- const DataLayout &dataLayout) {
std::optional<DenseMap<Attribute, Type>> types =
cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap();
IntegerAttr memsetLenAttr;
- bool successfulMatch =
- matchPattern(getLen(), m_Constant<IntegerAttr>(&memsetLenAttr));
- (void)successfulMatch;
- assert(successfulMatch);
+ createMemsetLenAttr(op, memsetLenAttr);
bool packed = false;
if (auto structType = dyn_cast<LLVM::LLVMStructType>(slot.elemType))
packed = structType.isPacked();
- Type i32 = IntegerType::get(getContext(), 32);
+ Type i32 = IntegerType::get(op.getContext(), 32);
uint64_t memsetLen = memsetLenAttr.getValue().getZExtValue();
uint64_t covered = 0;
for (size_t i = 0; i < types->size(); i++) {
@@ -1063,16 +1099,8 @@ DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
// Otherwise, only compute its offset within the original memset.
if (subslots.contains(index)) {
uint64_t newMemsetSize = std::min(memsetLen - covered, typeSize);
-
- Value newMemsetSizeValue =
- builder
- .create<LLVM::ConstantOp>(
- getLen().getLoc(),
- IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize))
- .getResult();
-
- builder.create<LLVM::MemsetOp>(getLoc(), subslots.at(index).ptr, getVal(),
- newMemsetSizeValue, getIsVolatile());
+ createMemsetLikeToReplace(builder, op, memsetLenAttr, newMemsetSize,
+ subslots, index);
}
covered += typeSize;
@@ -1081,6 +1109,103 @@ DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
return DeletionKind::Delete;
}
+bool LLVM::MemsetOp::loadsFrom(const MemorySlot &slot) { return false; }
+
+bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
+ return getDst() == slot.ptr;
+}
+
+Value LLVM::MemsetOp::getStored(const MemorySlot &slot, OpBuilder &builder,
+ Value reachingDef,
+ const DataLayout &dataLayout) {
+ return memsetGetStored(*this, slot, builder);
+}
+
+bool LLVM::MemsetOp::canUsesBeRemoved(
+ const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+ SmallVectorImpl<OpOperand *> &newBlockingUses,
+ const DataLayout &dataLayout) {
+ return memsetCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
+ dataLayout);
+}
+
+DeletionKind LLVM::MemsetOp::removeBlockingUses(
+ const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+ OpBuilder &builder, Value reachingDefinition,
+ const DataLayout &dataLayout) {
+ return DeletionKind::Delete;
+}
+
+LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses(
+ const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
+ return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
+}
+
+bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
+ SmallPtrSetImpl<Attribute> &usedIndices,
+ SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
+ return memsetCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
+ dataLayout);
+}
+
+DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
+ DenseMap<Attribute, MemorySlot> &subslots,
+ OpBuilder &builder,
+ const DataLayout &dataLayout) {
+ return memsetRewire(*this, slot, subslots, builder, dataLayout);
+}
+
+bool LLVM::MemsetInlineOp::loadsFrom(const MemorySlot &slot) { return false; }
+
+bool LLVM::MemsetInlineOp::storesTo(const MemorySlot &slot) {
+ return getDst() == slot.ptr;
+}
+
+Value LLVM::MemsetInlineOp::getStored(const MemorySlot &slot,
+ OpBuilder &builder, Value reachingDef,
+ const DataLayout &dataLayout) {
+ return memsetGetStored(*this, slot, builder);
+}
+
+bool LLVM::MemsetInlineOp::canUsesBeRemoved(
+ const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+ SmallVectorImpl<OpOperand *> &newBlockingUses,
+ const DataLayout &dataLayout) {
+ return memsetCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
+ dataLayout);
+}
+
+DeletionKind LLVM::MemsetInlineOp::removeBlockingUses(
+ const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+ OpBuilder &builder, Value reachingDefinition,
+ const DataLayout &dataLayout) {
+ return DeletionKind::Delete;
+}
+
+LogicalResult LLVM::MemsetInlineOp::ensureOnlySafeAccesses(
+ const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
+ return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
+}
+
+bool LLVM::MemsetInlineOp::canRewire(
+ const DestructurableMemorySlot &slot,
+ SmallPtrSetImpl<Attribute> &usedIndices,
+ SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
+ return memsetCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
+ dataLayout);
+}
+
+DeletionKind
+LLVM::MemsetInlineOp::rewire(const DestructurableMemorySlot &slot,
+ DenseMap<Attribute, MemorySlot> &subslots,
+ OpBuilder &builder, const DataLayout &dataLayout) {
+ return memsetRewire(*this, slot, subslots, builder, dataLayout);
+}
+
//===----------------------------------------------------------------------===//
// Interfaces for memcpy/memmove
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index cb712eb4e1262d..9d45f219cf746e 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -533,6 +533,10 @@ llvm.func @memset_test(%arg0: i32, %arg2: !llvm.ptr, %arg3: i8) {
%i1 = llvm.mlir.constant(false) : i1
// CHECK: call void @llvm.memset.p0.i32(ptr %{{.*}}, i8 %{{.*}}, i32 %{{.*}}, i1 false
"llvm.intr.memset"(%arg2, %arg3, %arg0) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ // CHECK: call void @llvm.memset.inline.p0.i32(ptr %{{.*}}, i8 %{{.*}}, i32 10, i1 true
+ "llvm.intr.memset.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i32}> : (!llvm.ptr, i8) -> ()
+ // CHECK: call void @llvm.memset.inline.p0.i64(ptr %{{.*}}, i8 %{{.*}}, i64 10, i1 true
+ "llvm.intr.memset.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i64}> : (!llvm.ptr, i8) -> ()
llvm.return
}
More information about the Mlir-commits
mailing list