[llvm] d3d35ad - [DirectX] Legalize i8 allocas (#137399)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 29 13:07:48 PDT 2025
Author: Farzon Lotfi
Date: 2025-04-29T16:07:44-04:00
New Revision: d3d35adcd32c91e9076be6bb242dd6c82c490c4b
URL: https://github.com/llvm/llvm-project/commit/d3d35adcd32c91e9076be6bb242dd6c82c490c4b
DIFF: https://github.com/llvm/llvm-project/commit/d3d35adcd32c91e9076be6bb242dd6c82c490c4b.diff
LOG: [DirectX] Legalize i8 allocas (#137399)
fixes #137202
investingating i8 allocas I came to find some missing instructions from
out i8 legalization around load, store, and select.
Added those three.
To do i8 allocas right though we needed to walk the uses and find the
casts.
After finding the casts I chose to pick the smallest cast as the cast to
transform to. That would then let me preserve the larger casts that come
later
Added:
llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll
Modified:
llvm/lib/Target/DirectX/DXILLegalizePass.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
index b62ff4c52f70c..7da5a71ab729b 100644
--- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
@@ -12,6 +12,7 @@
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/Pass.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include <functional>
@@ -31,16 +32,17 @@ static void legalizeFreeze(Instruction &I,
ToRemove.push_back(FI);
}
-static void fixI8TruncUseChain(Instruction &I,
- SmallVectorImpl<Instruction *> &ToRemove,
- DenseMap<Value *, Value *> &ReplacedValues) {
+static void fixI8UseChain(Instruction &I,
+ SmallVectorImpl<Instruction *> &ToRemove,
+ DenseMap<Value *, Value *> &ReplacedValues) {
auto ProcessOperands = [&](SmallVector<Value *> &NewOperands) {
Type *InstrType = IntegerType::get(I.getContext(), 32);
for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
Value *Op = I.getOperand(OpIdx);
- if (ReplacedValues.count(Op))
+ if (ReplacedValues.count(Op) &&
+ ReplacedValues[Op]->getType()->isIntegerTy())
InstrType = ReplacedValues[Op]->getType();
}
@@ -73,6 +75,31 @@ static void fixI8TruncUseChain(Instruction &I,
}
}
+ if (auto *Store = dyn_cast<StoreInst>(&I)) {
+ if (!Store->getValueOperand()->getType()->isIntegerTy(8))
+ return;
+ SmallVector<Value *> NewOperands;
+ ProcessOperands(NewOperands);
+ Value *NewStore = Builder.CreateStore(NewOperands[0], NewOperands[1]);
+ ReplacedValues[Store] = NewStore;
+ ToRemove.push_back(Store);
+ return;
+ }
+
+ if (auto *Load = dyn_cast<LoadInst>(&I)) {
+ if (!I.getType()->isIntegerTy(8))
+ return;
+ SmallVector<Value *> NewOperands;
+ ProcessOperands(NewOperands);
+ Type *ElementType = NewOperands[0]->getType();
+ if (auto *AI = dyn_cast<AllocaInst>(NewOperands[0]))
+ ElementType = AI->getAllocatedType();
+ LoadInst *NewLoad = Builder.CreateLoad(ElementType, NewOperands[0]);
+ ReplacedValues[Load] = NewLoad;
+ ToRemove.push_back(Load);
+ return;
+ }
+
if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
if (!I.getType()->isIntegerTy(8))
return;
@@ -81,16 +108,29 @@ static void fixI8TruncUseChain(Instruction &I,
Value *NewInst =
Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]);
if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) {
- if (OBO->hasNoSignedWrap())
- cast<BinaryOperator>(NewInst)->setHasNoSignedWrap();
- if (OBO->hasNoUnsignedWrap())
- cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap();
+ auto *NewBO = dyn_cast<BinaryOperator>(NewInst);
+ if (NewBO && OBO->hasNoSignedWrap())
+ NewBO->setHasNoSignedWrap();
+ if (NewBO && OBO->hasNoUnsignedWrap())
+ NewBO->setHasNoUnsignedWrap();
}
ReplacedValues[BO] = NewInst;
ToRemove.push_back(BO);
return;
}
+ if (auto *Sel = dyn_cast<SelectInst>(&I)) {
+ if (!I.getType()->isIntegerTy(8))
+ return;
+ SmallVector<Value *> NewOperands;
+ ProcessOperands(NewOperands);
+ Value *NewInst = Builder.CreateSelect(Sel->getCondition(), NewOperands[1],
+ NewOperands[2]);
+ ReplacedValues[Sel] = NewInst;
+ ToRemove.push_back(Sel);
+ return;
+ }
+
if (auto *Cmp = dyn_cast<CmpInst>(&I)) {
if (!Cmp->getOperand(0)->getType()->isIntegerTy(8))
return;
@@ -105,13 +145,61 @@ static void fixI8TruncUseChain(Instruction &I,
}
if (auto *Cast = dyn_cast<CastInst>(&I)) {
- if (Cast->getSrcTy()->isIntegerTy(8)) {
- ToRemove.push_back(Cast);
- Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]);
+ if (!Cast->getSrcTy()->isIntegerTy(8))
+ return;
+
+ ToRemove.push_back(Cast);
+ auto *Replacement = ReplacedValues[Cast->getOperand(0)];
+ if (Cast->getType() == Replacement->getType()) {
+ Cast->replaceAllUsesWith(Replacement);
+ return;
}
+ Value *AdjustedCast = nullptr;
+ if (Cast->getOpcode() == Instruction::ZExt)
+ AdjustedCast = Builder.CreateZExtOrTrunc(Replacement, Cast->getType());
+ if (Cast->getOpcode() == Instruction::SExt)
+ AdjustedCast = Builder.CreateSExtOrTrunc(Replacement, Cast->getType());
+
+ if (AdjustedCast)
+ Cast->replaceAllUsesWith(AdjustedCast);
}
}
+static void upcastI8AllocasAndUses(Instruction &I,
+ SmallVectorImpl<Instruction *> &ToRemove,
+ DenseMap<Value *, Value *> &ReplacedValues) {
+ auto *AI = dyn_cast<AllocaInst>(&I);
+ if (!AI || !AI->getAllocatedType()->isIntegerTy(8))
+ return;
+
+ Type *SmallestType = nullptr;
+
+ // Gather all cast targets
+ for (User *U : AI->users()) {
+ auto *Load = dyn_cast<LoadInst>(U);
+ if (!Load)
+ continue;
+ for (User *LU : Load->users()) {
+ auto *Cast = dyn_cast<CastInst>(LU);
+ if (!Cast)
+ continue;
+ Type *Ty = Cast->getType();
+ if (!SmallestType ||
+ Ty->getPrimitiveSizeInBits() < SmallestType->getPrimitiveSizeInBits())
+ SmallestType = Ty;
+ }
+ }
+
+ if (!SmallestType)
+ return; // no valid casts found
+
+ // Replace alloca
+ IRBuilder<> Builder(AI);
+ auto *NewAlloca = Builder.CreateAlloca(SmallestType);
+ ReplacedValues[AI] = NewAlloca;
+ ToRemove.push_back(AI);
+}
+
static void
downcastI64toI32InsertExtractElements(Instruction &I,
SmallVectorImpl<Instruction *> &ToRemove,
@@ -178,7 +266,8 @@ class DXILLegalizationPipeline {
LegalizationPipeline;
void initializeLegalizationPipeline() {
- LegalizationPipeline.push_back(fixI8TruncUseChain);
+ LegalizationPipeline.push_back(upcastI8AllocasAndUses);
+ LegalizationPipeline.push_back(fixI8UseChain);
LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
LegalizationPipeline.push_back(legalizeFreeze);
}
diff --git a/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll b/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll
new file mode 100644
index 0000000000000..529a69fca5d34
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll
@@ -0,0 +1,99 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+define void @const_i8_store() {
+; CHECK-LABEL: define void @const_i8_store() {
+; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = alloca i32, align 4
+; CHECK-NEXT: store i32 1, ptr [[TMP1]], align 4
+; CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4
+; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
+; CHECK-NEXT: store i32 [[TMP2]], ptr [[GEP]], align 4
+; CHECK-NEXT: ret void
+;
+ %accum.i.flat = alloca [1 x i32], align 4
+ %i = alloca i8, align 4
+ store i8 1, ptr %i
+ %i8.load = load i8, ptr %i
+ %z = zext i8 %i8.load to i32
+ %gep = getelementptr i32, ptr %accum.i.flat, i32 0
+ store i32 %z, ptr %gep, align 4
+ ret void
+}
+
+define void @const_add_i8_store() {
+; CHECK-LABEL: define void @const_add_i8_store() {
+; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = alloca i32, align 4
+; CHECK-NEXT: store i32 4, ptr [[TMP1]], align 4
+; CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4
+; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
+; CHECK-NEXT: store i32 [[TMP2]], ptr [[GEP]], align 4
+; CHECK-NEXT: ret void
+;
+ %accum.i.flat = alloca [1 x i32], align 4
+ %i = alloca i8, align 4
+ %add_i8 = add nsw i8 3, 1
+ store i8 %add_i8, ptr %i
+ %i8.load = load i8, ptr %i
+ %z = zext i8 %i8.load to i32
+ %gep = getelementptr i32, ptr %accum.i.flat, i32 0
+ store i32 %z, ptr %gep, align 4
+ ret void
+}
+
+define void @var_i8_store(i1 %cmp.i8) {
+; CHECK-LABEL: define void @var_i8_store(
+; CHECK-SAME: i1 [[CMP_I8:%.*]]) {
+; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = alloca i32, align 4
+; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[CMP_I8]], i32 1, i32 2
+; CHECK-NEXT: store i32 [[TMP2]], ptr [[TMP1]], align 4
+; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr [[TMP1]], align 4
+; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
+; CHECK-NEXT: store i32 [[TMP3]], ptr [[GEP]], align 4
+; CHECK-NEXT: ret void
+;
+ %accum.i.flat = alloca [1 x i32], align 4
+ %i = alloca i8, align 4
+ %select.i8 = select i1 %cmp.i8, i8 1, i8 2
+ store i8 %select.i8, ptr %i
+ %i8.load = load i8, ptr %i
+ %z = zext i8 %i8.load to i32
+ %gep = getelementptr i32, ptr %accum.i.flat, i32 0
+ store i32 %z, ptr %gep, align 4
+ ret void
+}
+
+define void @conflicting_cast(i1 %cmp.i8) {
+; CHECK-LABEL: define void @conflicting_cast(
+; CHECK-SAME: i1 [[CMP_I8:%.*]]) {
+; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [2 x i32], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = alloca i16, align 2
+; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[CMP_I8]], i32 1, i32 2
+; CHECK-NEXT: store i32 [[TMP2]], ptr [[TMP1]], align 4
+; CHECK-NEXT: [[TMP3:%.*]] = load i16, ptr [[TMP1]], align 2
+; CHECK-NEXT: [[GEP1:%.*]] = getelementptr i16, ptr [[ACCUM_I_FLAT]], i32 0
+; CHECK-NEXT: store i16 [[TMP3]], ptr [[GEP1]], align 2
+; CHECK-NEXT: [[GEP2:%.*]] = getelementptr i16, ptr [[ACCUM_I_FLAT]], i32 1
+; CHECK-NEXT: store i16 [[TMP3]], ptr [[GEP2]], align 2
+; CHECK-NEXT: [[TMP4:%.*]] = zext i16 [[TMP3]] to i32
+; CHECK-NEXT: [[GEP3:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 1
+; CHECK-NEXT: store i32 [[TMP4]], ptr [[GEP3]], align 4
+; CHECK-NEXT: ret void
+;
+ %accum.i.flat = alloca [2 x i32], align 4
+ %i = alloca i8, align 4
+ %select.i8 = select i1 %cmp.i8, i8 1, i8 2
+ store i8 %select.i8, ptr %i
+ %i8.load = load i8, ptr %i
+ %z = zext i8 %i8.load to i16
+ %gep1 = getelementptr i16, ptr %accum.i.flat, i32 0
+ store i16 %z, ptr %gep1, align 2
+ %gep2 = getelementptr i16, ptr %accum.i.flat, i32 1
+ store i16 %z, ptr %gep2, align 2
+ %z2 = zext i8 %i8.load to i32
+ %gep3 = getelementptr i32, ptr %accum.i.flat, i32 1
+ store i32 %z2, ptr %gep3, align 4
+ ret void
+}
More information about the llvm-commits
mailing list