[llvm] [DirectX] Address PR comments to #131221 (PR #131706)
Farzon Lotfi via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 20 11:08:08 PDT 2025
https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/131706
>From 723e1837593131c10d6ed096674a273d5c530532 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Mon, 17 Mar 2025 21:44:37 -0400
Subject: [PATCH 1/3] [DirectX] Address PR comments to #131221
---
llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 55 +++++++-------------
1 file changed, 20 insertions(+), 35 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
index f9a494ce63dd3..317bff40caf7e 100644
--- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
@@ -5,12 +5,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===---------------------------------------------------------------------===//
-//===---------------------------------------------------------------------===//
-///
-/// \file This file contains a pass to remove i8 truncations and i64 extract
-/// and insert elements.
-///
-//===----------------------------------------------------------------------===//
+
#include "DXILLegalizePass.h"
#include "DirectX.h"
#include "llvm/IR/Function.h"
@@ -20,31 +15,27 @@
#include "llvm/Pass.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include <functional>
-#include <map>
-#include <stack>
-#include <vector>
#define DEBUG_TYPE "dxil-legalize"
using namespace llvm;
namespace {
-static void fixI8TruncUseChain(Instruction &I,
- std::stack<Instruction *> &ToRemove,
- std::map<Value *, Value *> &ReplacedValues) {
+void fixI8TruncUseChain(Instruction &I, SmallVector<Instruction *> &ToRemove,
+ DenseMap<Value *, Value *> &ReplacedValues) {
auto *Cmp = dyn_cast<CmpInst>(&I);
if (auto *Trunc = dyn_cast<TruncInst>(&I)) {
if (Trunc->getDestTy()->isIntegerTy(8)) {
ReplacedValues[Trunc] = Trunc->getOperand(0);
- ToRemove.push(Trunc);
+ ToRemove.push_back(Trunc);
}
} else if (I.getType()->isIntegerTy(8) ||
(Cmp && Cmp->getOperand(0)->getType()->isIntegerTy(8))) {
IRBuilder<> Builder(&I);
- std::vector<Value *> NewOperands;
+ SmallVector<Value *> NewOperands;
Type *InstrType = IntegerType::get(I.getContext(), 32);
for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
Value *Op = I.getOperand(OpIdx);
@@ -88,20 +79,19 @@ static void fixI8TruncUseChain(Instruction &I,
if (NewInst) {
ReplacedValues[&I] = NewInst;
- ToRemove.push(&I);
+ ToRemove.push_back(&I);
}
} else if (auto *Cast = dyn_cast<CastInst>(&I)) {
if (Cast->getSrcTy()->isIntegerTy(8)) {
- ToRemove.push(Cast);
+ ToRemove.push_back(Cast);
Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]);
}
}
}
-static void
-downcastI64toI32InsertExtractElements(Instruction &I,
- std::stack<Instruction *> &ToRemove,
- std::map<Value *, Value *> &) {
+void downcastI64toI32InsertExtractElements(Instruction &I,
+ SmallVector<Instruction *> &ToRemove,
+ DenseMap<Value *, Value *> &) {
if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) {
Value *Idx = Extract->getIndexOperand();
@@ -115,7 +105,7 @@ downcastI64toI32InsertExtractElements(Instruction &I,
Extract->getVectorOperand(), Idx32, Extract->getName());
Extract->replaceAllUsesWith(NewExtract);
- ToRemove.push(Extract);
+ ToRemove.push_back(Extract);
}
}
@@ -132,7 +122,7 @@ downcastI64toI32InsertExtractElements(Instruction &I,
Insert->getName());
Insert->replaceAllUsesWith(Insert32Index);
- ToRemove.push(Insert);
+ ToRemove.push_back(Insert);
}
}
}
@@ -143,27 +133,22 @@ class DXILLegalizationPipeline {
DXILLegalizationPipeline() { initializeLegalizationPipeline(); }
bool runLegalizationPipeline(Function &F) {
- std::stack<Instruction *> ToRemove;
- std::map<Value *, Value *> ReplacedValues;
+ SmallVector<Instruction *> ToRemove;
+ DenseMap<Value *, Value *> ReplacedValues;
for (auto &I : instructions(F)) {
- for (auto &LegalizationFn : LegalizationPipeline) {
+ for (auto &LegalizationFn : LegalizationPipeline)
LegalizationFn(I, ToRemove, ReplacedValues);
- }
}
- bool MadeChanges = !ToRemove.empty();
- while (!ToRemove.empty()) {
- Instruction *I = ToRemove.top();
- I->eraseFromParent();
- ToRemove.pop();
- }
+ for (auto *Inst : reverse(ToRemove))
+ Inst->eraseFromParent();
- return MadeChanges;
+ return !ToRemove.empty();
}
private:
- std::vector<std::function<void(Instruction &, std::stack<Instruction *> &,
- std::map<Value *, Value *> &)>>
+ SmallVector<std::function<void(Instruction &, SmallVector<Instruction *> &,
+ DenseMap<Value *, Value *> &)>>
LegalizationPipeline;
void initializeLegalizationPipeline() {
>From 137873b88b4d51906beef1b8e0cc7cf7e8d3d96b Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Mon, 17 Mar 2025 22:16:52 -0400
Subject: [PATCH 2/3] address pr comments
---
llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 19 +++++++++++--------
1 file changed, 11 insertions(+), 8 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
index 317bff40caf7e..44cd92cfc51f8 100644
--- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
@@ -19,10 +19,10 @@
#define DEBUG_TYPE "dxil-legalize"
using namespace llvm;
-namespace {
-void fixI8TruncUseChain(Instruction &I, SmallVector<Instruction *> &ToRemove,
- DenseMap<Value *, Value *> &ReplacedValues) {
+static void fixI8TruncUseChain(Instruction &I,
+ SmallVectorImpl<Instruction *> &ToRemove,
+ DenseMap<Value *, Value *> &ReplacedValues) {
auto *Cmp = dyn_cast<CmpInst>(&I);
@@ -89,9 +89,10 @@ void fixI8TruncUseChain(Instruction &I, SmallVector<Instruction *> &ToRemove,
}
}
-void downcastI64toI32InsertExtractElements(Instruction &I,
- SmallVector<Instruction *> &ToRemove,
- DenseMap<Value *, Value *> &) {
+static void
+downcastI64toI32InsertExtractElements(Instruction &I,
+ SmallVectorImpl<Instruction *> &ToRemove,
+ DenseMap<Value *, Value *> &) {
if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) {
Value *Idx = Extract->getIndexOperand();
@@ -127,6 +128,7 @@ void downcastI64toI32InsertExtractElements(Instruction &I,
}
}
+namespace {
class DXILLegalizationPipeline {
public:
@@ -147,8 +149,9 @@ class DXILLegalizationPipeline {
}
private:
- SmallVector<std::function<void(Instruction &, SmallVector<Instruction *> &,
- DenseMap<Value *, Value *> &)>>
+ SmallVector<
+ std::function<void(Instruction &, SmallVectorImpl<Instruction *> &,
+ DenseMap<Value *, Value *> &)>>
LegalizationPipeline;
void initializeLegalizationPipeline() {
>From adb61f8e189ac7a28d135aa0d7fad8265d5ff45c Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Thu, 20 Mar 2025 14:02:13 -0400
Subject: [PATCH 3/3] refactor fixI8TruncUseChain to remove nesting
---
llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 76 +++++++++++---------
1 file changed, 43 insertions(+), 33 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
index 44cd92cfc51f8..a9edf77aa08af 100644
--- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
@@ -24,24 +24,15 @@ static void fixI8TruncUseChain(Instruction &I,
SmallVectorImpl<Instruction *> &ToRemove,
DenseMap<Value *, Value *> &ReplacedValues) {
- auto *Cmp = dyn_cast<CmpInst>(&I);
-
- if (auto *Trunc = dyn_cast<TruncInst>(&I)) {
- if (Trunc->getDestTy()->isIntegerTy(8)) {
- ReplacedValues[Trunc] = Trunc->getOperand(0);
- ToRemove.push_back(Trunc);
- }
- } else if (I.getType()->isIntegerTy(8) ||
- (Cmp && Cmp->getOperand(0)->getType()->isIntegerTy(8))) {
- IRBuilder<> Builder(&I);
-
- SmallVector<Value *> NewOperands;
+ auto ProcessOperands = [&](SmallVector<Value *> &NewOperands) {
Type *InstrType = IntegerType::get(I.getContext(), 32);
+
for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
Value *Op = I.getOperand(OpIdx);
if (ReplacedValues.count(Op))
InstrType = ReplacedValues[Op]->getType();
}
+
for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
Value *Op = I.getOperand(OpIdx);
if (ReplacedValues.count(Op))
@@ -52,6 +43,8 @@ static void fixI8TruncUseChain(Instruction &I,
// Note: options here are sext or sextOrTrunc.
// Since i8 isn't supported, we assume new values
// will always have a higher bitness.
+ assert(NewBitWidth > Value.getBitWidth() &&
+ "Replacement's BitWidth should be larger than Current.");
APInt NewValue = Value.sext(NewBitWidth);
NewOperands.push_back(ConstantInt::get(InstrType, NewValue));
} else {
@@ -59,29 +52,46 @@ static void fixI8TruncUseChain(Instruction &I,
NewOperands.push_back(Op);
}
}
-
- Value *NewInst = nullptr;
- if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
- NewInst =
- Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]);
-
- if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) {
- if (OBO->hasNoSignedWrap())
- cast<BinaryOperator>(NewInst)->setHasNoSignedWrap();
- if (OBO->hasNoUnsignedWrap())
- cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap();
- }
- } else if (Cmp) {
- NewInst = Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0],
- NewOperands[1]);
- Cmp->replaceAllUsesWith(NewInst);
+ };
+ IRBuilder<> Builder(&I);
+ if (auto *Trunc = dyn_cast<TruncInst>(&I)) {
+ if (Trunc->getDestTy()->isIntegerTy(8)) {
+ ReplacedValues[Trunc] = Trunc->getOperand(0);
+ ToRemove.push_back(Trunc);
}
-
- if (NewInst) {
- ReplacedValues[&I] = NewInst;
- ToRemove.push_back(&I);
+ }
+ Value *NewInst = nullptr;
+ if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
+ if (!I.getType()->isIntegerTy(8))
+ return;
+ SmallVector<Value *> NewOperands;
+ ProcessOperands(NewOperands);
+ NewInst =
+ Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]);
+ if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) {
+ if (OBO->hasNoSignedWrap())
+ cast<BinaryOperator>(NewInst)->setHasNoSignedWrap();
+ if (OBO->hasNoUnsignedWrap())
+ cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap();
}
- } else if (auto *Cast = dyn_cast<CastInst>(&I)) {
+ }
+
+ if (auto *Cmp = dyn_cast<CmpInst>(&I)) {
+ if (!Cmp->getOperand(0)->getType()->isIntegerTy(8))
+ return;
+ SmallVector<Value *> NewOperands;
+ ProcessOperands(NewOperands);
+ NewInst =
+ Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0], NewOperands[1]);
+ Cmp->replaceAllUsesWith(NewInst);
+ }
+ if (NewInst) {
+ ReplacedValues[&I] = NewInst;
+ ToRemove.push_back(&I);
+ return;
+ }
+
+ if (auto *Cast = dyn_cast<CastInst>(&I)) {
if (Cast->getSrcTy()->isIntegerTy(8)) {
ToRemove.push_back(Cast);
Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]);
More information about the llvm-commits
mailing list