[llvm] [DirectX] replace byte splitting via vector bitcast with scalar (PR #140167)

Finn Plummer via llvm-commits llvm-commits at lists.llvm.org
Tue May 27 12:00:53 PDT 2025


================
@@ -419,40 +421,96 @@ static void updateFnegToFsub(Instruction &I,
   ToRemove.push_back(&I);
 }
 
+static void
+legalizeGetHighLowi64Bytes(Instruction &I,
+                           SmallVectorImpl<Instruction *> &ToRemove,
+                           DenseMap<Value *, Value *> &ReplacedValues) {
+  if (auto *BitCast = dyn_cast<BitCastInst>(&I)) {
+    if (BitCast->getDestTy() ==
+            FixedVectorType::get(Type::getInt32Ty(I.getContext()), 2) &&
+        BitCast->getSrcTy()->isIntegerTy(64)) {
+      ToRemove.push_back(BitCast);
+      ReplacedValues[BitCast] = BitCast->getOperand(0);
+      return;
+    }
+  }
+
+  if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) {
+    auto *VecTy = dyn_cast<FixedVectorType>(Extract->getVectorOperandType());
+    if (VecTy && VecTy->getElementType()->isIntegerTy(32) &&
+        VecTy->getNumElements() == 2) {
+      if (auto *Index = dyn_cast<ConstantInt>(Extract->getIndexOperand())) {
+        unsigned Idx = Index->getZExtValue();
+        IRBuilder<> Builder(&I);
+        assert(dyn_cast<BitCastInst>(Extract->getVectorOperand()));
+        auto *Replacement = ReplacedValues[Extract->getVectorOperand()];
+        assert(Replacement && "The BitCast replacement should have been set "
+                              "before working on ExtractElementInst.");
+        if (Idx == 0) {
+          Value *LowBytes = Builder.CreateTrunc(
+              Replacement, Type::getInt32Ty(I.getContext()));
+          ReplacedValues[Extract] = LowBytes;
+        } else {
+          assert(Idx == 1);
+          Value *LogicalShiftRight = Builder.CreateLShr(
+              Replacement,
+              ConstantInt::get(
+                  Replacement->getType(),
+                  APInt(Replacement->getType()->getIntegerBitWidth(), 32)));
+          Value *HighBytes = Builder.CreateTrunc(
+              LogicalShiftRight, Type::getInt32Ty(I.getContext()));
+          ReplacedValues[Extract] = HighBytes;
+        }
+        ToRemove.push_back(Extract);
+        Extract->replaceAllUsesWith(ReplacedValues[Extract]);
+      }
+    }
+  }
+}
+
 namespace {
 class DXILLegalizationPipeline {
 
 public:
   DXILLegalizationPipeline() { initializeLegalizationPipeline(); }
 
   bool runLegalizationPipeline(Function &F) {
-    SmallVector<Instruction *> ToRemove;
-    DenseMap<Value *, Value *> ReplacedValues;
-    for (auto &I : instructions(F)) {
-      for (auto &LegalizationFn : LegalizationPipeline)
-        LegalizationFn(I, ToRemove, ReplacedValues);
-    }
+    bool MadeChange = false;
+    for (int Stage = 0; Stage < NumStages; ++Stage) {
+      SmallVector<Instruction *> ToRemove;
+      DenseMap<Value *, Value *> ReplacedValues;
+      for (auto &I : instructions(F)) {
+        for (auto &LegalizationFn : LegalizationPipeline[Stage])
+          LegalizationFn(I, ToRemove, ReplacedValues);
+      }
 
-    for (auto *Inst : reverse(ToRemove))
-      Inst->eraseFromParent();
+      for (auto *Inst : reverse(ToRemove))
+        Inst->eraseFromParent();
 
-    return !ToRemove.empty();
+      MadeChange |= !ToRemove.empty();
+    }
+    return MadeChange;
   }
 
 private:
-  SmallVector<
+  enum LegalizationStage { Stage1 = 0, Stage2 = 1, NumStages };
+
+  using LegalizationFnTy =
       std::function<void(Instruction &, SmallVectorImpl<Instruction *> &,
-                         DenseMap<Value *, Value *> &)>>
-      LegalizationPipeline;
+                         DenseMap<Value *, Value *> &)>;
+
+  SmallVector<LegalizationFnTy> LegalizationPipeline[NumStages];
 
   void initializeLegalizationPipeline() {
-    LegalizationPipeline.push_back(upcastI8AllocasAndUses);
-    LegalizationPipeline.push_back(fixI8UseChain);
-    LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
-    LegalizationPipeline.push_back(legalizeFreeze);
-    LegalizationPipeline.push_back(legalizeMemCpy);
-    LegalizationPipeline.push_back(removeMemSet);
-    LegalizationPipeline.push_back(updateFnegToFsub);
+    LegalizationPipeline[Stage1].push_back(upcastI8AllocasAndUses);
+    LegalizationPipeline[Stage1].push_back(fixI8UseChain);
+    LegalizationPipeline[Stage1].push_back(legalizeGetHighLowi64Bytes);
+    LegalizationPipeline[Stage1].push_back(legalizeFreeze);
+    LegalizationPipeline[Stage1].push_back(legalizeMemCpy);
+    LegalizationPipeline[Stage1].push_back(removeMemSet);
+    LegalizationPipeline[Stage1].push_back(updateFnegToFsub);
+    LegalizationPipeline[Stage2].push_back(
----------------
inbelic wrote:

```suggestion
    // legalizeGetHighLowi64Bytes may replace extractelement instructions so stage after
    LegalizationPipeline[Stage2].push_back(
```

https://github.com/llvm/llvm-project/pull/140167


More information about the llvm-commits mailing list