[llvm] [DirectX] Address PR comments to #131221 (PR #131706)

Farzon Lotfi via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 20 11:08:08 PDT 2025


https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/131706

>From 723e1837593131c10d6ed096674a273d5c530532 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Mon, 17 Mar 2025 21:44:37 -0400
Subject: [PATCH 1/3] [DirectX] Address PR comments to #131221

---
 llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 55 +++++++-------------
 1 file changed, 20 insertions(+), 35 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
index f9a494ce63dd3..317bff40caf7e 100644
--- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
@@ -5,12 +5,7 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===---------------------------------------------------------------------===//
-//===---------------------------------------------------------------------===//
-///
-/// \file This file contains a pass to remove i8 truncations and i64 extract
-/// and insert elements.
-///
-//===----------------------------------------------------------------------===//
+
 #include "DXILLegalizePass.h"
 #include "DirectX.h"
 #include "llvm/IR/Function.h"
@@ -20,31 +15,27 @@
 #include "llvm/Pass.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include <functional>
-#include <map>
-#include <stack>
-#include <vector>
 
 #define DEBUG_TYPE "dxil-legalize"
 
 using namespace llvm;
 namespace {
 
-static void fixI8TruncUseChain(Instruction &I,
-                               std::stack<Instruction *> &ToRemove,
-                               std::map<Value *, Value *> &ReplacedValues) {
+void fixI8TruncUseChain(Instruction &I, SmallVector<Instruction *> &ToRemove,
+                        DenseMap<Value *, Value *> &ReplacedValues) {
 
   auto *Cmp = dyn_cast<CmpInst>(&I);
 
   if (auto *Trunc = dyn_cast<TruncInst>(&I)) {
     if (Trunc->getDestTy()->isIntegerTy(8)) {
       ReplacedValues[Trunc] = Trunc->getOperand(0);
-      ToRemove.push(Trunc);
+      ToRemove.push_back(Trunc);
     }
   } else if (I.getType()->isIntegerTy(8) ||
              (Cmp && Cmp->getOperand(0)->getType()->isIntegerTy(8))) {
     IRBuilder<> Builder(&I);
 
-    std::vector<Value *> NewOperands;
+    SmallVector<Value *> NewOperands;
     Type *InstrType = IntegerType::get(I.getContext(), 32);
     for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
       Value *Op = I.getOperand(OpIdx);
@@ -88,20 +79,19 @@ static void fixI8TruncUseChain(Instruction &I,
 
     if (NewInst) {
       ReplacedValues[&I] = NewInst;
-      ToRemove.push(&I);
+      ToRemove.push_back(&I);
     }
   } else if (auto *Cast = dyn_cast<CastInst>(&I)) {
     if (Cast->getSrcTy()->isIntegerTy(8)) {
-      ToRemove.push(Cast);
+      ToRemove.push_back(Cast);
       Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]);
     }
   }
 }
 
-static void
-downcastI64toI32InsertExtractElements(Instruction &I,
-                                      std::stack<Instruction *> &ToRemove,
-                                      std::map<Value *, Value *> &) {
+void downcastI64toI32InsertExtractElements(Instruction &I,
+                                           SmallVector<Instruction *> &ToRemove,
+                                           DenseMap<Value *, Value *> &) {
 
   if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) {
     Value *Idx = Extract->getIndexOperand();
@@ -115,7 +105,7 @@ downcastI64toI32InsertExtractElements(Instruction &I,
           Extract->getVectorOperand(), Idx32, Extract->getName());
 
       Extract->replaceAllUsesWith(NewExtract);
-      ToRemove.push(Extract);
+      ToRemove.push_back(Extract);
     }
   }
 
@@ -132,7 +122,7 @@ downcastI64toI32InsertExtractElements(Instruction &I,
           Insert->getName());
 
       Insert->replaceAllUsesWith(Insert32Index);
-      ToRemove.push(Insert);
+      ToRemove.push_back(Insert);
     }
   }
 }
@@ -143,27 +133,22 @@ class DXILLegalizationPipeline {
   DXILLegalizationPipeline() { initializeLegalizationPipeline(); }
 
   bool runLegalizationPipeline(Function &F) {
-    std::stack<Instruction *> ToRemove;
-    std::map<Value *, Value *> ReplacedValues;
+    SmallVector<Instruction *> ToRemove;
+    DenseMap<Value *, Value *> ReplacedValues;
     for (auto &I : instructions(F)) {
-      for (auto &LegalizationFn : LegalizationPipeline) {
+      for (auto &LegalizationFn : LegalizationPipeline)
         LegalizationFn(I, ToRemove, ReplacedValues);
-      }
     }
-    bool MadeChanges = !ToRemove.empty();
 
-    while (!ToRemove.empty()) {
-      Instruction *I = ToRemove.top();
-      I->eraseFromParent();
-      ToRemove.pop();
-    }
+    for (auto *Inst : reverse(ToRemove))
+      Inst->eraseFromParent();
 
-    return MadeChanges;
+    return !ToRemove.empty();
   }
 
 private:
-  std::vector<std::function<void(Instruction &, std::stack<Instruction *> &,
-                                 std::map<Value *, Value *> &)>>
+  SmallVector<std::function<void(Instruction &, SmallVector<Instruction *> &,
+                                 DenseMap<Value *, Value *> &)>>
       LegalizationPipeline;
 
   void initializeLegalizationPipeline() {

>From 137873b88b4d51906beef1b8e0cc7cf7e8d3d96b Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Mon, 17 Mar 2025 22:16:52 -0400
Subject: [PATCH 2/3] address pr comments

---
 llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 19 +++++++++++--------
 1 file changed, 11 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
index 317bff40caf7e..44cd92cfc51f8 100644
--- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
@@ -19,10 +19,10 @@
 #define DEBUG_TYPE "dxil-legalize"
 
 using namespace llvm;
-namespace {
 
-void fixI8TruncUseChain(Instruction &I, SmallVector<Instruction *> &ToRemove,
-                        DenseMap<Value *, Value *> &ReplacedValues) {
+static void fixI8TruncUseChain(Instruction &I,
+                               SmallVectorImpl<Instruction *> &ToRemove,
+                               DenseMap<Value *, Value *> &ReplacedValues) {
 
   auto *Cmp = dyn_cast<CmpInst>(&I);
 
@@ -89,9 +89,10 @@ void fixI8TruncUseChain(Instruction &I, SmallVector<Instruction *> &ToRemove,
   }
 }
 
-void downcastI64toI32InsertExtractElements(Instruction &I,
-                                           SmallVector<Instruction *> &ToRemove,
-                                           DenseMap<Value *, Value *> &) {
+static void
+downcastI64toI32InsertExtractElements(Instruction &I,
+                                      SmallVectorImpl<Instruction *> &ToRemove,
+                                      DenseMap<Value *, Value *> &) {
 
   if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) {
     Value *Idx = Extract->getIndexOperand();
@@ -127,6 +128,7 @@ void downcastI64toI32InsertExtractElements(Instruction &I,
   }
 }
 
+namespace {
 class DXILLegalizationPipeline {
 
 public:
@@ -147,8 +149,9 @@ class DXILLegalizationPipeline {
   }
 
 private:
-  SmallVector<std::function<void(Instruction &, SmallVector<Instruction *> &,
-                                 DenseMap<Value *, Value *> &)>>
+  SmallVector<
+      std::function<void(Instruction &, SmallVectorImpl<Instruction *> &,
+                         DenseMap<Value *, Value *> &)>>
       LegalizationPipeline;
 
   void initializeLegalizationPipeline() {

>From adb61f8e189ac7a28d135aa0d7fad8265d5ff45c Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Thu, 20 Mar 2025 14:02:13 -0400
Subject: [PATCH 3/3] refactor fixI8TruncUseChain to remove nesting

---
 llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 76 +++++++++++---------
 1 file changed, 43 insertions(+), 33 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
index 44cd92cfc51f8..a9edf77aa08af 100644
--- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
@@ -24,24 +24,15 @@ static void fixI8TruncUseChain(Instruction &I,
                                SmallVectorImpl<Instruction *> &ToRemove,
                                DenseMap<Value *, Value *> &ReplacedValues) {
 
-  auto *Cmp = dyn_cast<CmpInst>(&I);
-
-  if (auto *Trunc = dyn_cast<TruncInst>(&I)) {
-    if (Trunc->getDestTy()->isIntegerTy(8)) {
-      ReplacedValues[Trunc] = Trunc->getOperand(0);
-      ToRemove.push_back(Trunc);
-    }
-  } else if (I.getType()->isIntegerTy(8) ||
-             (Cmp && Cmp->getOperand(0)->getType()->isIntegerTy(8))) {
-    IRBuilder<> Builder(&I);
-
-    SmallVector<Value *> NewOperands;
+  auto ProcessOperands = [&](SmallVector<Value *> &NewOperands) {
     Type *InstrType = IntegerType::get(I.getContext(), 32);
+
     for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
       Value *Op = I.getOperand(OpIdx);
       if (ReplacedValues.count(Op))
         InstrType = ReplacedValues[Op]->getType();
     }
+
     for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
       Value *Op = I.getOperand(OpIdx);
       if (ReplacedValues.count(Op))
@@ -52,6 +43,8 @@ static void fixI8TruncUseChain(Instruction &I,
         // Note: options here are sext or sextOrTrunc.
         // Since i8 isn't supported, we assume new values
         // will always have a higher bitness.
+        assert(NewBitWidth > Value.getBitWidth() &&
+               "Replacement's BitWidth should be larger than Current.");
         APInt NewValue = Value.sext(NewBitWidth);
         NewOperands.push_back(ConstantInt::get(InstrType, NewValue));
       } else {
@@ -59,29 +52,46 @@ static void fixI8TruncUseChain(Instruction &I,
         NewOperands.push_back(Op);
       }
     }
-
-    Value *NewInst = nullptr;
-    if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
-      NewInst =
-          Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]);
-
-      if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) {
-        if (OBO->hasNoSignedWrap())
-          cast<BinaryOperator>(NewInst)->setHasNoSignedWrap();
-        if (OBO->hasNoUnsignedWrap())
-          cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap();
-      }
-    } else if (Cmp) {
-      NewInst = Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0],
-                                  NewOperands[1]);
-      Cmp->replaceAllUsesWith(NewInst);
+  };
+  IRBuilder<> Builder(&I);
+  if (auto *Trunc = dyn_cast<TruncInst>(&I)) {
+    if (Trunc->getDestTy()->isIntegerTy(8)) {
+      ReplacedValues[Trunc] = Trunc->getOperand(0);
+      ToRemove.push_back(Trunc);
     }
-
-    if (NewInst) {
-      ReplacedValues[&I] = NewInst;
-      ToRemove.push_back(&I);
+  }
+  Value *NewInst = nullptr;
+  if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
+    if (!I.getType()->isIntegerTy(8))
+      return;
+    SmallVector<Value *> NewOperands;
+    ProcessOperands(NewOperands);
+    NewInst =
+        Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]);
+    if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) {
+      if (OBO->hasNoSignedWrap())
+        cast<BinaryOperator>(NewInst)->setHasNoSignedWrap();
+      if (OBO->hasNoUnsignedWrap())
+        cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap();
     }
-  } else if (auto *Cast = dyn_cast<CastInst>(&I)) {
+  }
+
+  if (auto *Cmp = dyn_cast<CmpInst>(&I)) {
+    if (!Cmp->getOperand(0)->getType()->isIntegerTy(8))
+      return;
+    SmallVector<Value *> NewOperands;
+    ProcessOperands(NewOperands);
+    NewInst =
+        Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0], NewOperands[1]);
+    Cmp->replaceAllUsesWith(NewInst);
+  }
+  if (NewInst) {
+    ReplacedValues[&I] = NewInst;
+    ToRemove.push_back(&I);
+    return;
+  }
+
+  if (auto *Cast = dyn_cast<CastInst>(&I)) {
     if (Cast->getSrcTy()->isIntegerTy(8)) {
       ToRemove.push_back(Cast);
       Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]);



More information about the llvm-commits mailing list