[llvm] [DirectX] Legalize i8 allocas (PR #137399)

Farzon Lotfi via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 25 14:19:33 PDT 2025


https://github.com/farzonl created https://github.com/llvm/llvm-project/pull/137399

fixes #137202

investingating i8 allocas I came to find some missing instructions from out i8 legalization around load, store, and select.
Added those three.

To do i8 allocas right though we needed to walk the uses and find the casts.

After finding the casts I chose to pick the smallest cast as the cast to transform to. That would then let me preserve the larger casts that come later

>From de3024d82bbefbc84cbeadfe32e7b3d1eedf6244 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Fri, 25 Apr 2025 16:51:23 -0400
Subject: [PATCH 1/2] [DirectX] Legalize i8 allocas

---
 llvm/lib/Target/DirectX/DXILLegalizePass.cpp  | 96 +++++++++++++++++--
 .../CodeGen/DirectX/legalize-i8-alloca.ll     | 53 ++++++++++
 2 files changed, 140 insertions(+), 9 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll

diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
index b62ff4c52f70c..f4e443543c728 100644
--- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
@@ -12,6 +12,7 @@
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InstIterator.h"
 #include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/Pass.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include <functional>
@@ -31,16 +32,17 @@ static void legalizeFreeze(Instruction &I,
   ToRemove.push_back(FI);
 }
 
-static void fixI8TruncUseChain(Instruction &I,
-                               SmallVectorImpl<Instruction *> &ToRemove,
-                               DenseMap<Value *, Value *> &ReplacedValues) {
+static void fixI8UseChain(Instruction &I,
+                          SmallVectorImpl<Instruction *> &ToRemove,
+                          DenseMap<Value *, Value *> &ReplacedValues) {
 
   auto ProcessOperands = [&](SmallVector<Value *> &NewOperands) {
     Type *InstrType = IntegerType::get(I.getContext(), 32);
 
     for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
       Value *Op = I.getOperand(OpIdx);
-      if (ReplacedValues.count(Op))
+      if (ReplacedValues.count(Op) &&
+          ReplacedValues[Op]->getType()->isIntegerTy())
         InstrType = ReplacedValues[Op]->getType();
     }
 
@@ -73,6 +75,31 @@ static void fixI8TruncUseChain(Instruction &I,
     }
   }
 
+  if (auto *Store = dyn_cast<StoreInst>(&I)) {
+    if (!Store->getValueOperand()->getType()->isIntegerTy(8))
+      return;
+    SmallVector<Value *> NewOperands;
+    ProcessOperands(NewOperands);
+    Value *NewStore = Builder.CreateStore(NewOperands[0], NewOperands[1]);
+    ReplacedValues[Store] = NewStore;
+    ToRemove.push_back(Store);
+    return;
+  }
+
+  if (auto *Load = dyn_cast<LoadInst>(&I)) {
+    if (!I.getType()->isIntegerTy(8))
+      return;
+    SmallVector<Value *> NewOperands;
+    ProcessOperands(NewOperands);
+    Type *ElementType = NewOperands[0]->getType();
+    if (auto *AI = dyn_cast<AllocaInst>(NewOperands[0]))
+      ElementType = AI->getAllocatedType();
+    LoadInst *NewLoad = Builder.CreateLoad(ElementType, NewOperands[0]);
+    ReplacedValues[Load] = NewLoad;
+    ToRemove.push_back(Load);
+    return;
+  }
+
   if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
     if (!I.getType()->isIntegerTy(8))
       return;
@@ -81,16 +108,29 @@ static void fixI8TruncUseChain(Instruction &I,
     Value *NewInst =
         Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]);
     if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) {
-      if (OBO->hasNoSignedWrap())
-        cast<BinaryOperator>(NewInst)->setHasNoSignedWrap();
-      if (OBO->hasNoUnsignedWrap())
-        cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap();
+      auto *NewBO = dyn_cast<BinaryOperator>(NewInst);
+      if (NewBO && OBO->hasNoSignedWrap())
+        NewBO->setHasNoSignedWrap();
+      if (NewBO && OBO->hasNoUnsignedWrap())
+        NewBO->setHasNoUnsignedWrap();
     }
     ReplacedValues[BO] = NewInst;
     ToRemove.push_back(BO);
     return;
   }
 
+  if (auto *Sel = dyn_cast<SelectInst>(&I)) {
+    if (!I.getType()->isIntegerTy(8))
+      return;
+    SmallVector<Value *> NewOperands;
+    ProcessOperands(NewOperands);
+    Value *NewInst = Builder.CreateSelect(Sel->getCondition(), NewOperands[1],
+                                          NewOperands[2]);
+    ReplacedValues[Sel] = NewInst;
+    ToRemove.push_back(Sel);
+    return;
+  }
+
   if (auto *Cmp = dyn_cast<CmpInst>(&I)) {
     if (!Cmp->getOperand(0)->getType()->isIntegerTy(8))
       return;
@@ -105,6 +145,7 @@ static void fixI8TruncUseChain(Instruction &I,
   }
 
   if (auto *Cast = dyn_cast<CastInst>(&I)) {
+
     if (Cast->getSrcTy()->isIntegerTy(8)) {
       ToRemove.push_back(Cast);
       Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]);
@@ -112,6 +153,42 @@ static void fixI8TruncUseChain(Instruction &I,
   }
 }
 
+static void upcastI8AllocasAndUses(Instruction &I,
+                                   SmallVectorImpl<Instruction *> &ToRemove,
+                                   DenseMap<Value *, Value *> &ReplacedValues) {
+  auto *AI = dyn_cast<AllocaInst>(&I);
+  if (!AI || !AI->getAllocatedType()->isIntegerTy(8))
+    return;
+
+  std::optional<Type *> TargetType;
+  bool Conflict = false;
+  for (User *U : AI->users()) {
+    auto *Load = dyn_cast<LoadInst>(U);
+    if (!Load)
+      continue;
+    for (User *LU : Load->users()) {
+      auto *Cast = dyn_cast<CastInst>(LU);
+      if (!Cast)
+        continue;
+      Type *T = Cast->getType();
+      if (!TargetType)
+        TargetType = T;
+
+      if (TargetType.value() != T) {
+        Conflict = true;
+        break;
+      }
+    }
+  }
+  if (!TargetType || Conflict)
+    return;
+
+  IRBuilder<> Builder(AI);
+  AllocaInst *NewAlloca = Builder.CreateAlloca(TargetType.value());
+  ReplacedValues[AI] = NewAlloca;
+  ToRemove.push_back(AI);
+}
+
 static void
 downcastI64toI32InsertExtractElements(Instruction &I,
                                       SmallVectorImpl<Instruction *> &ToRemove,
@@ -178,7 +255,8 @@ class DXILLegalizationPipeline {
       LegalizationPipeline;
 
   void initializeLegalizationPipeline() {
-    LegalizationPipeline.push_back(fixI8TruncUseChain);
+    LegalizationPipeline.push_back(upcastI8AllocasAndUses);
+    LegalizationPipeline.push_back(fixI8UseChain);
     LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
     LegalizationPipeline.push_back(legalizeFreeze);
   }
diff --git a/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll b/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll
new file mode 100644
index 0000000000000..a34b9be300a38
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll
@@ -0,0 +1,53 @@
+; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+define void @const_i8_store() {
+    %accum.i.flat = alloca [1 x i32], align 4
+    %i = alloca i8, align 4
+    store i8 1, ptr %i
+    %i8.load = load i8, ptr %i
+    %z = zext i8 %i8.load to i32
+    %gep = getelementptr i32, ptr %accum.i.flat, i32 0
+    store i32 %z, ptr %gep, align 4
+    ret void
+}
+
+define void @const_add_i8_store() {
+    %accum.i.flat = alloca [1 x i32], align 4
+    %i = alloca i8, align 4
+    %add_i8 = add nsw i8 3, 1
+    store i8 %add_i8, ptr %i
+    %i8.load = load i8, ptr %i
+    %z = zext i8 %i8.load to i32
+    %gep = getelementptr i32, ptr %accum.i.flat, i32 0
+    store i32 %z, ptr %gep, align 4
+    ret void
+}
+
+define void @var_i8_store(i1 %cmp.i8) {
+    %accum.i.flat = alloca [1 x i32], align 4
+    %i = alloca i8, align 4
+    %select.i8 = select i1 %cmp.i8, i8 1, i8 2
+    store i8 %select.i8, ptr %i
+    %i8.load = load i8, ptr %i
+    %z = zext i8 %i8.load to i32
+    %gep = getelementptr i32, ptr %accum.i.flat, i32 0
+    store i32 %z, ptr %gep, align 4
+    ret void
+}
+
+define void @conflicting_cast(i1 %cmp.i8) {
+    %accum.i.flat = alloca [2 x i32], align 4
+    %i = alloca i8, align 4
+    %select.i8 = select i1 %cmp.i8, i8 1, i8 2
+    store i8 %select.i8, ptr %i
+    %i8.load = load i8, ptr %i
+    %z = zext i8 %i8.load to i16
+    %gep1 = getelementptr i16, ptr %accum.i.flat, i32 0
+    store i16 %z, ptr %gep1, align 2
+    %gep2 = getelementptr i16, ptr %accum.i.flat, i32 1
+    store i16 %z, ptr %gep2, align 2
+    %z2 = zext i8 %i8.load to i32
+    %gep3 = getelementptr i32, ptr %accum.i.flat, i32 1
+    store i32 %z2, ptr %gep3, align 4
+    ret void
+}
\ No newline at end of file

>From c03a0f1f7e3ee0d65aaae515d0a8c4ecff7cf62b Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Fri, 25 Apr 2025 17:14:05 -0400
Subject: [PATCH 2/2] instead of detecting the conflicts lets pick the smallest
 value for the alloca then keep the cast but change the input type.

---
 llvm/lib/Target/DirectX/DXILLegalizePass.cpp  |  46 ++++---
 .../CodeGen/DirectX/legalize-i8-alloca.ll     | 128 ++++++++++++------
 2 files changed, 116 insertions(+), 58 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
index f4e443543c728..b7b209fcecbc9 100644
--- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
@@ -145,11 +145,23 @@ static void fixI8UseChain(Instruction &I,
   }
 
   if (auto *Cast = dyn_cast<CastInst>(&I)) {
-
-    if (Cast->getSrcTy()->isIntegerTy(8)) {
-      ToRemove.push_back(Cast);
-      Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]);
+    if (!Cast->getSrcTy()->isIntegerTy(8))
+      return;
+    
+    ToRemove.push_back(Cast);
+    auto* Replacement =ReplacedValues[Cast->getOperand(0)];
+    if (Cast->getType() == Replacement->getType()) {
+      Cast->replaceAllUsesWith(Replacement);
+      return;
     }
+    Value* AdjustedCast = nullptr;
+    if (Cast->getOpcode() == Instruction::ZExt)
+      AdjustedCast = Builder.CreateZExtOrTrunc(Replacement, Cast->getType());
+    if (Cast->getOpcode() == Instruction::SExt)
+      AdjustedCast = Builder.CreateSExtOrTrunc(Replacement, Cast->getType());
+  
+    if(AdjustedCast)
+      Cast->replaceAllUsesWith(AdjustedCast);
   }
 }
 
@@ -160,8 +172,9 @@ static void upcastI8AllocasAndUses(Instruction &I,
   if (!AI || !AI->getAllocatedType()->isIntegerTy(8))
     return;
 
-  std::optional<Type *> TargetType;
-  bool Conflict = false;
+  Type *SmallestType = nullptr;
+
+  // Gather all cast targets
   for (User *U : AI->users()) {
     auto *Load = dyn_cast<LoadInst>(U);
     if (!Load)
@@ -170,21 +183,20 @@ static void upcastI8AllocasAndUses(Instruction &I,
       auto *Cast = dyn_cast<CastInst>(LU);
       if (!Cast)
         continue;
-      Type *T = Cast->getType();
-      if (!TargetType)
-        TargetType = T;
-
-      if (TargetType.value() != T) {
-        Conflict = true;
-        break;
-      }
+      Type *Ty = Cast->getType();
+      if (!SmallestType ||
+          Ty->getPrimitiveSizeInBits() < SmallestType->getPrimitiveSizeInBits())
+        SmallestType = Ty;
     }
   }
-  if (!TargetType || Conflict)
-    return;
 
+  if (!SmallestType)
+    return; // no valid casts found
+
+  // Replace alloca
   IRBuilder<> Builder(AI);
-  AllocaInst *NewAlloca = Builder.CreateAlloca(TargetType.value());
+  auto *NewAlloca =
+      Builder.CreateAlloca(SmallestType);
   ReplacedValues[AI] = NewAlloca;
   ToRemove.push_back(AI);
 }
diff --git a/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll b/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll
index a34b9be300a38..529a69fca5d34 100644
--- a/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll
+++ b/llvm/test/CodeGen/DirectX/legalize-i8-alloca.ll
@@ -1,53 +1,99 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
 ; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 define void @const_i8_store() {
-    %accum.i.flat = alloca [1 x i32], align 4
-    %i = alloca i8, align 4
-    store i8 1, ptr %i
-    %i8.load = load i8, ptr %i
-    %z = zext i8 %i8.load to i32
-    %gep = getelementptr i32, ptr %accum.i.flat, i32 0
-    store i32 %z, ptr %gep, align 4
-    ret void
+; CHECK-LABEL: define void @const_i8_store() {
+; CHECK-NEXT:    [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca i32, align 4
+; CHECK-NEXT:    store i32 1, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
+; CHECK-NEXT:    store i32 [[TMP2]], ptr [[GEP]], align 4
+; CHECK-NEXT:    ret void
+;
+  %accum.i.flat = alloca [1 x i32], align 4
+  %i = alloca i8, align 4
+  store i8 1, ptr %i
+  %i8.load = load i8, ptr %i
+  %z = zext i8 %i8.load to i32
+  %gep = getelementptr i32, ptr %accum.i.flat, i32 0
+  store i32 %z, ptr %gep, align 4
+  ret void
 }
 
 define void @const_add_i8_store() {
-    %accum.i.flat = alloca [1 x i32], align 4
-    %i = alloca i8, align 4
-    %add_i8 = add nsw i8 3, 1
-    store i8 %add_i8, ptr %i
-    %i8.load = load i8, ptr %i
-    %z = zext i8 %i8.load to i32
-    %gep = getelementptr i32, ptr %accum.i.flat, i32 0
-    store i32 %z, ptr %gep, align 4
-    ret void
+; CHECK-LABEL: define void @const_add_i8_store() {
+; CHECK-NEXT:    [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca i32, align 4
+; CHECK-NEXT:    store i32 4, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
+; CHECK-NEXT:    store i32 [[TMP2]], ptr [[GEP]], align 4
+; CHECK-NEXT:    ret void
+;
+  %accum.i.flat = alloca [1 x i32], align 4
+  %i = alloca i8, align 4
+  %add_i8 = add nsw i8 3, 1
+  store i8 %add_i8, ptr %i
+  %i8.load = load i8, ptr %i
+  %z = zext i8 %i8.load to i32
+  %gep = getelementptr i32, ptr %accum.i.flat, i32 0
+  store i32 %z, ptr %gep, align 4
+  ret void
 }
 
 define void @var_i8_store(i1 %cmp.i8) {
-    %accum.i.flat = alloca [1 x i32], align 4
-    %i = alloca i8, align 4
-    %select.i8 = select i1 %cmp.i8, i8 1, i8 2
-    store i8 %select.i8, ptr %i
-    %i8.load = load i8, ptr %i
-    %z = zext i8 %i8.load to i32
-    %gep = getelementptr i32, ptr %accum.i.flat, i32 0
-    store i32 %z, ptr %gep, align 4
-    ret void
+; CHECK-LABEL: define void @var_i8_store(
+; CHECK-SAME: i1 [[CMP_I8:%.*]]) {
+; CHECK-NEXT:    [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca i32, align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = select i1 [[CMP_I8]], i32 1, i32 2
+; CHECK-NEXT:    store i32 [[TMP2]], ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[TMP3:%.*]] = load i32, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
+; CHECK-NEXT:    store i32 [[TMP3]], ptr [[GEP]], align 4
+; CHECK-NEXT:    ret void
+;
+  %accum.i.flat = alloca [1 x i32], align 4
+  %i = alloca i8, align 4
+  %select.i8 = select i1 %cmp.i8, i8 1, i8 2
+  store i8 %select.i8, ptr %i
+  %i8.load = load i8, ptr %i
+  %z = zext i8 %i8.load to i32
+  %gep = getelementptr i32, ptr %accum.i.flat, i32 0
+  store i32 %z, ptr %gep, align 4
+  ret void
 }
 
 define void @conflicting_cast(i1 %cmp.i8) {
-    %accum.i.flat = alloca [2 x i32], align 4
-    %i = alloca i8, align 4
-    %select.i8 = select i1 %cmp.i8, i8 1, i8 2
-    store i8 %select.i8, ptr %i
-    %i8.load = load i8, ptr %i
-    %z = zext i8 %i8.load to i16
-    %gep1 = getelementptr i16, ptr %accum.i.flat, i32 0
-    store i16 %z, ptr %gep1, align 2
-    %gep2 = getelementptr i16, ptr %accum.i.flat, i32 1
-    store i16 %z, ptr %gep2, align 2
-    %z2 = zext i8 %i8.load to i32
-    %gep3 = getelementptr i32, ptr %accum.i.flat, i32 1
-    store i32 %z2, ptr %gep3, align 4
-    ret void
-}
\ No newline at end of file
+; CHECK-LABEL: define void @conflicting_cast(
+; CHECK-SAME: i1 [[CMP_I8:%.*]]) {
+; CHECK-NEXT:    [[ACCUM_I_FLAT:%.*]] = alloca [2 x i32], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca i16, align 2
+; CHECK-NEXT:    [[TMP2:%.*]] = select i1 [[CMP_I8]], i32 1, i32 2
+; CHECK-NEXT:    store i32 [[TMP2]], ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[TMP3:%.*]] = load i16, ptr [[TMP1]], align 2
+; CHECK-NEXT:    [[GEP1:%.*]] = getelementptr i16, ptr [[ACCUM_I_FLAT]], i32 0
+; CHECK-NEXT:    store i16 [[TMP3]], ptr [[GEP1]], align 2
+; CHECK-NEXT:    [[GEP2:%.*]] = getelementptr i16, ptr [[ACCUM_I_FLAT]], i32 1
+; CHECK-NEXT:    store i16 [[TMP3]], ptr [[GEP2]], align 2
+; CHECK-NEXT:    [[TMP4:%.*]] = zext i16 [[TMP3]] to i32
+; CHECK-NEXT:    [[GEP3:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 1
+; CHECK-NEXT:    store i32 [[TMP4]], ptr [[GEP3]], align 4
+; CHECK-NEXT:    ret void
+;
+  %accum.i.flat = alloca [2 x i32], align 4
+  %i = alloca i8, align 4
+  %select.i8 = select i1 %cmp.i8, i8 1, i8 2
+  store i8 %select.i8, ptr %i
+  %i8.load = load i8, ptr %i
+  %z = zext i8 %i8.load to i16
+  %gep1 = getelementptr i16, ptr %accum.i.flat, i32 0
+  store i16 %z, ptr %gep1, align 2
+  %gep2 = getelementptr i16, ptr %accum.i.flat, i32 1
+  store i16 %z, ptr %gep2, align 2
+  %z2 = zext i8 %i8.load to i32
+  %gep3 = getelementptr i32, ptr %accum.i.flat, i32 1
+  store i32 %z2, ptr %gep3, align 4
+  ret void
+}



More information about the llvm-commits mailing list