[llvm] [DirectX] Start the creation of a DXIL Instruction legalizer (PR #131221)
Farzon Lotfi via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 13 19:03:00 PDT 2025
================
@@ -0,0 +1,201 @@
+//===- DXILLegalizePass.cpp - Legalizes llvm IR for DXIL-*- C++----------*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+//===---------------------------------------------------------------------===//
+///
+/// \file This file contains a pass to remove i8 truncations and i64 extract
+/// and insert elements.
+///
+//===----------------------------------------------------------------------===//
+#include "DXILLegalizePass.h"
+#include "DirectX.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/Pass.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include <functional>
+#include <map>
+#include <stack>
+#include <vector>
+
+#define DEBUG_TYPE "dxil-legalize"
+
+using namespace llvm;
+namespace {
+
+static bool fixI8TruncUseChain(Instruction &I,
+ std::stack<Instruction *> &ToRemove,
+ std::map<Value *, Value *> &ReplacedValues) {
+
+ if (auto *Trunc = dyn_cast<TruncInst>(&I)) {
+ if (Trunc->getDestTy()->isIntegerTy(8)) {
+ ReplacedValues[Trunc] = Trunc->getOperand(0);
+ ToRemove.push(Trunc);
+ }
+ } else if (I.getType()->isIntegerTy(8)) {
+ IRBuilder<> Builder(&I);
+
+ std::vector<Value *> NewOperands;
+ Type *InstrType = nullptr;
+ for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
+ Value *Op = I.getOperand(OpIdx);
+ if (ReplacedValues.count(Op)) {
+ InstrType = ReplacedValues[Op]->getType();
+ NewOperands.push_back(ReplacedValues[Op]);
+ } else if (auto *Imm = dyn_cast<ConstantInt>(Op)) {
+ APInt Value = Imm->getValue();
+ unsigned NewBitWidth = InstrType->getIntegerBitWidth();
+ // Note: options here are sext or sextOrTrunc.
+ // Since i8 isn't suppport we assume new values
+ // will always have a higher bitness.
+ APInt NewValue = Value.sext(NewBitWidth);
+ NewOperands.push_back(ConstantInt::get(InstrType, NewValue));
+ } else {
+ assert(!Op->getType()->isIntegerTy(8));
+ NewOperands.push_back(Op);
+ }
+ }
+
+ Value *NewInst = nullptr;
+ if (auto *BO = dyn_cast<BinaryOperator>(&I))
+ NewInst =
+ Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]);
+ else if (auto *Cmp = dyn_cast<CmpInst>(&I))
+ NewInst = Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0],
+ NewOperands[1]);
+ else if (auto *Cast = dyn_cast<CastInst>(&I))
+ NewInst = Builder.CreateCast(Cast->getOpcode(), NewOperands[0],
+ Cast->getDestTy());
+ else if (auto *UnaryOp = dyn_cast<UnaryOperator>(&I))
+ NewInst = Builder.CreateUnOp(UnaryOp->getOpcode(), NewOperands[0]);
+
+ if (NewInst) {
+ ReplacedValues[&I] = NewInst;
+ ToRemove.push(&I);
+ }
+ } else if (auto *Sext = dyn_cast<SExtInst>(&I)) {
+ if (Sext->getSrcTy()->isIntegerTy(8)) {
+ ToRemove.push(Sext);
+ Sext->replaceAllUsesWith(ReplacedValues[Sext->getOperand(0)]);
+ }
+ }
+
+ return !ToRemove.empty();
+}
+
+static bool
+downcastI64toI32InsertExtractElements(Instruction &I,
+ std::stack<Instruction *> &ToRemove,
+ std::map<Value *, Value *> &) {
+
+ if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) {
+ Value *Idx = Extract->getIndexOperand();
+ auto *CI = dyn_cast<ConstantInt>(Idx);
+ if (CI && CI->getBitWidth() == 64) {
+ IRBuilder<> Builder(Extract);
+ int64_t IndexValue = CI->getSExtValue();
+ auto *Idx32 =
+ ConstantInt::get(Type::getInt32Ty(I.getContext()), IndexValue);
+ Value *NewExtract = Builder.CreateExtractElement(
+ Extract->getVectorOperand(), Idx32, Extract->getName());
+
+ Extract->replaceAllUsesWith(NewExtract);
+ ToRemove.push(Extract);
+ }
+ }
+
+ if (auto *Insert = dyn_cast<InsertElementInst>(&I)) {
+ Value *Idx = Insert->getOperand(2);
+ auto *CI = dyn_cast<ConstantInt>(Idx);
+ if (CI && CI->getBitWidth() == 64) {
+ int64_t IndexValue = CI->getSExtValue();
+ auto *Idx32 =
+ ConstantInt::get(Type::getInt32Ty(I.getContext()), IndexValue);
+ IRBuilder<> Builder(Insert);
+ Value *Insert32Index = Builder.CreateInsertElement(
+ Insert->getOperand(0), Insert->getOperand(1), Idx32,
+ Insert->getName());
+
+ Insert->replaceAllUsesWith(Insert32Index);
+ ToRemove.push(Insert);
+ }
+ }
+
+ return !ToRemove.empty();
+}
+
+class DXILLegalizationPipeline {
+
+public:
+ DXILLegalizationPipeline() { initializeLegalizationPipeline(); }
+
+ bool runLegalizationPipeline(Function &F) {
+ std::stack<Instruction *> ToRemove;
+ std::map<Value *, Value *> ReplacedValues;
+ bool MadeChanges = false;
+ for (auto &I : instructions(F)) {
+ for (auto &LegalizationFn : LegalizationPipeline) {
+ MadeChanges = LegalizationFn(I, ToRemove, ReplacedValues);
+ }
+ }
+ while (!ToRemove.empty()) {
+ Instruction *I = ToRemove.top();
+ I->eraseFromParent();
+ ToRemove.pop();
+ }
+
+ return MadeChanges;
----------------
farzonl wrote:
its suppose to be `MadeChanges |= .. ;` not `MadeChanges = .. ;` I have that change staged, its coming after I fix a different bug first.
https://github.com/llvm/llvm-project/pull/131221
More information about the llvm-commits
mailing list