[llvm] [DirectX] Legalize i8 allocas (PR #137399)

via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 25 14:20:10 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-directx

Author: Farzon Lotfi (farzonl)

<details>
<summary>Changes</summary>

fixes #<!-- -->137202

investingating i8 allocas I came to find some missing instructions from out i8 legalization around load, store, and select.
Added those three.

To do i8 allocas right though we needed to walk the uses and find the casts.

After finding the casts I chose to pick the smallest cast as the cast to transform to. That would then let me preserve the larger casts that come later

---
Full diff: https://github.com/llvm/llvm-project/pull/137399.diff


2 Files Affected:

- (modified) llvm/lib/Target/DirectX/DXILLegalizePass.cpp (+102-12) 
- (added) llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll (+99) 


``````````diff
diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
index b62ff4c52f70c..b7b209fcecbc9 100644
--- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
@@ -12,6 +12,7 @@
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InstIterator.h"
 #include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/Pass.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include <functional>
@@ -31,16 +32,17 @@ static void legalizeFreeze(Instruction &I,
   ToRemove.push_back(FI);
 }
 
-static void fixI8TruncUseChain(Instruction &I,
-                               SmallVectorImpl<Instruction *> &ToRemove,
-                               DenseMap<Value *, Value *> &ReplacedValues) {
+static void fixI8UseChain(Instruction &I,
+                          SmallVectorImpl<Instruction *> &ToRemove,
+                          DenseMap<Value *, Value *> &ReplacedValues) {
 
   auto ProcessOperands = [&](SmallVector<Value *> &NewOperands) {
     Type *InstrType = IntegerType::get(I.getContext(), 32);
 
     for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
       Value *Op = I.getOperand(OpIdx);
-      if (ReplacedValues.count(Op))
+      if (ReplacedValues.count(Op) &&
+          ReplacedValues[Op]->getType()->isIntegerTy())
         InstrType = ReplacedValues[Op]->getType();
     }
 
@@ -73,6 +75,31 @@ static void fixI8TruncUseChain(Instruction &I,
     }
   }
 
+  if (auto *Store = dyn_cast<StoreInst>(&I)) {
+    if (!Store->getValueOperand()->getType()->isIntegerTy(8))
+      return;
+    SmallVector<Value *> NewOperands;
+    ProcessOperands(NewOperands);
+    Value *NewStore = Builder.CreateStore(NewOperands[0], NewOperands[1]);
+    ReplacedValues[Store] = NewStore;
+    ToRemove.push_back(Store);
+    return;
+  }
+
+  if (auto *Load = dyn_cast<LoadInst>(&I)) {
+    if (!I.getType()->isIntegerTy(8))
+      return;
+    SmallVector<Value *> NewOperands;
+    ProcessOperands(NewOperands);
+    Type *ElementType = NewOperands[0]->getType();
+    if (auto *AI = dyn_cast<AllocaInst>(NewOperands[0]))
+      ElementType = AI->getAllocatedType();
+    LoadInst *NewLoad = Builder.CreateLoad(ElementType, NewOperands[0]);
+    ReplacedValues[Load] = NewLoad;
+    ToRemove.push_back(Load);
+    return;
+  }
+
   if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
     if (!I.getType()->isIntegerTy(8))
       return;
@@ -81,16 +108,29 @@ static void fixI8TruncUseChain(Instruction &I,
     Value *NewInst =
         Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]);
     if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) {
-      if (OBO->hasNoSignedWrap())
-        cast<BinaryOperator>(NewInst)->setHasNoSignedWrap();
-      if (OBO->hasNoUnsignedWrap())
-        cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap();
+      auto *NewBO = dyn_cast<BinaryOperator>(NewInst);
+      if (NewBO && OBO->hasNoSignedWrap())
+        NewBO->setHasNoSignedWrap();
+      if (NewBO && OBO->hasNoUnsignedWrap())
+        NewBO->setHasNoUnsignedWrap();
     }
     ReplacedValues[BO] = NewInst;
     ToRemove.push_back(BO);
     return;
   }
 
+  if (auto *Sel = dyn_cast<SelectInst>(&I)) {
+    if (!I.getType()->isIntegerTy(8))
+      return;
+    SmallVector<Value *> NewOperands;
+    ProcessOperands(NewOperands);
+    Value *NewInst = Builder.CreateSelect(Sel->getCondition(), NewOperands[1],
+                                          NewOperands[2]);
+    ReplacedValues[Sel] = NewInst;
+    ToRemove.push_back(Sel);
+    return;
+  }
+
   if (auto *Cmp = dyn_cast<CmpInst>(&I)) {
     if (!Cmp->getOperand(0)->getType()->isIntegerTy(8))
       return;
@@ -105,13 +145,62 @@ static void fixI8TruncUseChain(Instruction &I,
   }
 
   if (auto *Cast = dyn_cast<CastInst>(&I)) {
-    if (Cast->getSrcTy()->isIntegerTy(8)) {
-      ToRemove.push_back(Cast);
-      Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]);
+    if (!Cast->getSrcTy()->isIntegerTy(8))
+      return;
+    
+    ToRemove.push_back(Cast);
+    auto* Replacement =ReplacedValues[Cast->getOperand(0)];
+    if (Cast->getType() == Replacement->getType()) {
+      Cast->replaceAllUsesWith(Replacement);
+      return;
     }
+    Value* AdjustedCast = nullptr;
+    if (Cast->getOpcode() == Instruction::ZExt)
+      AdjustedCast = Builder.CreateZExtOrTrunc(Replacement, Cast->getType());
+    if (Cast->getOpcode() == Instruction::SExt)
+      AdjustedCast = Builder.CreateSExtOrTrunc(Replacement, Cast->getType());
+  
+    if(AdjustedCast)
+      Cast->replaceAllUsesWith(AdjustedCast);
   }
 }
 
+static void upcastI8AllocasAndUses(Instruction &I,
+                                   SmallVectorImpl<Instruction *> &ToRemove,
+                                   DenseMap<Value *, Value *> &ReplacedValues) {
+  auto *AI = dyn_cast<AllocaInst>(&I);
+  if (!AI || !AI->getAllocatedType()->isIntegerTy(8))
+    return;
+
+  Type *SmallestType = nullptr;
+
+  // Gather all cast targets
+  for (User *U : AI->users()) {
+    auto *Load = dyn_cast<LoadInst>(U);
+    if (!Load)
+      continue;
+    for (User *LU : Load->users()) {
+      auto *Cast = dyn_cast<CastInst>(LU);
+      if (!Cast)
+        continue;
+      Type *Ty = Cast->getType();
+      if (!SmallestType ||
+          Ty->getPrimitiveSizeInBits() < SmallestType->getPrimitiveSizeInBits())
+        SmallestType = Ty;
+    }
+  }
+
+  if (!SmallestType)
+    return; // no valid casts found
+
+  // Replace alloca
+  IRBuilder<> Builder(AI);
+  auto *NewAlloca =
+      Builder.CreateAlloca(SmallestType);
+  ReplacedValues[AI] = NewAlloca;
+  ToRemove.push_back(AI);
+}
+
 static void
 downcastI64toI32InsertExtractElements(Instruction &I,
                                       SmallVectorImpl<Instruction *> &ToRemove,
@@ -178,7 +267,8 @@ class DXILLegalizationPipeline {
       LegalizationPipeline;
 
   void initializeLegalizationPipeline() {
-    LegalizationPipeline.push_back(fixI8TruncUseChain);
+    LegalizationPipeline.push_back(upcastI8AllocasAndUses);
+    LegalizationPipeline.push_back(fixI8UseChain);
     LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
     LegalizationPipeline.push_back(legalizeFreeze);
   }
diff --git a/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll b/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll
new file mode 100644
index 0000000000000..529a69fca5d34
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll
@@ -0,0 +1,99 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+define void @const_i8_store() {
+; CHECK-LABEL: define void @const_i8_store() {
+; CHECK-NEXT:    [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca i32, align 4
+; CHECK-NEXT:    store i32 1, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
+; CHECK-NEXT:    store i32 [[TMP2]], ptr [[GEP]], align 4
+; CHECK-NEXT:    ret void
+;
+  %accum.i.flat = alloca [1 x i32], align 4
+  %i = alloca i8, align 4
+  store i8 1, ptr %i
+  %i8.load = load i8, ptr %i
+  %z = zext i8 %i8.load to i32
+  %gep = getelementptr i32, ptr %accum.i.flat, i32 0
+  store i32 %z, ptr %gep, align 4
+  ret void
+}
+
+define void @const_add_i8_store() {
+; CHECK-LABEL: define void @const_add_i8_store() {
+; CHECK-NEXT:    [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca i32, align 4
+; CHECK-NEXT:    store i32 4, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
+; CHECK-NEXT:    store i32 [[TMP2]], ptr [[GEP]], align 4
+; CHECK-NEXT:    ret void
+;
+  %accum.i.flat = alloca [1 x i32], align 4
+  %i = alloca i8, align 4
+  %add_i8 = add nsw i8 3, 1
+  store i8 %add_i8, ptr %i
+  %i8.load = load i8, ptr %i
+  %z = zext i8 %i8.load to i32
+  %gep = getelementptr i32, ptr %accum.i.flat, i32 0
+  store i32 %z, ptr %gep, align 4
+  ret void
+}
+
+define void @var_i8_store(i1 %cmp.i8) {
+; CHECK-LABEL: define void @var_i8_store(
+; CHECK-SAME: i1 [[CMP_I8:%.*]]) {
+; CHECK-NEXT:    [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca i32, align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = select i1 [[CMP_I8]], i32 1, i32 2
+; CHECK-NEXT:    store i32 [[TMP2]], ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[TMP3:%.*]] = load i32, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
+; CHECK-NEXT:    store i32 [[TMP3]], ptr [[GEP]], align 4
+; CHECK-NEXT:    ret void
+;
+  %accum.i.flat = alloca [1 x i32], align 4
+  %i = alloca i8, align 4
+  %select.i8 = select i1 %cmp.i8, i8 1, i8 2
+  store i8 %select.i8, ptr %i
+  %i8.load = load i8, ptr %i
+  %z = zext i8 %i8.load to i32
+  %gep = getelementptr i32, ptr %accum.i.flat, i32 0
+  store i32 %z, ptr %gep, align 4
+  ret void
+}
+
+define void @conflicting_cast(i1 %cmp.i8) {
+; CHECK-LABEL: define void @conflicting_cast(
+; CHECK-SAME: i1 [[CMP_I8:%.*]]) {
+; CHECK-NEXT:    [[ACCUM_I_FLAT:%.*]] = alloca [2 x i32], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca i16, align 2
+; CHECK-NEXT:    [[TMP2:%.*]] = select i1 [[CMP_I8]], i32 1, i32 2
+; CHECK-NEXT:    store i32 [[TMP2]], ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[TMP3:%.*]] = load i16, ptr [[TMP1]], align 2
+; CHECK-NEXT:    [[GEP1:%.*]] = getelementptr i16, ptr [[ACCUM_I_FLAT]], i32 0
+; CHECK-NEXT:    store i16 [[TMP3]], ptr [[GEP1]], align 2
+; CHECK-NEXT:    [[GEP2:%.*]] = getelementptr i16, ptr [[ACCUM_I_FLAT]], i32 1
+; CHECK-NEXT:    store i16 [[TMP3]], ptr [[GEP2]], align 2
+; CHECK-NEXT:    [[TMP4:%.*]] = zext i16 [[TMP3]] to i32
+; CHECK-NEXT:    [[GEP3:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 1
+; CHECK-NEXT:    store i32 [[TMP4]], ptr [[GEP3]], align 4
+; CHECK-NEXT:    ret void
+;
+  %accum.i.flat = alloca [2 x i32], align 4
+  %i = alloca i8, align 4
+  %select.i8 = select i1 %cmp.i8, i8 1, i8 2
+  store i8 %select.i8, ptr %i
+  %i8.load = load i8, ptr %i
+  %z = zext i8 %i8.load to i16
+  %gep1 = getelementptr i16, ptr %accum.i.flat, i32 0
+  store i16 %z, ptr %gep1, align 2
+  %gep2 = getelementptr i16, ptr %accum.i.flat, i32 1
+  store i16 %z, ptr %gep2, align 2
+  %z2 = zext i8 %i8.load to i32
+  %gep3 = getelementptr i32, ptr %accum.i.flat, i32 1
+  store i32 %z2, ptr %gep3, align 4
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/137399


More information about the llvm-commits mailing list