[llvm] [DirectX] Start the creation of a DXIL Instruction legalizer (PR #131221)

Farzon Lotfi via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 13 14:34:03 PDT 2025


https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/131221

>From 34939c8663d317afd8956e11d75f9fe7648b3b97 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Thu, 27 Feb 2025 18:04:32 -0500
Subject: [PATCH 1/2] [DirectX] Working i8 legalization pass

---
 llvm/lib/Target/DirectX/CMakeLists.txt        |   1 +
 llvm/lib/Target/DirectX/DirectX.h             |   6 +
 .../Target/DirectX/DirectXPassRegistry.def    |   1 +
 .../Target/DirectX/DirectXTargetMachine.cpp   |   3 +
 llvm/lib/Target/DirectX/LegalizeI8Pass.cpp    | 127 ++++++++++++++++++
 llvm/lib/Target/DirectX/LegalizeI8Pass.h      |  23 ++++
 llvm/test/CodeGen/DirectX/legalize-i8.ll      |  16 +++
 7 files changed, 177 insertions(+)
 create mode 100644 llvm/lib/Target/DirectX/LegalizeI8Pass.cpp
 create mode 100644 llvm/lib/Target/DirectX/LegalizeI8Pass.h
 create mode 100644 llvm/test/CodeGen/DirectX/legalize-i8.ll

diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index 6904a1c0f1e73..0b3b6a23ce739 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -32,6 +32,7 @@ add_llvm_target(DirectXCodeGen
   DXILShaderFlags.cpp
   DXILTranslateMetadata.cpp
   DXILRootSignature.cpp
+  LegalizeI8Pass.cpp
   
   LINK_COMPONENTS
   Analysis
diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h
index 42aa0da16e8aa..482f3b5a8f694 100644
--- a/llvm/lib/Target/DirectX/DirectX.h
+++ b/llvm/lib/Target/DirectX/DirectX.h
@@ -47,6 +47,12 @@ void initializeDXILFlattenArraysLegacyPass(PassRegistry &);
 /// Pass to flatten arrays into a one dimensional DXIL legal form
 ModulePass *createDXILFlattenArraysLegacyPass();
 
+/// Initializer I8 legalizationPass
+void initializeLegalizeI8LegacyPass(PassRegistry &);
+
+/// Pass to remove i8 truncations
+FunctionPass *createLegalizeI8LegacyPass();
+
 /// Initializer for DXILOpLowering
 void initializeDXILOpLoweringLegacyPass(PassRegistry &);
 
diff --git a/llvm/lib/Target/DirectX/DirectXPassRegistry.def b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
index aee0a4ff83d43..297c6c10f68a3 100644
--- a/llvm/lib/Target/DirectX/DirectXPassRegistry.def
+++ b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
@@ -38,4 +38,5 @@ MODULE_PASS("print<dxil-root-signature>", dxil::RootSignatureAnalysisPrinter(dbg
 #define FUNCTION_PASS(NAME, CREATE_PASS)
 #endif
 FUNCTION_PASS("dxil-resource-access", DXILResourceAccess())
+FUNCTION_PASS("dxil-legalize-i8", LegalizeI8Pass())
 #undef FUNCTION_PASS
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index 82dc1c6af562a..ec6b69089fb64 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -25,6 +25,7 @@
 #include "DirectX.h"
 #include "DirectXSubtarget.h"
 #include "DirectXTargetTransformInfo.h"
+#include "LegalizeI8Pass.h"
 #include "TargetInfo/DirectXTargetInfo.h"
 #include "llvm/CodeGen/MachineModuleInfo.h"
 #include "llvm/CodeGen/Passes.h"
@@ -52,6 +53,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
   initializeDXILDataScalarizationLegacyPass(*PR);
   initializeDXILFlattenArraysLegacyPass(*PR);
   initializeScalarizerLegacyPassPass(*PR);
+  initializeLegalizeI8LegacyPass(*PR);
   initializeDXILPrepareModulePass(*PR);
   initializeEmbedDXILPassPass(*PR);
   initializeWriteDXILPassPass(*PR);
@@ -100,6 +102,7 @@ class DirectXPassConfig : public TargetPassConfig {
     DxilScalarOptions.ScalarizeLoadStore = true;
     addPass(createScalarizerPass(DxilScalarOptions));
     addPass(createDXILTranslateMetadataLegacyPass());
+    addPass(createLegalizeI8LegacyPass());
     addPass(createDXILOpLoweringLegacyPass());
     addPass(createDXILPrepareModulePass());
   }
diff --git a/llvm/lib/Target/DirectX/LegalizeI8Pass.cpp b/llvm/lib/Target/DirectX/LegalizeI8Pass.cpp
new file mode 100644
index 0000000000000..8faca2be8048c
--- /dev/null
+++ b/llvm/lib/Target/DirectX/LegalizeI8Pass.cpp
@@ -0,0 +1,127 @@
+//===- LegalizeI8Pass.cpp - A pass that reverts i8 conversions-*- C++ ---*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+//===---------------------------------------------------------------------===//
+///
+/// \file This file contains a pass to remove i8 truncations.
+///
+//===----------------------------------------------------------------------===//
+#include "DirectX.h"
+#include "LegalizeI8Pass.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/Pass.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include <map>
+#include <stack>
+
+#define DEBUG_TYPE "dxil-legalize-i8"
+
+using namespace llvm;
+namespace {
+
+class LegalizeI8Legacy : public FunctionPass {
+
+public:
+  bool runOnFunction(Function &F) override;
+  LegalizeI8Legacy() : FunctionPass(ID) {}
+
+  static char ID; // Pass identification.
+};
+} // namespace
+
+static bool fixI8TruncUseChain(Function &F) {
+    std::stack<Instruction*> ToRemove;
+    std::map<Value*, Value*> ReplacedValues;
+    
+    for (auto &I : instructions(F)) {
+        if (auto *Trunc = dyn_cast<TruncInst>(&I)) {
+            if (Trunc->getDestTy()->isIntegerTy(8)) {
+                ReplacedValues[Trunc] = Trunc->getOperand(0);
+                ToRemove.push(Trunc);
+            }
+        } else if (I.getType()->isIntegerTy(8)) {
+            IRBuilder<> Builder(&I);
+            
+            std::vector<Value*> NewOperands;
+            Type* InstrType = nullptr;
+            for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
+                Value *Op = I.getOperand(OpIdx);
+                if (ReplacedValues.count(Op)) {
+                    InstrType = ReplacedValues[Op]->getType();
+                    NewOperands.push_back(ReplacedValues[Op]);
+                }
+                else if (auto *Imm = dyn_cast<ConstantInt>(Op)) {
+                    APInt Value = Imm->getValue();
+                    unsigned NewBitWidth = InstrType->getIntegerBitWidth();
+                    // Note: options here are sext or sextOrTrunc. 
+                    // Since i8 isn't suppport we assume new values
+                    // will always have a higher bitness.
+                    APInt NewValue = Value.sext(NewBitWidth);
+                    NewOperands.push_back(ConstantInt::get(InstrType, NewValue));
+                } else {
+                    assert(!Op->getType()->isIntegerTy(8));
+                    NewOperands.push_back(Op);
+                }
+                
+            }
+            
+            Value *NewInst = nullptr;
+            if (auto *BO = dyn_cast<BinaryOperator>(&I))
+                NewInst = Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]);
+            else if (auto *Cmp = dyn_cast<CmpInst>(&I))
+                NewInst = Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0], NewOperands[1]);
+            else if (auto *Cast = dyn_cast<CastInst>(&I))
+                NewInst = Builder.CreateCast(Cast->getOpcode(), NewOperands[0], Cast->getDestTy());
+            else if (auto *UnaryOp = dyn_cast<UnaryOperator>(&I))
+                NewInst = Builder.CreateUnOp(UnaryOp->getOpcode(), NewOperands[0]);
+                
+            if (NewInst) {
+                ReplacedValues[&I] = NewInst;
+                ToRemove.push(&I);
+            }
+        } else if (auto *Sext = dyn_cast<SExtInst>(&I)) {
+            if (Sext->getSrcTy()->isIntegerTy(8)) {
+                ToRemove.push(Sext);
+                Sext->replaceAllUsesWith(ReplacedValues[Sext->getOperand(0)]);
+            }
+        }
+    }
+    
+    while (!ToRemove.empty()) {
+        Instruction *I = ToRemove.top();
+        I->eraseFromParent();
+        ToRemove.pop();
+    }
+    
+    return true;
+}
+
+PreservedAnalyses LegalizeI8Pass::run(Function &F, FunctionAnalysisManager &FAM) {
+    bool MadeChanges = fixI8TruncUseChain(F);
+    if (!MadeChanges)
+      return PreservedAnalyses::all();
+    PreservedAnalyses PA;
+    return PA;
+  }
+  
+  bool LegalizeI8Legacy::runOnFunction(Function &F) {
+    return fixI8TruncUseChain(F);
+  }
+  
+  char LegalizeI8Legacy::ID = 0;
+  
+  INITIALIZE_PASS_BEGIN(LegalizeI8Legacy, DEBUG_TYPE,
+                        "DXIL I8 Legalizer", false, false)
+  INITIALIZE_PASS_END(LegalizeI8Legacy, DEBUG_TYPE, "DXIL I8 Legalizer",
+                      false, false)
+  
+FunctionPass *llvm::createLegalizeI8LegacyPass() {
+    return new LegalizeI8Legacy();
+  }
\ No newline at end of file
diff --git a/llvm/lib/Target/DirectX/LegalizeI8Pass.h b/llvm/lib/Target/DirectX/LegalizeI8Pass.h
new file mode 100644
index 0000000000000..30ba4a88b8176
--- /dev/null
+++ b/llvm/lib/Target/DirectX/LegalizeI8Pass.h
@@ -0,0 +1,23 @@
+//===- LegalizeI8Pass.h - A pass that reverts i8 conversions-*- C++ -----*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+#ifndef LLVM_TARGET_DIRECTX_LEGALIZEI8_H
+#define LLVM_TARGET_DIRECTX_LEGALIZEI8_H
+
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+
+/// A pass that transforms multidimensional arrays into one-dimensional arrays.
+class LegalizeI8Pass : public PassInfoMixin<LegalizeI8Pass> {
+public:
+  PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM);
+};
+} // namespace llvm
+
+#endif // LLVM_TARGET_DIRECTX_LEGALIZEI8_H
diff --git a/llvm/test/CodeGen/DirectX/legalize-i8.ll b/llvm/test/CodeGen/DirectX/legalize-i8.ll
new file mode 100644
index 0000000000000..d9787531ae1c4
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/legalize-i8.ll
@@ -0,0 +1,16 @@
+; RUN: opt -S -passes='dxil-legalize-i8' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+define i32 @i8trunc(float %0) #0 {
+  ; CHECK-NOT: %4 = trunc nsw i32 %3 to i8
+  ; CHECK: add i32
+  ; CHECK: srem i32
+  ; CHECK-NOT: %7 = sext i8 %6 to i32
+  
+  %2 = fptosi float %0 to i32
+  %3 = srem i32 %2, 8
+  %4 = trunc nsw i32 %3 to i8
+  %5 = add nsw i8 %4, 1
+  %6 = srem i8 %5, 8
+  %7 = sext i8 %6 to i32
+  ret i32 %7
+}
\ No newline at end of file

>From 9fd631af64af28917025827606e2155f55d67741 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Thu, 13 Mar 2025 16:54:16 -0400
Subject: [PATCH 2/2] modify the i8 legalization pass to be a more generic
 legalization so we can reduce i64 insert/extracts to i32

---
 llvm/lib/Target/DirectX/CMakeLists.txt        |   2 +-
 llvm/lib/Target/DirectX/DXILLegalizePass.cpp  | 198 ++++++++++++++++++
 .../{LegalizeI8Pass.h => DXILLegalizePass.h}  |  11 +-
 llvm/lib/Target/DirectX/DirectX.h             |   9 +-
 .../Target/DirectX/DirectXPassRegistry.def    |   2 +-
 .../Target/DirectX/DirectXTargetMachine.cpp   |   6 +-
 llvm/lib/Target/DirectX/LegalizeI8Pass.cpp    | 127 -----------
 .../legalize-i64-extract-insert-elements.ll   |  24 +++
 llvm/test/CodeGen/DirectX/legalize-i8.ll      |  32 ++-
 llvm/test/CodeGen/DirectX/llc-pipeline.ll     |   1 +
 .../DirectX/llc-vector-load-scalarize.ll      |  32 +--
 .../CodeGen/DirectX/scalarize-two-calls.ll    |  16 +-
 12 files changed, 289 insertions(+), 171 deletions(-)
 create mode 100644 llvm/lib/Target/DirectX/DXILLegalizePass.cpp
 rename llvm/lib/Target/DirectX/{LegalizeI8Pass.h => DXILLegalizePass.h} (55%)
 delete mode 100644 llvm/lib/Target/DirectX/LegalizeI8Pass.cpp
 create mode 100644 llvm/test/CodeGen/DirectX/legalize-i64-extract-insert-elements.ll

diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index 0b3b6a23ce739..13f8adbe4f132 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -32,7 +32,7 @@ add_llvm_target(DirectXCodeGen
   DXILShaderFlags.cpp
   DXILTranslateMetadata.cpp
   DXILRootSignature.cpp
-  LegalizeI8Pass.cpp
+  DXILLegalizePass.cpp
   
   LINK_COMPONENTS
   Analysis
diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
new file mode 100644
index 0000000000000..934a6e75ad844
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
@@ -0,0 +1,198 @@
+//===- DXILLegalizePass.cpp - Legalizes llvm IR for DXIL-*- C++----------*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+//===---------------------------------------------------------------------===//
+///
+/// \file This file contains a pass to remove i8 truncations and i64 extract
+/// and insert elements.
+///
+//===----------------------------------------------------------------------===//
+#include "DXILLegalizePass.h"
+#include "DirectX.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/Pass.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include <functional>
+#include <map>
+#include <stack>
+#include <vector>
+
+#define DEBUG_TYPE "dxil-legalize"
+
+using namespace llvm;
+namespace {
+
+static bool fixI8TruncUseChain(Instruction &I,
+                               std::stack<Instruction *> &ToRemove,
+                               std::map<Value *, Value *> &ReplacedValues) {
+
+  if (auto *Trunc = dyn_cast<TruncInst>(&I)) {
+    if (Trunc->getDestTy()->isIntegerTy(8)) {
+      ReplacedValues[Trunc] = Trunc->getOperand(0);
+      ToRemove.push(Trunc);
+    }
+  } else if (I.getType()->isIntegerTy(8)) {
+    IRBuilder<> Builder(&I);
+
+    std::vector<Value *> NewOperands;
+    Type *InstrType = nullptr;
+    for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
+      Value *Op = I.getOperand(OpIdx);
+      if (ReplacedValues.count(Op)) {
+        InstrType = ReplacedValues[Op]->getType();
+        NewOperands.push_back(ReplacedValues[Op]);
+      } else if (auto *Imm = dyn_cast<ConstantInt>(Op)) {
+        APInt Value = Imm->getValue();
+        unsigned NewBitWidth = InstrType->getIntegerBitWidth();
+        // Note: options here are sext or sextOrTrunc.
+        // Since i8 isn't suppport we assume new values
+        // will always have a higher bitness.
+        APInt NewValue = Value.sext(NewBitWidth);
+        NewOperands.push_back(ConstantInt::get(InstrType, NewValue));
+      } else {
+        assert(!Op->getType()->isIntegerTy(8));
+        NewOperands.push_back(Op);
+      }
+    }
+
+    Value *NewInst = nullptr;
+    if (auto *BO = dyn_cast<BinaryOperator>(&I))
+      NewInst =
+          Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]);
+    else if (auto *Cmp = dyn_cast<CmpInst>(&I))
+      NewInst = Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0],
+                                  NewOperands[1]);
+    else if (auto *Cast = dyn_cast<CastInst>(&I))
+      NewInst = Builder.CreateCast(Cast->getOpcode(), NewOperands[0],
+                                   Cast->getDestTy());
+    else if (auto *UnaryOp = dyn_cast<UnaryOperator>(&I))
+      NewInst = Builder.CreateUnOp(UnaryOp->getOpcode(), NewOperands[0]);
+
+    if (NewInst) {
+      ReplacedValues[&I] = NewInst;
+      ToRemove.push(&I);
+    }
+  } else if (auto *Sext = dyn_cast<SExtInst>(&I)) {
+    if (Sext->getSrcTy()->isIntegerTy(8)) {
+      ToRemove.push(Sext);
+      Sext->replaceAllUsesWith(ReplacedValues[Sext->getOperand(0)]);
+    }
+  }
+
+  return !ToRemove.empty();
+}
+
+static bool downcastI64toI32InsertExtractElements(
+    Instruction &I, std::stack<Instruction *> &ToRemove, std::map<Value *, Value *> &) {
+
+  if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) {
+    Value *Idx = Extract->getIndexOperand();
+    auto *CI = dyn_cast<ConstantInt>(Idx);
+    if (CI && CI->getBitWidth() == 64) {
+      IRBuilder<> Builder(Extract);
+      int64_t IndexValue = CI->getSExtValue();
+      auto *Idx32 =
+          ConstantInt::get(Type::getInt32Ty(I.getContext()), IndexValue);
+      Value *NewExtract =
+          Builder.CreateExtractElement(Extract->getVectorOperand(), Idx32,Extract->getName());
+
+      Extract->replaceAllUsesWith(NewExtract);
+      ToRemove.push(Extract);
+    }
+  }
+
+  if (auto *Insert = dyn_cast<InsertElementInst>(&I)) {
+    Value *Idx = Insert->getOperand(2);
+    auto *CI = dyn_cast<ConstantInt>(Idx);
+    if (CI && CI->getBitWidth() == 64) {
+      int64_t IndexValue = CI->getSExtValue();
+      auto *Idx32 =
+          ConstantInt::get(Type::getInt32Ty(I.getContext()), IndexValue);
+      IRBuilder<> Builder(Insert);
+      Value *Insert32Index = Builder.CreateInsertElement(
+          Insert->getOperand(0), Insert->getOperand(1), Idx32, Insert->getName());
+
+      Insert->replaceAllUsesWith(Insert32Index);
+      ToRemove.push(Insert);
+    }
+  }
+
+  return !ToRemove.empty();
+}
+
+class DXILLegalizationPipeline {
+
+public:
+  DXILLegalizationPipeline() { initializeLegalizationPipeline(); }
+
+  bool runLegalizationPipeline(Function &F) {
+    std::stack<Instruction *> ToRemove;
+    std::map<Value *, Value *> ReplacedValues;
+    bool MadeChanges = false;
+    for (auto &I : instructions(F)) {
+      for (auto &LegalizationFn : LegalizationPipeline) {
+        MadeChanges = LegalizationFn(I, ToRemove, ReplacedValues);
+      }
+    }
+    while (!ToRemove.empty()) {
+      Instruction *I = ToRemove.top();
+      I->eraseFromParent();
+      ToRemove.pop();
+    }
+
+    return MadeChanges;
+  }
+
+private:
+  std::vector<std::function<bool(Instruction &, std::stack<Instruction *> &,
+                                 std::map<Value *, Value *> &)>>
+      LegalizationPipeline;
+
+  void initializeLegalizationPipeline() {
+    LegalizationPipeline.push_back(fixI8TruncUseChain);
+    LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
+  }
+};
+
+class DXILLegalizeLegacy : public FunctionPass {
+
+public:
+  bool runOnFunction(Function &F) override;
+  DXILLegalizeLegacy() : FunctionPass(ID) {}
+
+  static char ID; // Pass identification.
+};
+} // namespace
+
+PreservedAnalyses DXILLegalizePass::run(Function &F,
+                                        FunctionAnalysisManager &FAM) {
+  DXILLegalizationPipeline DXLegalize;
+  bool MadeChanges = DXLegalize.runLegalizationPipeline(F);
+  if (!MadeChanges)
+    return PreservedAnalyses::all();
+  PreservedAnalyses PA;
+  return PA;
+}
+
+bool DXILLegalizeLegacy::runOnFunction(Function &F) {
+  DXILLegalizationPipeline DXLegalize;
+  return DXLegalize.runLegalizationPipeline(F);
+}
+
+char DXILLegalizeLegacy::ID = 0;
+
+INITIALIZE_PASS_BEGIN(DXILLegalizeLegacy, DEBUG_TYPE, "DXIL Legalizer", false,
+                      false)
+INITIALIZE_PASS_END(DXILLegalizeLegacy, DEBUG_TYPE, "DXIL Legalizer", false,
+                    false)
+
+FunctionPass *llvm::createDXILLegalizeLegacyPass() {
+  return new DXILLegalizeLegacy();
+}
diff --git a/llvm/lib/Target/DirectX/LegalizeI8Pass.h b/llvm/lib/Target/DirectX/DXILLegalizePass.h
similarity index 55%
rename from llvm/lib/Target/DirectX/LegalizeI8Pass.h
rename to llvm/lib/Target/DirectX/DXILLegalizePass.h
index 30ba4a88b8176..39ef6f532dca0 100644
--- a/llvm/lib/Target/DirectX/LegalizeI8Pass.h
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.h
@@ -1,4 +1,4 @@
-//===- LegalizeI8Pass.h - A pass that reverts i8 conversions-*- C++ -----*-===//
+//===- DXILLegalizePass.h - Legalizes llvm IR for DXIL-*- C++------------*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,18 +6,17 @@
 //
 //===---------------------------------------------------------------------===//
 
-#ifndef LLVM_TARGET_DIRECTX_LEGALIZEI8_H
-#define LLVM_TARGET_DIRECTX_LEGALIZEI8_H
+#ifndef LLVM_TARGET_DIRECTX_LEGALIZE_H
+#define LLVM_TARGET_DIRECTX_LEGALIZE_H
 
 #include "llvm/IR/PassManager.h"
 
 namespace llvm {
 
-/// A pass that transforms multidimensional arrays into one-dimensional arrays.
-class LegalizeI8Pass : public PassInfoMixin<LegalizeI8Pass> {
+class DXILLegalizePass : public PassInfoMixin<DXILLegalizePass> {
 public:
   PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM);
 };
 } // namespace llvm
 
-#endif // LLVM_TARGET_DIRECTX_LEGALIZEI8_H
+#endif // LLVM_TARGET_DIRECTX_LEGALIZE_H
diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h
index 482f3b5a8f694..96a8a08c875f8 100644
--- a/llvm/lib/Target/DirectX/DirectX.h
+++ b/llvm/lib/Target/DirectX/DirectX.h
@@ -47,11 +47,12 @@ void initializeDXILFlattenArraysLegacyPass(PassRegistry &);
 /// Pass to flatten arrays into a one dimensional DXIL legal form
 ModulePass *createDXILFlattenArraysLegacyPass();
 
-/// Initializer I8 legalizationPass
-void initializeLegalizeI8LegacyPass(PassRegistry &);
+/// Initializer DXIL legalizationPass
+void initializeDXILLegalizeLegacyPass(PassRegistry &);
 
-/// Pass to remove i8 truncations
-FunctionPass *createLegalizeI8LegacyPass();
+/// Pass to Legalize DXIL by remove i8 truncations and i64 insert/extract
+/// elements
+FunctionPass *createDXILLegalizeLegacyPass();
 
 /// Initializer for DXILOpLowering
 void initializeDXILOpLoweringLegacyPass(PassRegistry &);
diff --git a/llvm/lib/Target/DirectX/DirectXPassRegistry.def b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
index 297c6c10f68a3..87d91ead1896f 100644
--- a/llvm/lib/Target/DirectX/DirectXPassRegistry.def
+++ b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
@@ -38,5 +38,5 @@ MODULE_PASS("print<dxil-root-signature>", dxil::RootSignatureAnalysisPrinter(dbg
 #define FUNCTION_PASS(NAME, CREATE_PASS)
 #endif
 FUNCTION_PASS("dxil-resource-access", DXILResourceAccess())
-FUNCTION_PASS("dxil-legalize-i8", LegalizeI8Pass())
+FUNCTION_PASS("dxil-legalize", DXILLegalizePass())
 #undef FUNCTION_PASS
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index ec6b69089fb64..ce408b4034f83 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -15,6 +15,7 @@
 #include "DXILDataScalarization.h"
 #include "DXILFlattenArrays.h"
 #include "DXILIntrinsicExpansion.h"
+#include "DXILLegalizePass.h"
 #include "DXILOpLowering.h"
 #include "DXILPrettyPrinter.h"
 #include "DXILResourceAccess.h"
@@ -25,7 +26,6 @@
 #include "DirectX.h"
 #include "DirectXSubtarget.h"
 #include "DirectXTargetTransformInfo.h"
-#include "LegalizeI8Pass.h"
 #include "TargetInfo/DirectXTargetInfo.h"
 #include "llvm/CodeGen/MachineModuleInfo.h"
 #include "llvm/CodeGen/Passes.h"
@@ -53,7 +53,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
   initializeDXILDataScalarizationLegacyPass(*PR);
   initializeDXILFlattenArraysLegacyPass(*PR);
   initializeScalarizerLegacyPassPass(*PR);
-  initializeLegalizeI8LegacyPass(*PR);
+  initializeDXILLegalizeLegacyPass(*PR);
   initializeDXILPrepareModulePass(*PR);
   initializeEmbedDXILPassPass(*PR);
   initializeWriteDXILPassPass(*PR);
@@ -101,8 +101,8 @@ class DirectXPassConfig : public TargetPassConfig {
     ScalarizerPassOptions DxilScalarOptions;
     DxilScalarOptions.ScalarizeLoadStore = true;
     addPass(createScalarizerPass(DxilScalarOptions));
+    addPass(createDXILLegalizeLegacyPass());
     addPass(createDXILTranslateMetadataLegacyPass());
-    addPass(createLegalizeI8LegacyPass());
     addPass(createDXILOpLoweringLegacyPass());
     addPass(createDXILPrepareModulePass());
   }
diff --git a/llvm/lib/Target/DirectX/LegalizeI8Pass.cpp b/llvm/lib/Target/DirectX/LegalizeI8Pass.cpp
deleted file mode 100644
index 8faca2be8048c..0000000000000
--- a/llvm/lib/Target/DirectX/LegalizeI8Pass.cpp
+++ /dev/null
@@ -1,127 +0,0 @@
-//===- LegalizeI8Pass.cpp - A pass that reverts i8 conversions-*- C++ ---*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===---------------------------------------------------------------------===//
-//===---------------------------------------------------------------------===//
-///
-/// \file This file contains a pass to remove i8 truncations.
-///
-//===----------------------------------------------------------------------===//
-#include "DirectX.h"
-#include "LegalizeI8Pass.h"
-#include "llvm/IR/Function.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/InstIterator.h"
-#include "llvm/IR/Instruction.h"
-#include "llvm/Pass.h"
-#include "llvm/Transforms/Utils/BasicBlockUtils.h"
-#include <map>
-#include <stack>
-
-#define DEBUG_TYPE "dxil-legalize-i8"
-
-using namespace llvm;
-namespace {
-
-class LegalizeI8Legacy : public FunctionPass {
-
-public:
-  bool runOnFunction(Function &F) override;
-  LegalizeI8Legacy() : FunctionPass(ID) {}
-
-  static char ID; // Pass identification.
-};
-} // namespace
-
-static bool fixI8TruncUseChain(Function &F) {
-    std::stack<Instruction*> ToRemove;
-    std::map<Value*, Value*> ReplacedValues;
-    
-    for (auto &I : instructions(F)) {
-        if (auto *Trunc = dyn_cast<TruncInst>(&I)) {
-            if (Trunc->getDestTy()->isIntegerTy(8)) {
-                ReplacedValues[Trunc] = Trunc->getOperand(0);
-                ToRemove.push(Trunc);
-            }
-        } else if (I.getType()->isIntegerTy(8)) {
-            IRBuilder<> Builder(&I);
-            
-            std::vector<Value*> NewOperands;
-            Type* InstrType = nullptr;
-            for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
-                Value *Op = I.getOperand(OpIdx);
-                if (ReplacedValues.count(Op)) {
-                    InstrType = ReplacedValues[Op]->getType();
-                    NewOperands.push_back(ReplacedValues[Op]);
-                }
-                else if (auto *Imm = dyn_cast<ConstantInt>(Op)) {
-                    APInt Value = Imm->getValue();
-                    unsigned NewBitWidth = InstrType->getIntegerBitWidth();
-                    // Note: options here are sext or sextOrTrunc. 
-                    // Since i8 isn't suppport we assume new values
-                    // will always have a higher bitness.
-                    APInt NewValue = Value.sext(NewBitWidth);
-                    NewOperands.push_back(ConstantInt::get(InstrType, NewValue));
-                } else {
-                    assert(!Op->getType()->isIntegerTy(8));
-                    NewOperands.push_back(Op);
-                }
-                
-            }
-            
-            Value *NewInst = nullptr;
-            if (auto *BO = dyn_cast<BinaryOperator>(&I))
-                NewInst = Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]);
-            else if (auto *Cmp = dyn_cast<CmpInst>(&I))
-                NewInst = Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0], NewOperands[1]);
-            else if (auto *Cast = dyn_cast<CastInst>(&I))
-                NewInst = Builder.CreateCast(Cast->getOpcode(), NewOperands[0], Cast->getDestTy());
-            else if (auto *UnaryOp = dyn_cast<UnaryOperator>(&I))
-                NewInst = Builder.CreateUnOp(UnaryOp->getOpcode(), NewOperands[0]);
-                
-            if (NewInst) {
-                ReplacedValues[&I] = NewInst;
-                ToRemove.push(&I);
-            }
-        } else if (auto *Sext = dyn_cast<SExtInst>(&I)) {
-            if (Sext->getSrcTy()->isIntegerTy(8)) {
-                ToRemove.push(Sext);
-                Sext->replaceAllUsesWith(ReplacedValues[Sext->getOperand(0)]);
-            }
-        }
-    }
-    
-    while (!ToRemove.empty()) {
-        Instruction *I = ToRemove.top();
-        I->eraseFromParent();
-        ToRemove.pop();
-    }
-    
-    return true;
-}
-
-PreservedAnalyses LegalizeI8Pass::run(Function &F, FunctionAnalysisManager &FAM) {
-    bool MadeChanges = fixI8TruncUseChain(F);
-    if (!MadeChanges)
-      return PreservedAnalyses::all();
-    PreservedAnalyses PA;
-    return PA;
-  }
-  
-  bool LegalizeI8Legacy::runOnFunction(Function &F) {
-    return fixI8TruncUseChain(F);
-  }
-  
-  char LegalizeI8Legacy::ID = 0;
-  
-  INITIALIZE_PASS_BEGIN(LegalizeI8Legacy, DEBUG_TYPE,
-                        "DXIL I8 Legalizer", false, false)
-  INITIALIZE_PASS_END(LegalizeI8Legacy, DEBUG_TYPE, "DXIL I8 Legalizer",
-                      false, false)
-  
-FunctionPass *llvm::createLegalizeI8LegacyPass() {
-    return new LegalizeI8Legacy();
-  }
\ No newline at end of file
diff --git a/llvm/test/CodeGen/DirectX/legalize-i64-extract-insert-elements.ll b/llvm/test/CodeGen/DirectX/legalize-i64-extract-insert-elements.ll
new file mode 100644
index 0000000000000..8a59986524c90
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/legalize-i64-extract-insert-elements.ll
@@ -0,0 +1,24 @@
+; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+define noundef <4 x float> @float4_extract(<4 x float> noundef %a) {
+entry:
+  ; CHECK: [[ee0:%.*]] = extractelement <4 x float> %a, i32 0
+  ; CHECK: [[ee1:%.*]] = extractelement <4 x float> %a, i32 1
+  ; CHECK: [[ee2:%.*]] = extractelement <4 x float> %a, i32 2
+  ; CHECK: [[ee3:%.*]] = extractelement <4 x float> %a, i32 3
+  ; CHECK: insertelement <4 x float> poison, float [[ee0]], i32 0
+  ; CHECK: insertelement <4 x float> %{{.*}}, float [[ee1]], i32 1
+  ; CHECK: insertelement <4 x float> %{{.*}}, float [[ee2]], i32 2
+  ; CHECK: insertelement <4 x float> %{{.*}}, float [[ee3]], i32 3
+
+  %a.i0 = extractelement <4 x float> %a, i64 0
+  %a.i1 = extractelement <4 x float> %a, i64 1
+  %a.i2 = extractelement <4 x float> %a, i64 2
+  %a.i3 = extractelement <4 x float> %a, i64 3
+  
+  %.upto0 = insertelement <4 x float> poison, float %a.i0, i64 0
+  %.upto1 = insertelement <4 x float> %.upto0, float %a.i1, i64 1
+  %.upto2 = insertelement <4 x float> %.upto1, float %a.i2, i64 2
+  %0 = insertelement <4 x float> %.upto2, float %a.i3, i64 3
+  ret <4 x float> %0
+}
diff --git a/llvm/test/CodeGen/DirectX/legalize-i8.ll b/llvm/test/CodeGen/DirectX/legalize-i8.ll
index d9787531ae1c4..d7cea585a056b 100644
--- a/llvm/test/CodeGen/DirectX/legalize-i8.ll
+++ b/llvm/test/CodeGen/DirectX/legalize-i8.ll
@@ -1,9 +1,20 @@
-; RUN: opt -S -passes='dxil-legalize-i8' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 define i32 @i8trunc(float %0) #0 {
   ; CHECK-NOT: %4 = trunc nsw i32 %3 to i8
   ; CHECK: add i32
-  ; CHECK: srem i32
+  ; CHECK-NEXT: srem i32
+  ; CHECK-NEXT: sub i32
+  ; CHECK-NEXT: mul i32
+  ; CHECK-NEXT: udiv i32
+  ; CHECK-NEXT: sdiv i32
+  ; CHECK-NEXT: urem i32
+  ; CHECK-NEXT: and i32
+  ; CHECK-NEXT: or i32
+  ; CHECK-NEXT: xor i32
+  ; CHECK-NEXT: shl i32
+  ; CHECK-NEXT: lshr i32
+  ; CHECK-NEXT: ashr i32
   ; CHECK-NOT: %7 = sext i8 %6 to i32
   
   %2 = fptosi float %0 to i32
@@ -11,6 +22,17 @@ define i32 @i8trunc(float %0) #0 {
   %4 = trunc nsw i32 %3 to i8
   %5 = add nsw i8 %4, 1
   %6 = srem i8 %5, 8
-  %7 = sext i8 %6 to i32
-  ret i32 %7
-}
\ No newline at end of file
+  %7 = sub i8 %6, 1
+  %8 = mul i8 %7, 1
+  %9 = udiv i8 %8, 1
+  %10 = sdiv i8 %9, 1
+  %11 = urem i8 %10, 1
+  %12 = and i8 %11, 1
+  %13 = or i8 %12, 1
+  %14 = xor i8 %13, 1
+  %15 = shl i8 %14, 1
+  %16 = lshr i8 %15, 1
+  %17 = ashr i8 %16, 1
+  %18 = sext i8 %17 to i32
+  ret i32 %18
+}
diff --git a/llvm/test/CodeGen/DirectX/llc-pipeline.ll b/llvm/test/CodeGen/DirectX/llc-pipeline.ll
index 3a9af4d744f98..ee70cec534bc5 100644
--- a/llvm/test/CodeGen/DirectX/llc-pipeline.ll
+++ b/llvm/test/CodeGen/DirectX/llc-pipeline.ll
@@ -21,6 +21,7 @@
 ; CHECK-NEXT:     DXIL Resource Access
 ; CHECK-NEXT:     Dominator Tree Construction
 ; CHECK-NEXT:     Scalarize vector operations
+; CHECK-NEXT:   DXIL Legalizer
 ; CHECK-NEXT:   DXIL Resource Binding Analysis
 ; CHECK-NEXT:   DXIL Module Metadata analysis
 ; CHECK-NEXT:   DXIL Shader Flag Analysis
diff --git a/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll b/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll
index 4e522c6ef5da7..7e5a92e1311f8 100644
--- a/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll
+++ b/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll
@@ -44,10 +44,10 @@ define <4 x i32> @load_array_vec_test() #0 {
 ; CHECK-NEXT:    [[DOTI19:%.*]] = add i32 [[TMP4]], [[DOTI13]]
 ; CHECK-NEXT:    [[DOTI210:%.*]] = add i32 [[TMP6]], [[DOTI25]]
 ; CHECK-NEXT:    [[DOTI311:%.*]] = add i32 [[TMP8]], [[DOTI37]]
-; CHECK-NEXT:    [[DOTUPTO015:%.*]] = insertelement <4 x i32> poison, i32 [[DOTI08]], i64 0
-; CHECK-NEXT:    [[DOTUPTO116:%.*]] = insertelement <4 x i32> [[DOTUPTO015]], i32 [[DOTI19]], i64 1
-; CHECK-NEXT:    [[DOTUPTO217:%.*]] = insertelement <4 x i32> [[DOTUPTO116]], i32 [[DOTI210]], i64 2
-; CHECK-NEXT:    [[TMP16:%.*]] = insertelement <4 x i32> [[DOTUPTO217]], i32 [[DOTI311]], i64 3
+; CHECK-NEXT:    [[DOTUPTO015:%.*]] = insertelement <4 x i32> poison, i32 [[DOTI08]], i32 0
+; CHECK-NEXT:    [[DOTUPTO116:%.*]] = insertelement <4 x i32> [[DOTUPTO015]], i32 [[DOTI19]], i32 1
+; CHECK-NEXT:    [[DOTUPTO217:%.*]] = insertelement <4 x i32> [[DOTUPTO116]], i32 [[DOTI210]], i32 2
+; CHECK-NEXT:    [[TMP16:%.*]] = insertelement <4 x i32> [[DOTUPTO217]], i32 [[DOTI311]], i32 3
 ; CHECK-NEXT:    ret <4 x i32> [[TMP16]]
 ;
   %1 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([2 x <4 x i32>], [2 x <4 x i32>] addrspace(3)* @"arrayofVecData", i32 0, i32 0), align 4
@@ -68,10 +68,10 @@ define <4 x i32> @load_vec_test() #0 {
 ; CHECK-NEXT:    [[TMP6:%.*]] = load i32, ptr addrspace(3) [[TMP5]], align 4
 ; CHECK-NEXT:    [[TMP7:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @vecData.scalarized, i32 3) to ptr addrspace(3)
 ; CHECK-NEXT:    [[TMP8:%.*]] = load i32, ptr addrspace(3) [[TMP7]], align 4
-; CHECK-NEXT:    [[DOTUPTO0:%.*]] = insertelement <4 x i32> poison, i32 [[TMP2]], i64 0
-; CHECK-NEXT:    [[DOTUPTO1:%.*]] = insertelement <4 x i32> [[DOTUPTO0]], i32 [[TMP4]], i64 1
-; CHECK-NEXT:    [[DOTUPTO2:%.*]] = insertelement <4 x i32> [[DOTUPTO1]], i32 [[TMP6]], i64 2
-; CHECK-NEXT:    [[TMP9:%.*]] = insertelement <4 x i32> [[DOTUPTO2]], i32 [[TMP8]], i64 3
+; CHECK-NEXT:    [[DOTUPTO0:%.*]] = insertelement <4 x i32> poison, i32 [[TMP2]], i32 0
+; CHECK-NEXT:    [[DOTUPTO1:%.*]] = insertelement <4 x i32> [[DOTUPTO0]], i32 [[TMP4]], i32 1
+; CHECK-NEXT:    [[DOTUPTO2:%.*]] = insertelement <4 x i32> [[DOTUPTO1]], i32 [[TMP6]], i32 2
+; CHECK-NEXT:    [[TMP9:%.*]] = insertelement <4 x i32> [[DOTUPTO2]], i32 [[TMP8]], i32 3
 ; CHECK-NEXT:    ret <4 x i32> [[TMP9]]
 ;
   %1 = load <4 x i32>, <4 x i32> addrspace(3)* @"vecData", align 4
@@ -93,10 +93,10 @@ define <4 x i32> @load_static_array_of_vec_test(i32 %index) #0 {
 ; CHECK-NEXT:    [[TMP5:%.*]] = bitcast ptr [[DOTFLAT]] to ptr
 ; CHECK-NEXT:    [[DOTFLAT_I3:%.*]] = getelementptr i32, ptr [[TMP5]], i32 3
 ; CHECK-NEXT:    [[DOTI3:%.*]] = load i32, ptr [[DOTFLAT_I3]], align 4
-; CHECK-NEXT:    [[DOTUPTO0:%.*]] = insertelement <4 x i32> poison, i32 [[TMP2]], i64 0
-; CHECK-NEXT:    [[DOTUPTO1:%.*]] = insertelement <4 x i32> [[DOTUPTO0]], i32 [[DOTI1]], i64 1
-; CHECK-NEXT:    [[DOTUPTO2:%.*]] = insertelement <4 x i32> [[DOTUPTO1]], i32 [[DOTI2]], i64 2
-; CHECK-NEXT:    [[TMP6:%.*]] = insertelement <4 x i32> [[DOTUPTO2]], i32 [[DOTI3]], i64 3
+; CHECK-NEXT:    [[DOTUPTO0:%.*]] = insertelement <4 x i32> poison, i32 [[TMP2]], i32 0
+; CHECK-NEXT:    [[DOTUPTO1:%.*]] = insertelement <4 x i32> [[DOTUPTO0]], i32 [[DOTI1]], i32 1
+; CHECK-NEXT:    [[DOTUPTO2:%.*]] = insertelement <4 x i32> [[DOTUPTO1]], i32 [[DOTI2]], i32 2
+; CHECK-NEXT:    [[TMP6:%.*]] = insertelement <4 x i32> [[DOTUPTO2]], i32 [[DOTI3]], i32 3
 ; CHECK-NEXT:    ret <4 x i32> [[TMP6]]
 ;
   %3 = getelementptr inbounds [3 x <4 x i32>], [3 x <4 x i32>]* @staticArrayOfVecData, i32 0, i32 %index
@@ -127,10 +127,10 @@ define <4 x i32> @multid_load_test() #0 {
 ; CHECK-NEXT:    [[DOTI19:%.*]] = add i32 [[TMP4]], [[DOTI13]]
 ; CHECK-NEXT:    [[DOTI210:%.*]] = add i32 [[TMP6]], [[DOTI25]]
 ; CHECK-NEXT:    [[DOTI311:%.*]] = add i32 [[TMP8]], [[DOTI37]]
-; CHECK-NEXT:    [[DOTUPTO015:%.*]] = insertelement <4 x i32> poison, i32 [[DOTI08]], i64 0
-; CHECK-NEXT:    [[DOTUPTO116:%.*]] = insertelement <4 x i32> [[DOTUPTO015]], i32 [[DOTI19]], i64 1
-; CHECK-NEXT:    [[DOTUPTO217:%.*]] = insertelement <4 x i32> [[DOTUPTO116]], i32 [[DOTI210]], i64 2
-; CHECK-NEXT:    [[TMP16:%.*]] = insertelement <4 x i32> [[DOTUPTO217]], i32 [[DOTI311]], i64 3
+; CHECK-NEXT:    [[DOTUPTO015:%.*]] = insertelement <4 x i32> poison, i32 [[DOTI08]], i32 0
+; CHECK-NEXT:    [[DOTUPTO116:%.*]] = insertelement <4 x i32> [[DOTUPTO015]], i32 [[DOTI19]], i32 1
+; CHECK-NEXT:    [[DOTUPTO217:%.*]] = insertelement <4 x i32> [[DOTUPTO116]], i32 [[DOTI210]], i32 2
+; CHECK-NEXT:    [[TMP16:%.*]] = insertelement <4 x i32> [[DOTUPTO217]], i32 [[DOTI311]], i32 3
 ; CHECK-NEXT:    ret <4 x i32> [[TMP16]]
 ;
   %1 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([3 x [3 x <4 x i32>]], [3 x [3 x <4 x i32>]] addrspace(3)* @"groushared2dArrayofVectors", i32 0, i32 0, i32 0), align 4
diff --git a/llvm/test/CodeGen/DirectX/scalarize-two-calls.ll b/llvm/test/CodeGen/DirectX/scalarize-two-calls.ll
index 0546a5505416f..7e8f58c0576f0 100644
--- a/llvm/test/CodeGen/DirectX/scalarize-two-calls.ll
+++ b/llvm/test/CodeGen/DirectX/scalarize-two-calls.ll
@@ -3,22 +3,22 @@
 ; CHECK: target triple = "dxilv1.3-pc-shadermodel6.3-library"
 ; CHECK-LABEL: cos_sin_float_test
 define noundef <4 x float> @cos_sin_float_test(<4 x float> noundef %a) #0 {
-    ; CHECK: [[ee0:%.*]] = extractelement <4 x float> %a, i64 0
+    ; CHECK: [[ee0:%.*]] = extractelement <4 x float> %a, i32 0
     ; CHECK: [[ie0:%.*]] = call float @dx.op.unary.f32(i32 13, float [[ee0]])
-    ; CHECK: [[ee1:%.*]] = extractelement <4 x float> %a, i64 1
+    ; CHECK: [[ee1:%.*]] = extractelement <4 x float> %a, i32 1
     ; CHECK: [[ie1:%.*]] = call float @dx.op.unary.f32(i32 13, float [[ee1]])
-    ; CHECK: [[ee2:%.*]] = extractelement <4 x float> %a, i64 2
+    ; CHECK: [[ee2:%.*]] = extractelement <4 x float> %a, i32 2
     ; CHECK: [[ie2:%.*]] = call float @dx.op.unary.f32(i32 13, float [[ee2]])
-    ; CHECK: [[ee3:%.*]] = extractelement <4 x float> %a, i64 3
+    ; CHECK: [[ee3:%.*]] = extractelement <4 x float> %a, i32 3
     ; CHECK: [[ie3:%.*]] = call float @dx.op.unary.f32(i32 13, float [[ee3]])
     ; CHECK: [[ie4:%.*]] = call float @dx.op.unary.f32(i32 12, float [[ie0]])
     ; CHECK: [[ie5:%.*]] = call float @dx.op.unary.f32(i32 12, float [[ie1]])
     ; CHECK: [[ie6:%.*]] = call float @dx.op.unary.f32(i32 12, float [[ie2]])
     ; CHECK: [[ie7:%.*]] = call float @dx.op.unary.f32(i32 12, float [[ie3]])
-    ; CHECK: insertelement <4 x float> poison, float [[ie4]], i64 0
-    ; CHECK: insertelement <4 x float> %{{.*}}, float [[ie5]], i64 1
-    ; CHECK: insertelement <4 x float> %{{.*}}, float [[ie6]], i64 2
-    ; CHECK: insertelement <4 x float> %{{.*}}, float [[ie7]], i64 3
+    ; CHECK: insertelement <4 x float> poison, float [[ie4]], i32 0
+    ; CHECK: insertelement <4 x float> %{{.*}}, float [[ie5]], i32 1
+    ; CHECK: insertelement <4 x float> %{{.*}}, float [[ie6]], i32 2
+    ; CHECK: insertelement <4 x float> %{{.*}}, float [[ie7]], i32 3
     %2 = tail call <4 x float> @llvm.sin.v4f32(<4 x float> %a) 
     %3 = tail call <4 x float> @llvm.cos.v4f32(<4 x float> %2) 
     ret <4 x float> %3 



More information about the llvm-commits mailing list