[llvm] [DirectX] replace byte splitting via vector bitcast with scalar (PR #140167)
Farzon Lotfi via llvm-commits
llvm-commits at lists.llvm.org
Thu May 22 21:10:27 PDT 2025
https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/140167
>From ef0efad21e717d6f0fa14defe15bb44a357acb87 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Thu, 15 May 2025 18:11:44 -0400
Subject: [PATCH 1/3] [DirectX] replace byte splitting via vector bitcast with
scalar instructions - instead of bitcasting and extract element lets use
trunc or trunc and logical shift right to split. - fixes #139020
---
llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 46 +++++++++++++++++++
.../legalize-i64-high-low-vec-spilt.ll | 18 ++++++++
2 files changed, 64 insertions(+)
create mode 100644 llvm/test/CodeGen/DirectX/legalize-i64-high-low-vec-spilt.ll
diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
index 23883c936a20d..295b31185476c 100644
--- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
@@ -8,6 +8,8 @@
#include "DXILLegalizePass.h"
#include "DirectX.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
@@ -419,6 +421,49 @@ static void updateFnegToFsub(Instruction &I,
ToRemove.push_back(&I);
}
+static void
+legalizeGetHighLowi64Bytes(Instruction &I,
+ SmallVectorImpl<Instruction *> &ToRemove,
+ DenseMap<Value *, Value *> &ReplacedValues) {
+ if (auto *BitCast = dyn_cast<BitCastInst>(&I)) {
+ if (BitCast->getDestTy() ==
+ FixedVectorType::get(Type::getInt32Ty(I.getContext()), 2) &&
+ BitCast->getSrcTy()->isIntegerTy(64)) {
+ ToRemove.push_back(BitCast);
+ ReplacedValues[BitCast] = BitCast->getOperand(0);
+ }
+ }
+
+ if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) {
+ auto *VecTy = dyn_cast<FixedVectorType>(Extract->getVectorOperandType());
+ if (VecTy && VecTy->getElementType()->isIntegerTy(32) &&
+ VecTy->getNumElements() == 2) {
+ if (auto *Index = dyn_cast<ConstantInt>(Extract->getIndexOperand())) {
+ unsigned Idx = Index->getZExtValue();
+ IRBuilder<> Builder(&I);
+ assert(dyn_cast<BitCastInst>(Extract->getVectorOperand()));
+ auto *Replacement = ReplacedValues[Extract->getVectorOperand()];
+ if (Idx == 0) {
+ Value *LowBytes = Builder.CreateTrunc(
+ Replacement, Type::getInt32Ty(I.getContext()));
+ ReplacedValues[Extract] = LowBytes;
+ } else {
+ assert(Idx == 1);
+ Value *LogicalShiftRight = Builder.CreateLShr(
+ Replacement,
+ ConstantInt::get(
+ Replacement->getType(),
+ APInt(Replacement->getType()->getIntegerBitWidth(), 32)));
+ Value *HighBytes = Builder.CreateTrunc(
+ LogicalShiftRight, Type::getInt32Ty(I.getContext()));
+ ReplacedValues[Extract] = HighBytes;
+ }
+ ToRemove.push_back(Extract);
+ }
+ }
+ }
+}
+
namespace {
class DXILLegalizationPipeline {
@@ -453,6 +498,7 @@ class DXILLegalizationPipeline {
LegalizationPipeline.push_back(legalizeMemCpy);
LegalizationPipeline.push_back(removeMemSet);
LegalizationPipeline.push_back(updateFnegToFsub);
+ LegalizationPipeline.push_back(legalizeGetHighLowi64Bytes);
}
};
diff --git a/llvm/test/CodeGen/DirectX/legalize-i64-high-low-vec-spilt.ll b/llvm/test/CodeGen/DirectX/legalize-i64-high-low-vec-spilt.ll
new file mode 100644
index 0000000000000..17fd3bf54acda
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/legalize-i64-high-low-vec-spilt.ll
@@ -0,0 +1,18 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+define void @split_via_extract(i64 noundef %a) {
+; CHECK-LABEL: define void @split_via_extract(
+; CHECK-SAME: i64 noundef [[A:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = trunc i64 [[A]] to i32
+; CHECK-NEXT: [[TMP1:%.*]] = lshr i64 [[A]], 32
+; CHECK-NEXT: [[TMP2:%.*]] = trunc i64 [[TMP1]] to i32
+; CHECK-NEXT: ret void
+;
+entry:
+ %vecA = bitcast i64 %a to <2 x i32>
+ %low = extractelement <2 x i32> %vecA, i32 0 ; low 32 bits
+ %high = extractelement <2 x i32> %vecA, i32 1 ; high 32 bits
+ ret void
+}
>From 70cc62ac1f492be23a777f0309faf5eb880c5c3d Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Fri, 16 May 2025 12:34:24 -0400
Subject: [PATCH 2/3] address pr comment
---
llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 3 +++
1 file changed, 3 insertions(+)
diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
index 295b31185476c..a16b803829014 100644
--- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
@@ -431,6 +431,7 @@ legalizeGetHighLowi64Bytes(Instruction &I,
BitCast->getSrcTy()->isIntegerTy(64)) {
ToRemove.push_back(BitCast);
ReplacedValues[BitCast] = BitCast->getOperand(0);
+ return;
}
}
@@ -443,6 +444,8 @@ legalizeGetHighLowi64Bytes(Instruction &I,
IRBuilder<> Builder(&I);
assert(dyn_cast<BitCastInst>(Extract->getVectorOperand()));
auto *Replacement = ReplacedValues[Extract->getVectorOperand()];
+ assert(Replacement && "The BitCast replacement should have been set "
+ "before working on ExtractElementInst.");
if (Idx == 0) {
Value *LowBytes = Builder.CreateTrunc(
Replacement, Type::getInt32Ty(I.getContext()));
>From ba88ccb2c2525651ad3d162f4f6efb08f3eecd1e Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Fri, 16 May 2025 14:50:30 -0700
Subject: [PATCH 3/3] Add Legalization stages. Passes that change the same
instructions need to be staggered.
---
llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 49 ++++++++++++--------
1 file changed, 29 insertions(+), 20 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
index a16b803829014..db17a8c9dc963 100644
--- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
@@ -462,6 +462,7 @@ legalizeGetHighLowi64Bytes(Instruction &I,
ReplacedValues[Extract] = HighBytes;
}
ToRemove.push_back(Extract);
+ Extract->replaceAllUsesWith(ReplacedValues[Extract]);
}
}
}
@@ -474,34 +475,42 @@ class DXILLegalizationPipeline {
DXILLegalizationPipeline() { initializeLegalizationPipeline(); }
bool runLegalizationPipeline(Function &F) {
- SmallVector<Instruction *> ToRemove;
- DenseMap<Value *, Value *> ReplacedValues;
- for (auto &I : instructions(F)) {
- for (auto &LegalizationFn : LegalizationPipeline)
- LegalizationFn(I, ToRemove, ReplacedValues);
- }
+ bool MadeChange = false;
+ for (int Stage = 0; Stage < NumStages; ++Stage) {
+ SmallVector<Instruction *> ToRemove;
+ DenseMap<Value *, Value *> ReplacedValues;
+ for (auto &I : instructions(F)) {
+ for (auto &LegalizationFn : LegalizationPipeline[Stage])
+ LegalizationFn(I, ToRemove, ReplacedValues);
+ }
- for (auto *Inst : reverse(ToRemove))
- Inst->eraseFromParent();
+ for (auto *Inst : reverse(ToRemove))
+ Inst->eraseFromParent();
- return !ToRemove.empty();
+ MadeChange |= !ToRemove.empty();
+ }
+ return MadeChange;
}
private:
- SmallVector<
+ enum LegalizationStage { Stage1 = 0, Stage2 = 1, NumStages };
+
+ using LegalizationFnTy =
std::function<void(Instruction &, SmallVectorImpl<Instruction *> &,
- DenseMap<Value *, Value *> &)>>
- LegalizationPipeline;
+ DenseMap<Value *, Value *> &)>;
+
+ SmallVector<LegalizationFnTy> LegalizationPipeline[NumStages];
void initializeLegalizationPipeline() {
- LegalizationPipeline.push_back(upcastI8AllocasAndUses);
- LegalizationPipeline.push_back(fixI8UseChain);
- LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
- LegalizationPipeline.push_back(legalizeFreeze);
- LegalizationPipeline.push_back(legalizeMemCpy);
- LegalizationPipeline.push_back(removeMemSet);
- LegalizationPipeline.push_back(updateFnegToFsub);
- LegalizationPipeline.push_back(legalizeGetHighLowi64Bytes);
+ LegalizationPipeline[Stage1].push_back(upcastI8AllocasAndUses);
+ LegalizationPipeline[Stage1].push_back(fixI8UseChain);
+ LegalizationPipeline[Stage1].push_back(legalizeGetHighLowi64Bytes);
+ LegalizationPipeline[Stage1].push_back(legalizeFreeze);
+ LegalizationPipeline[Stage1].push_back(legalizeMemCpy);
+ LegalizationPipeline[Stage1].push_back(removeMemSet);
+ LegalizationPipeline[Stage1].push_back(updateFnegToFsub);
+ LegalizationPipeline[Stage2].push_back(
+ downcastI64toI32InsertExtractElements);
}
};
More information about the llvm-commits
mailing list