[llvm] [DirectX] replace byte splitting via vector bitcast with scalar (PR #140167)
Farzon Lotfi via llvm-commits
llvm-commits at lists.llvm.org
Fri May 16 14:50:46 PDT 2025
https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/140167
>From 4d087227c830a88f7a958d1f5b6cf88886d5bd50 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Thu, 15 May 2025 18:11:44 -0400
Subject: [PATCH 1/3] [DirectX] replace byte splitting via vector bitcast with
scalar instructions - instead of bitcasting and extract element lets use
trunc or trunc and logical shift right to split. - fixes #139020
---
llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 45 +++++++++++++++++++
.../legalize-i64-high-low-vec-spilt.ll | 18 ++++++++
2 files changed, 63 insertions(+)
create mode 100644 llvm/test/CodeGen/DirectX/legalize-i64-high-low-vec-spilt.ll
diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
index be77a70fa46ba..a99f706763c0f 100644
--- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
@@ -8,6 +8,8 @@
#include "DXILLegalizePass.h"
#include "DirectX.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
@@ -317,6 +319,48 @@ static void removeMemSet(Instruction &I,
ToRemove.push_back(CI);
}
+static void
+legalizeGetHighLowi64Bytes(Instruction &I,
+ SmallVectorImpl<Instruction *> &ToRemove,
+ DenseMap<Value *, Value *> &ReplacedValues) {
+ if (auto *BitCast = dyn_cast<BitCastInst>(&I)) {
+ if (BitCast->getDestTy() ==
+ FixedVectorType::get(Type::getInt32Ty(I.getContext()), 2) &&
+ BitCast->getSrcTy()->isIntegerTy(64)) {
+ ToRemove.push_back(BitCast);
+ ReplacedValues[BitCast] = BitCast->getOperand(0);
+ }
+ }
+
+ if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) {
+ auto *VecTy = dyn_cast<FixedVectorType>(Extract->getVectorOperandType());
+ if (VecTy && VecTy->getElementType()->isIntegerTy(32) &&
+ VecTy->getNumElements() == 2) {
+ if (auto *Index = dyn_cast<ConstantInt>(Extract->getIndexOperand())) {
+ unsigned Idx = Index->getZExtValue();
+ IRBuilder<> Builder(&I);
+ assert(dyn_cast<BitCastInst>(Extract->getVectorOperand()));
+ auto *Replacement = ReplacedValues[Extract->getVectorOperand()];
+ if (Idx == 0) {
+ Value *LowBytes = Builder.CreateTrunc(
+ Replacement, Type::getInt32Ty(I.getContext()));
+ ReplacedValues[Extract] = LowBytes;
+ } else {
+ assert(Idx == 1);
+ Value *LogicalShiftRight = Builder.CreateLShr(
+ Replacement,
+ ConstantInt::get(
+ Replacement->getType(),
+ APInt(Replacement->getType()->getIntegerBitWidth(), 32)));
+ Value *HighBytes = Builder.CreateTrunc(
+ LogicalShiftRight, Type::getInt32Ty(I.getContext()));
+ ReplacedValues[Extract] = HighBytes;
+ }
+ ToRemove.push_back(Extract);
+ }
+ }
+ }
+}
namespace {
class DXILLegalizationPipeline {
@@ -349,6 +393,7 @@ class DXILLegalizationPipeline {
LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
LegalizationPipeline.push_back(legalizeFreeze);
LegalizationPipeline.push_back(removeMemSet);
+ LegalizationPipeline.push_back(legalizeGetHighLowi64Bytes);
}
};
diff --git a/llvm/test/CodeGen/DirectX/legalize-i64-high-low-vec-spilt.ll b/llvm/test/CodeGen/DirectX/legalize-i64-high-low-vec-spilt.ll
new file mode 100644
index 0000000000000..17fd3bf54acda
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/legalize-i64-high-low-vec-spilt.ll
@@ -0,0 +1,18 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+define void @split_via_extract(i64 noundef %a) {
+; CHECK-LABEL: define void @split_via_extract(
+; CHECK-SAME: i64 noundef [[A:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = trunc i64 [[A]] to i32
+; CHECK-NEXT: [[TMP1:%.*]] = lshr i64 [[A]], 32
+; CHECK-NEXT: [[TMP2:%.*]] = trunc i64 [[TMP1]] to i32
+; CHECK-NEXT: ret void
+;
+entry:
+ %vecA = bitcast i64 %a to <2 x i32>
+ %low = extractelement <2 x i32> %vecA, i32 0 ; low 32 bits
+ %high = extractelement <2 x i32> %vecA, i32 1 ; high 32 bits
+ ret void
+}
>From ac5b8a01d0b3fa27c4e6f683d2ee1db35008cb81 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Fri, 16 May 2025 12:34:24 -0400
Subject: [PATCH 2/3] address pr comment
---
llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 3 +++
1 file changed, 3 insertions(+)
diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
index a99f706763c0f..ec0d9d7a300fe 100644
--- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
@@ -329,6 +329,7 @@ legalizeGetHighLowi64Bytes(Instruction &I,
BitCast->getSrcTy()->isIntegerTy(64)) {
ToRemove.push_back(BitCast);
ReplacedValues[BitCast] = BitCast->getOperand(0);
+ return;
}
}
@@ -341,6 +342,8 @@ legalizeGetHighLowi64Bytes(Instruction &I,
IRBuilder<> Builder(&I);
assert(dyn_cast<BitCastInst>(Extract->getVectorOperand()));
auto *Replacement = ReplacedValues[Extract->getVectorOperand()];
+ assert(Replacement && "The BitCast replacement should have been set "
+ "before working on ExtractElementInst.");
if (Idx == 0) {
Value *LowBytes = Builder.CreateTrunc(
Replacement, Type::getInt32Ty(I.getContext()));
>From 0f831294438305aa94ba07d9aeceb36ac7734e70 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Fri, 16 May 2025 14:50:30 -0700
Subject: [PATCH 3/3] Add Legalization stages. Passes that change the same
instructions need to be staggered.
---
llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 45 ++++++++++++--------
1 file changed, 27 insertions(+), 18 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
index ec0d9d7a300fe..d92701431a59c 100644
--- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
@@ -360,6 +360,7 @@ legalizeGetHighLowi64Bytes(Instruction &I,
ReplacedValues[Extract] = HighBytes;
}
ToRemove.push_back(Extract);
+ Extract->replaceAllUsesWith(ReplacedValues[Extract]);
}
}
}
@@ -371,32 +372,40 @@ class DXILLegalizationPipeline {
DXILLegalizationPipeline() { initializeLegalizationPipeline(); }
bool runLegalizationPipeline(Function &F) {
- SmallVector<Instruction *> ToRemove;
- DenseMap<Value *, Value *> ReplacedValues;
- for (auto &I : instructions(F)) {
- for (auto &LegalizationFn : LegalizationPipeline)
- LegalizationFn(I, ToRemove, ReplacedValues);
- }
+ bool MadeChange = false;
+ for (int Stage = 0; Stage < NumStages; ++Stage) {
+ SmallVector<Instruction *> ToRemove;
+ DenseMap<Value *, Value *> ReplacedValues;
+ for (auto &I : instructions(F)) {
+ for (auto &LegalizationFn : LegalizationPipeline[Stage])
+ LegalizationFn(I, ToRemove, ReplacedValues);
+ }
- for (auto *Inst : reverse(ToRemove))
- Inst->eraseFromParent();
+ for (auto *Inst : reverse(ToRemove))
+ Inst->eraseFromParent();
- return !ToRemove.empty();
+ MadeChange |= !ToRemove.empty();
+ }
+ return MadeChange;
}
private:
- SmallVector<
+ enum LegalizationStage { Stage1 = 0, Stage2 = 1, NumStages };
+
+ using LegalizationFnTy =
std::function<void(Instruction &, SmallVectorImpl<Instruction *> &,
- DenseMap<Value *, Value *> &)>>
- LegalizationPipeline;
+ DenseMap<Value *, Value *> &)>;
+
+ SmallVector<LegalizationFnTy> LegalizationPipeline[NumStages];
void initializeLegalizationPipeline() {
- LegalizationPipeline.push_back(upcastI8AllocasAndUses);
- LegalizationPipeline.push_back(fixI8UseChain);
- LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
- LegalizationPipeline.push_back(legalizeFreeze);
- LegalizationPipeline.push_back(removeMemSet);
- LegalizationPipeline.push_back(legalizeGetHighLowi64Bytes);
+ LegalizationPipeline[Stage1].push_back(upcastI8AllocasAndUses);
+ LegalizationPipeline[Stage1].push_back(fixI8UseChain);
+ LegalizationPipeline[Stage1].push_back(legalizeGetHighLowi64Bytes);
+ LegalizationPipeline[Stage1].push_back(legalizeFreeze);
+ LegalizationPipeline[Stage1].push_back(removeMemSet);
+ LegalizationPipeline[Stage2].push_back(
+ downcastI64toI32InsertExtractElements);
}
};
More information about the llvm-commits
mailing list