[llvm] [DirectX] replace byte splitting via vector bitcast with scalar (PR #140167)

Farzon Lotfi via llvm-commits llvm-commits at lists.llvm.org
Thu May 22 21:10:27 PDT 2025


https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/140167

>From ef0efad21e717d6f0fa14defe15bb44a357acb87 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Thu, 15 May 2025 18:11:44 -0400
Subject: [PATCH 1/3] [DirectX] replace byte splitting via vector bitcast with
 scalar instructions - instead of bitcasting and extract element lets use
 trunc or trunc  and logical shift right to split. - fixes #139020

---
 llvm/lib/Target/DirectX/DXILLegalizePass.cpp  | 46 +++++++++++++++++++
 .../legalize-i64-high-low-vec-spilt.ll        | 18 ++++++++
 2 files changed, 64 insertions(+)
 create mode 100644 llvm/test/CodeGen/DirectX/legalize-i64-high-low-vec-spilt.ll

diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
index 23883c936a20d..295b31185476c 100644
--- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
@@ -8,6 +8,8 @@
 
 #include "DXILLegalizePass.h"
 #include "DirectX.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/IR/Constants.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InstIterator.h"
@@ -419,6 +421,49 @@ static void updateFnegToFsub(Instruction &I,
   ToRemove.push_back(&I);
 }
 
+static void
+legalizeGetHighLowi64Bytes(Instruction &I,
+                           SmallVectorImpl<Instruction *> &ToRemove,
+                           DenseMap<Value *, Value *> &ReplacedValues) {
+  if (auto *BitCast = dyn_cast<BitCastInst>(&I)) {
+    if (BitCast->getDestTy() ==
+            FixedVectorType::get(Type::getInt32Ty(I.getContext()), 2) &&
+        BitCast->getSrcTy()->isIntegerTy(64)) {
+      ToRemove.push_back(BitCast);
+      ReplacedValues[BitCast] = BitCast->getOperand(0);
+    }
+  }
+
+  if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) {
+    auto *VecTy = dyn_cast<FixedVectorType>(Extract->getVectorOperandType());
+    if (VecTy && VecTy->getElementType()->isIntegerTy(32) &&
+        VecTy->getNumElements() == 2) {
+      if (auto *Index = dyn_cast<ConstantInt>(Extract->getIndexOperand())) {
+        unsigned Idx = Index->getZExtValue();
+        IRBuilder<> Builder(&I);
+        assert(dyn_cast<BitCastInst>(Extract->getVectorOperand()));
+        auto *Replacement = ReplacedValues[Extract->getVectorOperand()];
+        if (Idx == 0) {
+          Value *LowBytes = Builder.CreateTrunc(
+              Replacement, Type::getInt32Ty(I.getContext()));
+          ReplacedValues[Extract] = LowBytes;
+        } else {
+          assert(Idx == 1);
+          Value *LogicalShiftRight = Builder.CreateLShr(
+              Replacement,
+              ConstantInt::get(
+                  Replacement->getType(),
+                  APInt(Replacement->getType()->getIntegerBitWidth(), 32)));
+          Value *HighBytes = Builder.CreateTrunc(
+              LogicalShiftRight, Type::getInt32Ty(I.getContext()));
+          ReplacedValues[Extract] = HighBytes;
+        }
+        ToRemove.push_back(Extract);
+      }
+    }
+  }
+}
+
 namespace {
 class DXILLegalizationPipeline {
 
@@ -453,6 +498,7 @@ class DXILLegalizationPipeline {
     LegalizationPipeline.push_back(legalizeMemCpy);
     LegalizationPipeline.push_back(removeMemSet);
     LegalizationPipeline.push_back(updateFnegToFsub);
+    LegalizationPipeline.push_back(legalizeGetHighLowi64Bytes);
   }
 };
 
diff --git a/llvm/test/CodeGen/DirectX/legalize-i64-high-low-vec-spilt.ll b/llvm/test/CodeGen/DirectX/legalize-i64-high-low-vec-spilt.ll
new file mode 100644
index 0000000000000..17fd3bf54acda
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/legalize-i64-high-low-vec-spilt.ll
@@ -0,0 +1,18 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+define void @split_via_extract(i64 noundef %a) {
+; CHECK-LABEL: define void @split_via_extract(
+; CHECK-SAME: i64 noundef [[A:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = trunc i64 [[A]] to i32
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i64 [[A]], 32
+; CHECK-NEXT:    [[TMP2:%.*]] = trunc i64 [[TMP1]] to i32
+; CHECK-NEXT:    ret void
+;
+entry:
+  %vecA = bitcast i64 %a to <2 x i32>
+  %low = extractelement <2 x i32> %vecA, i32 0 ; low 32 bits
+  %high = extractelement <2 x i32> %vecA, i32 1 ; high 32 bits
+  ret void
+}

>From 70cc62ac1f492be23a777f0309faf5eb880c5c3d Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Fri, 16 May 2025 12:34:24 -0400
Subject: [PATCH 2/3] address pr comment

---
 llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
index 295b31185476c..a16b803829014 100644
--- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
@@ -431,6 +431,7 @@ legalizeGetHighLowi64Bytes(Instruction &I,
         BitCast->getSrcTy()->isIntegerTy(64)) {
       ToRemove.push_back(BitCast);
       ReplacedValues[BitCast] = BitCast->getOperand(0);
+      return;
     }
   }
 
@@ -443,6 +444,8 @@ legalizeGetHighLowi64Bytes(Instruction &I,
         IRBuilder<> Builder(&I);
         assert(dyn_cast<BitCastInst>(Extract->getVectorOperand()));
         auto *Replacement = ReplacedValues[Extract->getVectorOperand()];
+        assert(Replacement && "The BitCast replacement should have been set "
+                              "before working on ExtractElementInst.");
         if (Idx == 0) {
           Value *LowBytes = Builder.CreateTrunc(
               Replacement, Type::getInt32Ty(I.getContext()));

>From ba88ccb2c2525651ad3d162f4f6efb08f3eecd1e Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Fri, 16 May 2025 14:50:30 -0700
Subject: [PATCH 3/3] Add Legalization stages. Passes that change the same
 instructions need to be staggered.

---
 llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 49 ++++++++++++--------
 1 file changed, 29 insertions(+), 20 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
index a16b803829014..db17a8c9dc963 100644
--- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
@@ -462,6 +462,7 @@ legalizeGetHighLowi64Bytes(Instruction &I,
           ReplacedValues[Extract] = HighBytes;
         }
         ToRemove.push_back(Extract);
+        Extract->replaceAllUsesWith(ReplacedValues[Extract]);
       }
     }
   }
@@ -474,34 +475,42 @@ class DXILLegalizationPipeline {
   DXILLegalizationPipeline() { initializeLegalizationPipeline(); }
 
   bool runLegalizationPipeline(Function &F) {
-    SmallVector<Instruction *> ToRemove;
-    DenseMap<Value *, Value *> ReplacedValues;
-    for (auto &I : instructions(F)) {
-      for (auto &LegalizationFn : LegalizationPipeline)
-        LegalizationFn(I, ToRemove, ReplacedValues);
-    }
+    bool MadeChange = false;
+    for (int Stage = 0; Stage < NumStages; ++Stage) {
+      SmallVector<Instruction *> ToRemove;
+      DenseMap<Value *, Value *> ReplacedValues;
+      for (auto &I : instructions(F)) {
+        for (auto &LegalizationFn : LegalizationPipeline[Stage])
+          LegalizationFn(I, ToRemove, ReplacedValues);
+      }
 
-    for (auto *Inst : reverse(ToRemove))
-      Inst->eraseFromParent();
+      for (auto *Inst : reverse(ToRemove))
+        Inst->eraseFromParent();
 
-    return !ToRemove.empty();
+      MadeChange |= !ToRemove.empty();
+    }
+    return MadeChange;
   }
 
 private:
-  SmallVector<
+  enum LegalizationStage { Stage1 = 0, Stage2 = 1, NumStages };
+
+  using LegalizationFnTy =
       std::function<void(Instruction &, SmallVectorImpl<Instruction *> &,
-                         DenseMap<Value *, Value *> &)>>
-      LegalizationPipeline;
+                         DenseMap<Value *, Value *> &)>;
+
+  SmallVector<LegalizationFnTy> LegalizationPipeline[NumStages];
 
   void initializeLegalizationPipeline() {
-    LegalizationPipeline.push_back(upcastI8AllocasAndUses);
-    LegalizationPipeline.push_back(fixI8UseChain);
-    LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
-    LegalizationPipeline.push_back(legalizeFreeze);
-    LegalizationPipeline.push_back(legalizeMemCpy);
-    LegalizationPipeline.push_back(removeMemSet);
-    LegalizationPipeline.push_back(updateFnegToFsub);
-    LegalizationPipeline.push_back(legalizeGetHighLowi64Bytes);
+    LegalizationPipeline[Stage1].push_back(upcastI8AllocasAndUses);
+    LegalizationPipeline[Stage1].push_back(fixI8UseChain);
+    LegalizationPipeline[Stage1].push_back(legalizeGetHighLowi64Bytes);
+    LegalizationPipeline[Stage1].push_back(legalizeFreeze);
+    LegalizationPipeline[Stage1].push_back(legalizeMemCpy);
+    LegalizationPipeline[Stage1].push_back(removeMemSet);
+    LegalizationPipeline[Stage1].push_back(updateFnegToFsub);
+    LegalizationPipeline[Stage2].push_back(
+        downcastI64toI32InsertExtractElements);
   }
 };
 



More information about the llvm-commits mailing list