[llvm] 02e316c - [DirectX] legalize memset (#136244)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 30 14:09:31 PDT 2025
Author: Farzon Lotfi
Date: 2025-04-30T17:09:28-04:00
New Revision: 02e316cf8c9c73aad580e8c0d1b3b691567601ca
URL: https://github.com/llvm/llvm-project/commit/02e316cf8c9c73aad580e8c0d1b3b691567601ca
DIFF: https://github.com/llvm/llvm-project/commit/02e316cf8c9c73aad580e8c0d1b3b691567601ca.diff
LOG: [DirectX] legalize memset (#136244)
fixes #136243
This change converts memset into a series of geps and stores It is
intentionally limited to memsets of fixed size It also converts the byte
stores to type stores.
DXIL does not support i8 plus this reduces the total number of gep and
store instructions.
This change also moves DXILFinalizeLinkage to run after Legalization to
clean up any dead intrinsic definitions.
Added:
llvm/test/CodeGen/DirectX/legalize-memset.ll
Modified:
llvm/lib/Target/DirectX/DXILLegalizePass.cpp
llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
llvm/test/CodeGen/DirectX/llc-pipeline.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
index 7da5a71ab729b..be77a70fa46ba 100644
--- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
@@ -13,6 +13,7 @@
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Module.h"
#include "llvm/Pass.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include <functional>
@@ -174,16 +175,22 @@ static void upcastI8AllocasAndUses(Instruction &I,
Type *SmallestType = nullptr;
- // Gather all cast targets
for (User *U : AI->users()) {
auto *Load = dyn_cast<LoadInst>(U);
if (!Load)
continue;
for (User *LU : Load->users()) {
- auto *Cast = dyn_cast<CastInst>(LU);
- if (!Cast)
+ Type *Ty = nullptr;
+ if (auto *Cast = dyn_cast<CastInst>(LU))
+ Ty = Cast->getType();
+ if (CallInst *CI = dyn_cast<CallInst>(LU)) {
+ if (CI->getIntrinsicID() == Intrinsic::memset)
+ Ty = Type::getInt32Ty(CI->getContext());
+ }
+
+ if (!Ty)
continue;
- Type *Ty = Cast->getType();
+
if (!SmallestType ||
Ty->getPrimitiveSizeInBits() < SmallestType->getPrimitiveSizeInBits())
SmallestType = Ty;
@@ -239,6 +246,77 @@ downcastI64toI32InsertExtractElements(Instruction &I,
}
}
+static void emitMemsetExpansion(IRBuilder<> &Builder, Value *Dst, Value *Val,
+ ConstantInt *SizeCI,
+ DenseMap<Value *, Value *> &ReplacedValues) {
+ LLVMContext &Ctx = Builder.getContext();
+ [[maybe_unused]] const DataLayout &DL =
+ Builder.GetInsertBlock()->getModule()->getDataLayout();
+ [[maybe_unused]] uint64_t OrigSize = SizeCI->getZExtValue();
+
+ AllocaInst *Alloca = dyn_cast<AllocaInst>(Dst);
+
+ assert(Alloca && "Expected memset on an Alloca");
+ assert(OrigSize == Alloca->getAllocationSize(DL)->getFixedValue() &&
+ "Expected for memset size to match DataLayout size");
+
+ Type *AllocatedTy = Alloca->getAllocatedType();
+ ArrayType *ArrTy = dyn_cast<ArrayType>(AllocatedTy);
+ assert(ArrTy && "Expected Alloca for an Array Type");
+
+ Type *ElemTy = ArrTy->getElementType();
+ uint64_t Size = ArrTy->getArrayNumElements();
+
+ [[maybe_unused]] uint64_t ElemSize = DL.getTypeStoreSize(ElemTy);
+
+ assert(ElemSize > 0 && "Size must be set");
+ assert(OrigSize == ElemSize * Size && "Size in bytes must match");
+
+ Value *TypedVal = Val;
+
+ if (Val->getType() != ElemTy) {
+ if (ReplacedValues[Val]) {
+ // Note for i8 replacements if we know them we should use them.
+ // Further if this is a constant ReplacedValues will return null
+ // so we will stick to TypedVal = Val
+ TypedVal = ReplacedValues[Val];
+
+ } else {
+ // This case Val is a ConstantInt so the cast folds away.
+ // However if we don't do the cast the store below ends up being
+ // an i8.
+ TypedVal = Builder.CreateIntCast(Val, ElemTy, false);
+ }
+ }
+
+ for (uint64_t I = 0; I < Size; ++I) {
+ Value *Offset = ConstantInt::get(Type::getInt32Ty(Ctx), I);
+ Value *Ptr = Builder.CreateGEP(ElemTy, Dst, Offset, "gep");
+ Builder.CreateStore(TypedVal, Ptr);
+ }
+}
+
+static void removeMemSet(Instruction &I,
+ SmallVectorImpl<Instruction *> &ToRemove,
+ DenseMap<Value *, Value *> &ReplacedValues) {
+
+ CallInst *CI = dyn_cast<CallInst>(&I);
+ if (!CI)
+ return;
+
+ Intrinsic::ID ID = CI->getIntrinsicID();
+ if (ID != Intrinsic::memset)
+ return;
+
+ IRBuilder<> Builder(&I);
+ Value *Dst = CI->getArgOperand(0);
+ Value *Val = CI->getArgOperand(1);
+ ConstantInt *Size = dyn_cast<ConstantInt>(CI->getArgOperand(2));
+ assert(Size && "Expected Size to be a ConstantInt");
+ emitMemsetExpansion(Builder, Dst, Val, Size, ReplacedValues);
+ ToRemove.push_back(CI);
+}
+
namespace {
class DXILLegalizationPipeline {
@@ -270,6 +348,7 @@ class DXILLegalizationPipeline {
LegalizationPipeline.push_back(fixI8UseChain);
LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
LegalizationPipeline.push_back(legalizeFreeze);
+ LegalizationPipeline.push_back(removeMemSet);
}
};
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index 398abd66dda16..10f4b4ee76619 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -98,7 +98,6 @@ class DirectXPassConfig : public TargetPassConfig {
FunctionPass *createTargetRegisterAllocator(bool) override { return nullptr; }
void addCodeGenPrepare() override {
- addPass(createDXILFinalizeLinkageLegacyPass());
addPass(createDXILIntrinsicExpansionLegacyPass());
addPass(createDXILCBufferAccessLegacyPass());
addPass(createDXILDataScalarizationLegacyPass());
@@ -109,6 +108,7 @@ class DirectXPassConfig : public TargetPassConfig {
addPass(createScalarizerPass(DxilScalarOptions));
addPass(createDXILForwardHandleAccessesLegacyPass());
addPass(createDXILLegalizeLegacyPass());
+ addPass(createDXILFinalizeLinkageLegacyPass());
addPass(createDXILTranslateMetadataLegacyPass());
addPass(createDXILOpLoweringLegacyPass());
addPass(createDXILPrepareModulePass());
diff --git a/llvm/test/CodeGen/DirectX/legalize-memset.ll b/llvm/test/CodeGen/DirectX/legalize-memset.ll
new file mode 100644
index 0000000000000..e97817ba824ed
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/legalize-memset.ll
@@ -0,0 +1,125 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -dxil-legalize -dxil-finalize-linkage -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+define void @replace_float_memset_test() #0 {
+; CHECK-LABEL: define void @replace_float_memset_test(
+; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [2 x float], align 4
+; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 8, ptr nonnull [[ACCUM_I_FLAT]])
+; CHECK-NEXT: [[GEP:%.*]] = getelementptr float, ptr [[ACCUM_I_FLAT]], i32 0
+; CHECK-NEXT: store float 0.000000e+00, ptr [[GEP]], align 4
+; CHECK-NEXT: [[GEP1:%.*]] = getelementptr float, ptr [[ACCUM_I_FLAT]], i32 1
+; CHECK-NEXT: store float 0.000000e+00, ptr [[GEP1]], align 4
+; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 8, ptr nonnull [[ACCUM_I_FLAT]])
+; CHECK-NEXT: ret void
+;
+ %accum.i.flat = alloca [2 x float], align 4
+ call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %accum.i.flat)
+ call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 0, i32 8, i1 false)
+ call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %accum.i.flat)
+ ret void
+}
+
+define void @replace_half_memset_test() #0 {
+; CHECK-LABEL: define void @replace_half_memset_test(
+; CHECK-SAME: ) #[[ATTR0]] {
+; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [2 x half], align 4
+; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]])
+; CHECK-NEXT: [[GEP:%.*]] = getelementptr half, ptr [[ACCUM_I_FLAT]], i32 0
+; CHECK-NEXT: store half 0xH0000, ptr [[GEP]], align 2
+; CHECK-NEXT: [[GEP1:%.*]] = getelementptr half, ptr [[ACCUM_I_FLAT]], i32 1
+; CHECK-NEXT: store half 0xH0000, ptr [[GEP1]], align 2
+; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]])
+; CHECK-NEXT: ret void
+;
+ %accum.i.flat = alloca [2 x half], align 4
+ call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %accum.i.flat)
+ call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 0, i32 4, i1 false)
+ call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %accum.i.flat)
+ ret void
+}
+
+define void @replace_double_memset_test() #0 {
+; CHECK-LABEL: define void @replace_double_memset_test(
+; CHECK-SAME: ) #[[ATTR0]] {
+; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [2 x double], align 4
+; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 16, ptr nonnull [[ACCUM_I_FLAT]])
+; CHECK-NEXT: [[GEP:%.*]] = getelementptr double, ptr [[ACCUM_I_FLAT]], i32 0
+; CHECK-NEXT: store double 0.000000e+00, ptr [[GEP]], align 8
+; CHECK-NEXT: [[GEP1:%.*]] = getelementptr double, ptr [[ACCUM_I_FLAT]], i32 1
+; CHECK-NEXT: store double 0.000000e+00, ptr [[GEP1]], align 8
+; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 16, ptr nonnull [[ACCUM_I_FLAT]])
+; CHECK-NEXT: ret void
+;
+ %accum.i.flat = alloca [2 x double], align 4
+ call void @llvm.lifetime.start.p0(i64 16, ptr nonnull %accum.i.flat)
+ call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 0, i32 16, i1 false)
+ call void @llvm.lifetime.end.p0(i64 16, ptr nonnull %accum.i.flat)
+ ret void
+}
+
+define void @replace_int16_memset_test() #0 {
+; CHECK-LABEL: define void @replace_int16_memset_test(
+; CHECK-SAME: ) #[[ATTR0]] {
+; CHECK-NEXT: [[CACHE_I:%.*]] = alloca [2 x i16], align 2
+; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[CACHE_I]])
+; CHECK-NEXT: [[GEP:%.*]] = getelementptr i16, ptr [[CACHE_I]], i32 0
+; CHECK-NEXT: store i16 0, ptr [[GEP]], align 2
+; CHECK-NEXT: [[GEP1:%.*]] = getelementptr i16, ptr [[CACHE_I]], i32 1
+; CHECK-NEXT: store i16 0, ptr [[GEP1]], align 2
+; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[CACHE_I]])
+; CHECK-NEXT: ret void
+;
+ %cache.i = alloca [2 x i16], align 2
+ call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %cache.i)
+ call void @llvm.memset.p0.i32(ptr nonnull align 2 dereferenceable(4) %cache.i, i8 0, i32 4, i1 false)
+ call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %cache.i)
+ ret void
+}
+
+define void @replace_int_memset_test() #0 {
+; CHECK-LABEL: define void @replace_int_memset_test(
+; CHECK-SAME: ) #[[ATTR0]] {
+; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
+; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]])
+; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
+; CHECK-NEXT: store i32 0, ptr [[GEP]], align 4
+; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]])
+; CHECK-NEXT: ret void
+;
+ %accum.i.flat = alloca [1 x i32], align 4
+ call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %accum.i.flat)
+ call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 0, i32 4, i1 false)
+ call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %accum.i.flat)
+ ret void
+}
+
+define void @replace_int_memset_to_var_test() #0 {
+; CHECK-LABEL: define void @replace_int_memset_to_var_test(
+; CHECK-SAME: ) #[[ATTR0]] {
+; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
+; CHECK-NEXT: [[I:%.*]] = alloca i32, align 4
+; CHECK-NEXT: store i32 1, ptr [[I]], align 4
+; CHECK-NEXT: [[I8_LOAD:%.*]] = load i32, ptr [[I]], align 4
+; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]])
+; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
+; CHECK-NEXT: store i32 [[I8_LOAD]], ptr [[GEP]], align 4
+; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]])
+; CHECK-NEXT: ret void
+;
+ %accum.i.flat = alloca [1 x i32], align 4
+ %i = alloca i8, align 4
+ store i8 1, ptr %i
+ %i8.load = load i8, ptr %i
+ call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %accum.i.flat)
+ call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 %i8.load, i32 4, i1 false)
+ call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %accum.i.flat)
+ ret void
+}
+
+attributes #0 = {"hlsl.export"}
+
+
+declare void @llvm.lifetime.end.p0(i64 immarg, ptr captures(none))
+declare void @llvm.lifetime.start.p0(i64 immarg, ptr captures(none))
+declare void @llvm.memset.p0.i32(ptr writeonly captures(none), i8, i32, i1 immarg)
diff --git a/llvm/test/CodeGen/DirectX/llc-pipeline.ll b/llvm/test/CodeGen/DirectX/llc-pipeline.ll
index a2412b6324a05..55dd86c9fad1d 100644
--- a/llvm/test/CodeGen/DirectX/llc-pipeline.ll
+++ b/llvm/test/CodeGen/DirectX/llc-pipeline.ll
@@ -13,7 +13,6 @@
; CHECK-OBJ-NEXT: Create Garbage Collector Module Metadata
; CHECK-NEXT: ModulePass Manager
-; CHECK-NEXT: DXIL Finalize Linkage
; CHECK-NEXT: DXIL Intrinsic Expansion
; CHECK-NEXT: DXIL CBuffer Access
; CHECK-NEXT: DXIL Data Scalarization
@@ -24,6 +23,7 @@
; CHECK-NEXT: Scalarize vector operations
; CHECK-NEXT: DXIL Forward Handle Accesses
; CHECK-NEXT: DXIL Legalizer
+; CHECK-NEXT: DXIL Finalize Linkage
; CHECK-NEXT: DXIL Resources Analysis
; CHECK-NEXT: DXIL Module Metadata analysis
; CHECK-NEXT: DXIL Shader Flag Analysis
More information about the llvm-commits
mailing list