[llvm] [DirectX] Data Scalarization of Vectors in Global Scope (PR #110029)
Farzon Lotfi via llvm-commits
llvm-commits at lists.llvm.org
Wed Sep 25 17:42:41 PDT 2024
================
@@ -0,0 +1,284 @@
+//===- DXILDataScalarization.cpp - Perform DXIL Data Legalization----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------===//
+
+#include "DXILDataScalarization.h"
+#include "DirectX.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstVisitor.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Operator.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/ReplaceConstant.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/Local.h"
+
+#define DEBUG_TYPE "dxil-data-scalarization"
+#define Max_VEC_SIZE 4
+
+using namespace llvm;
+
+static void findAndReplaceVectors(Module &M);
+
+class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
+public:
+ DataScalarizerVisitor() : GlobalMap() {}
+ bool visit(Function &F);
+ // InstVisitor methods. They return true if the instruction was scalarized,
+ // false if nothing changed.
+ bool visitInstruction(Instruction &I) { return false; }
+ bool visitSelectInst(SelectInst &SI) { return false; }
+ bool visitICmpInst(ICmpInst &ICI) { return false; }
+ bool visitFCmpInst(FCmpInst &FCI) { return false; }
+ bool visitUnaryOperator(UnaryOperator &UO) { return false; }
+ bool visitBinaryOperator(BinaryOperator &BO) { return false; }
+ bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
+ bool visitCastInst(CastInst &CI) { return false; }
+ bool visitBitCastInst(BitCastInst &BCI) { return false; }
+ bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
+ bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
+ bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
+ bool visitPHINode(PHINode &PHI) { return false; }
+ bool visitLoadInst(LoadInst &LI);
+ bool visitStoreInst(StoreInst &SI);
+ bool visitCallInst(CallInst &ICI) { return false; }
+ bool visitFreezeInst(FreezeInst &FI) { return false; }
+ friend void findAndReplaceVectors(llvm::Module &M);
+
+private:
+ GlobalVariable *getNewGlobalIfExists(Value *CurrOperand);
+ DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
+ SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
+ bool finish();
+};
+
+bool DataScalarizerVisitor::visit(Function &F) {
+ assert(!GlobalMap.empty());
+ ReversePostOrderTraversal<BasicBlock *> RPOT(&F.getEntryBlock());
+ for (BasicBlock *BB : RPOT) {
+ for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
+ Instruction *I = &*II;
+ bool Done = InstVisitor::visit(I);
+ ++II;
+ if (Done && I->getType()->isVoidTy())
+ I->eraseFromParent();
+ }
+ }
+ return finish();
+}
+
+bool DataScalarizerVisitor::finish() {
+ RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
+ return true;
+}
+
+GlobalVariable *
+DataScalarizerVisitor::getNewGlobalIfExists(Value *CurrOperand) {
+ if (GlobalVariable *OldGlobal = dyn_cast<GlobalVariable>(CurrOperand)) {
+ auto It = GlobalMap.find(OldGlobal);
+ if (It != GlobalMap.end()) {
+ return It->second; // Found, return the new global
+ }
+ }
+ return nullptr; // Not found
+}
+
+bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
+ for (unsigned I = 0; I < LI.getNumOperands(); ++I) {
+ Value *CurrOpperand = LI.getOperand(I);
+ GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand);
+ if (NewGlobal)
+ LI.setOperand(I, NewGlobal);
+ }
+ return false;
+}
+
+bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
+ for (unsigned I = 0; I < SI.getNumOperands(); ++I) {
+ Value *CurrOpperand = SI.getOperand(I);
+ GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand);
+ if (NewGlobal) {
+ SI.setOperand(I, NewGlobal);
+ }
+ }
+ return false;
+}
+
+bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
+ for (unsigned I = 0; I < GEPI.getNumOperands(); ++I) {
+ Value *CurrOpperand = GEPI.getOperand(I);
+ GlobalVariable *NewGlobal = getNewGlobalIfExists(CurrOpperand);
+ if (NewGlobal) {
+ IRBuilder<> Builder(&GEPI);
+
+ SmallVector<Value *, Max_VEC_SIZE> Indices;
+ for (auto &Index : GEPI.indices())
+ Indices.push_back(Index);
+
+ Value *NewGEP =
+ Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices);
+
+ GEPI.replaceAllUsesWith(NewGEP);
+ PotentiallyDeadInstrs.emplace_back(&GEPI);
+ }
+ }
+ return true;
+}
+
+// Recursively Creates and Array like version of the given vector like type.
+static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
+ if (auto *VecTy = dyn_cast<VectorType>(T))
+ return ArrayType::get(VecTy->getElementType(),
+ cast<FixedVectorType>(VecTy)->getNumElements());
+ if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
+ Type *NewElementType =
+ replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
+ return ArrayType::get(NewElementType, ArrayTy->getNumElements());
+ }
+ // If it's not a vector or array, return the original type.
+ return T;
+}
+
+Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
+ LLVMContext &Ctx) {
+ // Handle ConstantAggregateZero (zero-initialized constants)
+ if (isa<ConstantAggregateZero>(Init)) {
+ return ConstantAggregateZero::get(NewType);
+ }
+
+ // Handle UndefValue (undefined constants)
+ if (isa<UndefValue>(Init)) {
+ return UndefValue::get(NewType);
+ }
+
+ // Handle vector to array transformation
+ if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) {
+ // Convert vector initializer to array initializer
+ SmallVector<Constant *, Max_VEC_SIZE> ArrayElements;
+ if (ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) {
+ for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I)
+ ArrayElements.push_back(ConstVecInit->getOperand(I));
+ } else if (ConstantDataVector *ConstDataVecInit =
+ llvm::dyn_cast<llvm::ConstantDataVector>(Init)) {
+ for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I)
+ ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I));
+ } else {
+ llvm_unreachable("Expected a ConstantVector or ConstantDataVector for "
+ "vector initializer!");
+ }
+
+ return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements);
+ }
+
+ // Handle array of vectors transformation
+ if (auto *ArrayTy = dyn_cast<ArrayType>(OrigType)) {
+
+ auto *ArrayInit = dyn_cast<ConstantArray>(Init);
+ if (!ArrayInit) {
+ llvm_unreachable("Expected a ConstantArray for array initializer!");
+ }
----------------
farzonl wrote:
I'm assuiming also assert for the other `llvm_unreachable`
https://github.com/llvm/llvm-project/pull/110029
More information about the llvm-commits
mailing list