[llvm] [AMDGPU] Add TDM Descriptor Optimization Pass (PR #173324)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Tue Dec 23 03:15:30 PST 2025
================
@@ -0,0 +1,515 @@
+//===-- AMDGPUTDMOptimization.cpp - TDM Descriptor Optimization ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass optimizes Tensor Data Movement (TDM) descriptor creation patterns.
+// It identifies insertelement chains that create descriptors and transforms
+// them to use alloca+field updates, which SROA later optimizes to
+// INSERT_SUBREG.
+//
+//===----------------------------------------------------------------------===//
+
+#include "AMDGPU.h"
+#include "AMDGPUSubtarget.h"
+#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Type.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "amdgpu-tdm-optimization"
+
+static cl::opt<unsigned>
+ TDMOptBenefitThreshold("amdgpu-tdm-opt-threshold", cl::Hidden, cl::init(10),
+ cl::desc("Minimum optimization benefit threshold "
+ "for TDM descriptor optimization"));
+
+namespace llvm {
+void initializeAMDGPUTDMOptimizationPass(PassRegistry &);
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Pattern Detection Data Structures
+//===----------------------------------------------------------------------===//
+
+/// Represents a single descriptor creation pattern
+struct DescriptorPattern {
+ Type *DescType; ///< <4 x i32> or <8 x i32>
+ Value *BaseValue; ///< Base template (constant or computed)
+ SmallVector<InsertElementInst *, 8>
+ Chain; ///< Chain of insertelement instructions
+ SmallVector<unsigned, 8> VariableFields; ///< Fields that change
+ SmallVector<unsigned, 8> ConstantFields; ///< Fields that stay constant
+ BasicBlock *Location; ///< Where the pattern is located
+ Loop *ContainingLoop; ///< Loop containing this pattern (if any)
+
+ /// Calculate field reuse ratio (constant fields / total fields)
+ float getFieldReuseRatio() const {
+ unsigned totalFields = cast<FixedVectorType>(DescType)->getNumElements();
+ return (float)ConstantFields.size() / totalFields;
+ }
+
+ /// Check if this pattern is worth optimizing
+ bool isWorthOptimizing() const {
+ // Always optimize if in loop with reuse potential
+ if (ContainingLoop && getFieldReuseRatio() >= 0.5f)
+ return true;
+
+ // Optimize if significant field reuse
+ if (getFieldReuseRatio() >= 0.75f)
+ return true;
+
+ // Optimize address descriptors (common case)
+ if (isAddressDescriptor() && ConstantFields.size() >= 1)
+ return true;
+
+ return false;
+ }
+
+ /// Check if this is an address descriptor (<4 x i32>)
+ bool isAddressDescriptor() const {
+ auto *VecTy = cast<FixedVectorType>(DescType);
+ return VecTy->getNumElements() == 4 &&
+ VecTy->getElementType()->isIntegerTy(32);
+ }
+
+ /// Check if this is a tensor descriptor (<8 x i32>)
+ bool isTensorDescriptor() const {
+ auto *VecTy = cast<FixedVectorType>(DescType);
+ return VecTy->getNumElements() == 8 &&
+ VecTy->getElementType()->isIntegerTy(32);
+ }
+};
+
+/// Groups similar descriptor patterns for optimization
+struct DescriptorGroup {
+ SmallVector<DescriptorPattern, 4> Patterns;
+ Type *SharedType;
+ Value *SharedBase; ///< Common base value (if any)
+ SmallVector<unsigned, 8> SharedConstantFields;
+
+ /// Calculate total optimization benefit
+ unsigned getOptimizationBenefit() const {
+ unsigned benefit = 0;
+ for (const auto &pattern : Patterns) {
+ // Base benefit from field reuse
+ benefit += pattern.ConstantFields.size() * 2;
+
+ // Extra benefit for loop patterns
+ if (pattern.ContainingLoop)
+ benefit *= 5;
+ }
+ return benefit;
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// AMDGPUTDMOptimization Pass
+//===----------------------------------------------------------------------===//
+
+class AMDGPUTDMOptimization : public FunctionPass {
+private:
+ LoopInfo *LI = nullptr;
+
+ /// Detected patterns in the function
+ SmallVector<DescriptorPattern, 16> DetectedPatterns;
+
+ /// Groups of optimizable patterns
+ SmallVector<DescriptorGroup, 8> OptimizationGroups;
+
+public:
+ static char ID;
+
+ AMDGPUTDMOptimization() : FunctionPass(ID) {
+ initializeAMDGPUTDMOptimizationPass(*PassRegistry::getPassRegistry());
+ }
+
+ bool runOnFunction(Function &F) override;
+ void getAnalysisUsage(AnalysisUsage &AU) const override;
+
+private:
+ /// Main optimization phases
+ bool detectDescriptorPatterns(Function &F);
+ void groupSimilarPatterns();
+ bool transformPatterns(Function &F);
+
+ /// Pattern detection helpers
+ bool isDescriptorType(Type *Ty) const;
+ DescriptorPattern analyzeInsertChain(InsertElementInst *FinalInsert);
+ Value *extractBaseValue(const DescriptorPattern &Pattern);
+
+ /// Transformation helpers
+ bool transformDescriptorGroup(DescriptorGroup &Group, Function &F);
+ Value *createSharedStorage(DescriptorGroup &Group, IRBuilder<> &Builder);
+ void transformSinglePattern(DescriptorPattern &Pattern, Value *SharedStorage,
+ IRBuilder<> &Builder);
+
+ /// Utility functions
+ Loop *getContainingLoop(BasicBlock *BB);
+ bool arePatternsSimilar(const DescriptorPattern &A,
+ const DescriptorPattern &B);
+};
+
+//===----------------------------------------------------------------------===//
+// Pass Implementation
+//===----------------------------------------------------------------------===//
+
+bool AMDGPUTDMOptimization::runOnFunction(Function &F) {
+ LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+
+ LLVM_DEBUG(dbgs() << "Running TDM optimization on function: " << F.getName()
+ << "\n");
+
+ // Phase 1: Detect descriptor patterns
+ if (!detectDescriptorPatterns(F)) {
+ LLVM_DEBUG(dbgs() << "No descriptor patterns found\n");
+ return false;
+ }
+
+ LLVM_DEBUG(dbgs() << "Found " << DetectedPatterns.size()
+ << " descriptor patterns\n");
+
+ // Phase 2: Group similar patterns for optimization
+ groupSimilarPatterns();
+
+ LLVM_DEBUG(dbgs() << "Created " << OptimizationGroups.size()
+ << " optimization groups\n");
+
+ // Phase 3: Transform patterns
+ bool Changed = transformPatterns(F);
+
+ // Cleanup for next function
+ DetectedPatterns.clear();
+ OptimizationGroups.clear();
+
+ return Changed;
+}
+
+void AMDGPUTDMOptimization::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.addRequired<LoopInfoWrapperPass>();
+ AU.setPreservesCFG();
+}
+
+//===----------------------------------------------------------------------===//
+// Pattern Detection
+//===----------------------------------------------------------------------===//
+
+bool AMDGPUTDMOptimization::detectDescriptorPatterns(Function &F) {
+ bool FoundPatterns = false;
+
+ // Scan function for insertelement instructions that create descriptors
+ for (auto &BB : F) {
+ for (auto &I : BB) {
+ auto *IE = dyn_cast<InsertElementInst>(&I);
+ if (!IE || !isDescriptorType(IE->getType()))
+ continue;
+
+ // Check if this is the final insert in a descriptor creation chain
+ if (!IE->hasOneUse() || isa<InsertElementInst>(*IE->user_begin()))
+ continue;
+
+ // Analyze the complete chain
+ DescriptorPattern Pattern = analyzeInsertChain(IE);
+ if (Pattern.Chain.empty())
+ continue;
+
+ // Check if worth optimizing
+ if (!Pattern.isWorthOptimizing()) {
+ LLVM_DEBUG(
+ dbgs() << "Pattern not worth optimizing: field reuse ratio = "
+ << Pattern.getFieldReuseRatio() << "\n");
+ continue;
+ }
+
+ LLVM_DEBUG(
+ dbgs() << "Found optimizable pattern: "
+ << (Pattern.isAddressDescriptor() ? "Address" : "Tensor")
+ << " descriptor with " << Pattern.ConstantFields.size()
+ << " constant fields\n");
+
+ DetectedPatterns.push_back(std::move(Pattern));
+ FoundPatterns = true;
+ }
+ }
+
+ return FoundPatterns;
+}
+
+bool AMDGPUTDMOptimization::isDescriptorType(Type *Ty) const {
+ auto *VecTy = dyn_cast<FixedVectorType>(Ty);
+ if (!VecTy || !VecTy->getElementType()->isIntegerTy(32))
+ return false;
+
+ unsigned NumElements = VecTy->getNumElements();
+ return NumElements == 4 || NumElements == 8; // Address or tensor descriptors
+}
+
+DescriptorPattern
+AMDGPUTDMOptimization::analyzeInsertChain(InsertElementInst *FinalInsert) {
+ DescriptorPattern Pattern;
+ Pattern.DescType = FinalInsert->getType();
+ Pattern.Location = FinalInsert->getParent();
+ Pattern.ContainingLoop = getContainingLoop(Pattern.Location);
+
+ // Trace back the insertelement chain
+ SmallVector<InsertElementInst *, 8> Chain;
+ Value *CurrentVal = FinalInsert;
+
+ while (auto *IE = dyn_cast<InsertElementInst>(CurrentVal)) {
+ Chain.push_back(IE);
+ CurrentVal = IE->getOperand(0); // Vector being inserted into
+ }
+
+ // Reverse to get forward order
+ std::reverse(Chain.begin(), Chain.end());
+ Pattern.Chain = Chain;
+
+ // Extract base value (the initial vector)
+ Pattern.BaseValue = extractBaseValue(Pattern);
+
+ // Analyze which fields are constant vs variable
+ unsigned NumElements =
+ cast<FixedVectorType>(Pattern.DescType)->getNumElements();
+ SmallBitVector FieldSet(NumElements, false);
+
+ for (auto *IE : Chain) {
+ if (auto *CI = dyn_cast<ConstantInt>(IE->getOperand(2))) {
+ unsigned Idx = CI->getZExtValue();
+ if (Idx < NumElements) {
+ FieldSet.set(Idx);
+ Pattern.VariableFields.push_back(Idx);
+ }
+ }
+ }
+
+ // Fields not in chain are constant
+ for (unsigned i = 0; i < NumElements; ++i) {
+ if (!FieldSet[i])
+ Pattern.ConstantFields.push_back(i);
+ }
+
+ return Pattern;
+}
+
+Value *
+AMDGPUTDMOptimization::extractBaseValue(const DescriptorPattern &Pattern) {
+ if (Pattern.Chain.empty())
+ return nullptr;
+
+ // Get the vector being inserted into by the first insert
+ Value *Base = Pattern.Chain[0]->getOperand(0);
+
+ // If base is a constant vector or another recognizable pattern, return it
+ if (isa<Constant>(Base))
+ return Base;
+
+ // For shufflevector results, we might want to trace further back
+ if (auto *SV = dyn_cast<ShuffleVectorInst>(Base))
+ return SV; // Keep shufflevector as base for now
+
+ return Base;
+}
+
+Loop *AMDGPUTDMOptimization::getContainingLoop(BasicBlock *BB) {
+ return LI ? LI->getLoopFor(BB) : nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// Pattern Grouping
+//===----------------------------------------------------------------------===//
+
+void AMDGPUTDMOptimization::groupSimilarPatterns() {
+ // Simple grouping strategy: group by type and base similarity
+ for (auto &Pattern : DetectedPatterns) {
+ bool Added = false;
+
+ // Try to add to existing group
+ for (auto &Group : OptimizationGroups) {
+ if (Group.SharedType == Pattern.DescType &&
+ arePatternsSimilar(Group.Patterns[0], Pattern)) {
+ Group.Patterns.push_back(Pattern);
+ Added = true;
+ break;
+ }
+ }
+
+ // Create new group if needed
+ if (!Added) {
+ DescriptorGroup NewGroup;
+ NewGroup.SharedType = Pattern.DescType;
+ NewGroup.SharedBase = Pattern.BaseValue;
+ NewGroup.Patterns.push_back(Pattern);
+ OptimizationGroups.push_back(std::move(NewGroup));
+ }
+ }
+
+ // Remove groups that don't meet optimization criteria
+ OptimizationGroups.erase(
+ std::remove_if(OptimizationGroups.begin(), OptimizationGroups.end(),
+ [](const DescriptorGroup &Group) {
+ return Group.getOptimizationBenefit() <
+ TDMOptBenefitThreshold;
+ }),
+ OptimizationGroups.end());
+}
+
+bool AMDGPUTDMOptimization::arePatternsSimilar(const DescriptorPattern &A,
+ const DescriptorPattern &B) {
+ // Patterns are similar if they have same type and similar field usage
+ if (A.DescType != B.DescType)
+ return false;
+
+ // Check if constant fields overlap significantly
+ SmallBitVector AConstants(
+ cast<FixedVectorType>(A.DescType)->getNumElements());
+ SmallBitVector BConstants(
+ cast<FixedVectorType>(B.DescType)->getNumElements());
+
+ for (unsigned Field : A.ConstantFields)
+ AConstants.set(Field);
+ for (unsigned Field : B.ConstantFields)
+ BConstants.set(Field);
+
+ // Count overlapping constant fields
+ auto Intersection = AConstants & BConstants;
+ unsigned OverlapCount = Intersection.count();
+ unsigned TotalConstants = std::max(AConstants.count(), BConstants.count());
+
+ return TotalConstants > 0 && (float)OverlapCount / TotalConstants >= 0.5f;
+}
+
+//===----------------------------------------------------------------------===//
+// Pattern Transformation
+//===----------------------------------------------------------------------===//
+
+bool AMDGPUTDMOptimization::transformPatterns(Function &F) {
+ bool Changed = false;
+
+ for (auto &Group : OptimizationGroups) {
+ LLVM_DEBUG(dbgs() << "Transforming group with " << Group.Patterns.size()
+ << " patterns, benefit = "
+ << Group.getOptimizationBenefit() << "\n");
+
+ if (transformDescriptorGroup(Group, F))
+ Changed = true;
+ }
+
+ return Changed;
+}
+
+bool AMDGPUTDMOptimization::transformDescriptorGroup(DescriptorGroup &Group,
+ Function &F) {
+ if (Group.Patterns.empty())
+ return false;
+
+ // Find the best location to place shared storage
+ BasicBlock *StorageLocation = Group.Patterns[0].Location;
+
+ // If patterns are in a loop, try to hoist storage outside loop
+ if (auto *Loop = Group.Patterns[0].ContainingLoop) {
+ if (auto *Preheader = Loop->getLoopPreheader()) {
+ StorageLocation = Preheader;
+ LLVM_DEBUG(dbgs() << "Hoisting storage outside loop\n");
+ }
+ }
+
+ // Create shared storage at the beginning of the storage block
+ IRBuilder<> Builder(&StorageLocation->front());
+ Value *SharedStorage = createSharedStorage(Group, Builder);
+
+ if (!SharedStorage)
+ return false;
+
+ // Transform each pattern in the group
+ for (auto &Pattern : Group.Patterns) {
+ IRBuilder<> PatternBuilder(Pattern.Chain.back());
+ transformSinglePattern(Pattern, SharedStorage, PatternBuilder);
+ }
+
+ return true;
+}
+
+Value *AMDGPUTDMOptimization::createSharedStorage(DescriptorGroup &Group,
+ IRBuilder<> &Builder) {
+ // Create alloca in address space 5 (AMDGPU private memory)
+ auto *StorageType = Group.SharedType;
+ auto *Storage = Builder.CreateAlloca(
+ StorageType, /*AddrSpace=*/5, /*ArraySize=*/nullptr, "tdm_desc_storage");
----------------
arsenm wrote:
And don't hardcode the address space, take from the datalayout (or at least use the enum)
https://github.com/llvm/llvm-project/pull/173324
More information about the llvm-commits
mailing list