[llvm] bcec4cc - [X86] [AMX] Replace bitcast with specific AMX intrinsics with X86 specific cast.
Bing1 Yu via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 17 02:04:37 PDT 2021
Author: Bing1 Yu
Date: 2021-08-17T17:04:26+08:00
New Revision: bcec4ccd04ae678a0d17b8fe8170e04221bf1959
URL: https://github.com/llvm/llvm-project/commit/bcec4ccd04ae678a0d17b8fe8170e04221bf1959
DIFF: https://github.com/llvm/llvm-project/commit/bcec4ccd04ae678a0d17b8fe8170e04221bf1959.diff
LOG: [X86] [AMX] Replace bitcast with specific AMX intrinsics with X86 specific cast.
There is some discussion on the bitcast for vector and x86_amx at https://reviews.llvm.org/D99152. This patch is to introduce a x86 specific cast for vector and x86_amx, so that it can avoid some unnecessary optimization by middle-end. On the other way, we have to optimize the x86 specific cast by ourselves. This patch also optimize the cast operation to eliminate redundant code.
Reviewed By: LuoYuanke
Differential Revision: https://reviews.llvm.org/D107544
Added:
llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll
llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll
Modified:
llvm/include/llvm/IR/IntrinsicsX86.td
llvm/lib/Target/X86/X86LowerAMXType.cpp
llvm/test/CodeGen/X86/AMX/amx-type.ll
Removed:
################################################################################
diff --git a/llvm/include/llvm/IR/IntrinsicsX86.td b/llvm/include/llvm/IR/IntrinsicsX86.td
index ae0a416175f9e..eba83493e686d 100644
--- a/llvm/include/llvm/IR/IntrinsicsX86.td
+++ b/llvm/include/llvm/IR/IntrinsicsX86.td
@@ -5093,6 +5093,10 @@ let TargetPrefix = "x86" in {
[llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
llvm_x86amx_ty, llvm_x86amx_ty,
llvm_x86amx_ty], []>;
+ def int_x86_cast_vector_to_tile:
+ Intrinsic<[llvm_x86amx_ty], [llvm_anyvector_ty], [IntrNoMem]>;
+ def int_x86_cast_tile_to_vector:
+ Intrinsic<[llvm_anyvector_ty], [llvm_x86amx_ty], [IntrNoMem]>;
}
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp
index 4ba44ccb6c160..a2bcc98f3d5b6 100644
--- a/llvm/lib/Target/X86/X86LowerAMXType.cpp
+++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp
@@ -40,8 +40,10 @@
//
#include "X86.h"
#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/CodeGen/TargetPassConfig.h"
@@ -56,66 +58,44 @@
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Target/TargetMachine.h"
+#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
+#include "llvm/Transforms/Utils/Local.h"
using namespace llvm;
using namespace PatternMatch;
#define DEBUG_TYPE "lower-amx-type"
-static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder,
- BasicBlock *BB) {
+static bool isAMXCast(Instruction *II) {
+ return match(II,
+ m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) ||
+ match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
+}
+
+static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB,
+ Type *Ty) {
Function &F = *BB->getParent();
Module *M = BB->getModule();
const DataLayout &DL = M->getDataLayout();
- Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
LLVMContext &Ctx = Builder.getContext();
auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
unsigned AllocaAS = DL.getAllocaAddrSpace();
AllocaInst *AllocaRes =
- new AllocaInst(V256I32Ty, AllocaAS, "", &F.getEntryBlock().front());
+ new AllocaInst(Ty, AllocaAS, "", &F.getEntryBlock().front());
AllocaRes->setAlignment(AllocaAlignment);
return AllocaRes;
}
-namespace {
-class X86LowerAMXType {
- Function &Func;
- TargetMachine *TM = nullptr;
-
- // In AMX intrinsics we let Shape = {Row, Col}, but the
- // RealCol = Col / ElementSize. We may use the RealCol
- // as a new Row for other new created AMX intrinsics.
- std::map<Value *, Value *> Col2Row;
-
-public:
- X86LowerAMXType(Function &F, TargetMachine *TargetM) : Func(F), TM(TargetM) {}
- bool visit();
- void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
- void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
- bool transformBitcast(BitCastInst *Bitcast);
- std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo);
- Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity);
-};
-
-Value *X86LowerAMXType::getRowFromCol(Instruction *II, Value *V,
- unsigned Granularity) {
- if (Col2Row.count(V))
- return Col2Row[V];
- IRBuilder<> Builder(&*II->getParent()->getFirstInsertionPt());
- if (auto *I = dyn_cast<Instruction>(V)) {
- BasicBlock::iterator Iter = I->getIterator();
- ++Iter;
- Builder.SetInsertPoint(&*Iter);
- }
- ConstantInt *Gran = Builder.getInt16(Granularity);
- Value *RealRow = Builder.CreateUDiv(V, Gran);
- Col2Row[V] = RealRow;
- return RealRow;
+static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) {
+ for (Instruction &I : F.getEntryBlock())
+ if (!isa<AllocaInst>(&I))
+ return &I;
+ llvm_unreachable("No terminator in the entry block!");
}
-std::pair<Value *, Value *> X86LowerAMXType::getShape(IntrinsicInst *II,
- unsigned OpNo) {
+static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
+ IRBuilder<> Builder(II);
Value *Row = nullptr, *Col = nullptr;
switch (II->getIntrinsicID()) {
default:
@@ -144,14 +124,32 @@ std::pair<Value *, Value *> X86LowerAMXType::getShape(IntrinsicInst *II,
Col = II->getArgOperand(2);
break;
case 5:
- Row = II->getArgOperand(2);
- // FIXME: There is a design bug for AMX shape, which the Col should be
- // Col/4 if it will be used as Row, but current Greedy RA can't handle
- // this case well, it may failed if we generate a new Shape definition.
- // So Let's just do it in O0 first.
- // Row = Row / 4
- if (TM->getOptLevel() == CodeGenOpt::None)
- Row = getRowFromCol(II, Row, 4);
+ if (isa<ConstantInt>(II->getArgOperand(2)))
+ Row = Builder.getInt16(
+ (dyn_cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4);
+ else if (isa<Instruction>(II->getArgOperand(2))) {
+ // When it is not a const value and it is not a function argument, we
+ // create Row after the definition of II->getOperand(2) instead of
+ // before II. For example, II is %118, we try to getshape for %117:
+ // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
+ // i32> %115).
+ // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
+ // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
+ // %117).
+ // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
+ // definition is after its user(new tileload for %117).
+ // So, the best choice is to create %row right after the definition of
+ // %106.
+ Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2)));
+ Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4));
+ cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2)));
+ } else {
+ // When it is not a const value and it is a function argument, we create
+ // Row at the entry bb.
+ IRBuilder<> NewBuilder(
+ getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
+ Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4));
+ }
Col = II->getArgOperand(1);
break;
}
@@ -162,6 +160,40 @@ std::pair<Value *, Value *> X86LowerAMXType::getShape(IntrinsicInst *II,
return std::make_pair(Row, Col);
}
+namespace {
+class X86LowerAMXType {
+ Function &Func;
+
+ // In AMX intrinsics we let Shape = {Row, Col}, but the
+ // RealCol = Col / ElementSize. We may use the RealCol
+ // as a new Row for other new created AMX intrinsics.
+ std::map<Value *, Value *> Col2Row;
+
+public:
+ X86LowerAMXType(Function &F) : Func(F) {}
+ bool visit();
+ void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
+ void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
+ bool transformBitcast(BitCastInst *Bitcast);
+ Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity);
+};
+
+Value *X86LowerAMXType::getRowFromCol(Instruction *II, Value *V,
+ unsigned Granularity) {
+ if (Col2Row.count(V))
+ return Col2Row[V];
+ IRBuilder<> Builder(&*II->getParent()->getFirstInsertionPt());
+ if (auto *I = dyn_cast<Instruction>(V)) {
+ BasicBlock::iterator Iter = I->getIterator();
+ ++Iter;
+ Builder.SetInsertPoint(&*Iter);
+ }
+ ConstantInt *Gran = Builder.getInt16(Granularity);
+ Value *RealRow = Builder.CreateUDiv(V, Gran);
+ Col2Row[V] = RealRow;
+ return RealRow;
+}
+
// %src = load <256 x i32>, <256 x i32>* %addr, align 64
// %2 = bitcast <256 x i32> %src to x86_amx
// -->
@@ -230,8 +262,8 @@ bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
Value *I8Ptr, *Stride;
auto *Src = Bitcast->getOperand(0);
- auto Prepare = [&]() {
- AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent());
+ auto Prepare = [&](Type *MemTy) {
+ AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
Stride = Builder.getInt64(64);
};
@@ -250,7 +282,7 @@ bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
auto *II = dyn_cast<IntrinsicInst>(U.getUser());
if (!II)
return false; // May be bitcast from x86amx to <256 x i32>.
- Prepare();
+ Prepare(Bitcast->getOperand(0)->getType());
Builder.CreateStore(Src, AllocaAddr);
// TODO we can pick an constant operand for the shape.
Value *Row = nullptr, *Col = nullptr;
@@ -270,7 +302,7 @@ bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
auto *II = dyn_cast<IntrinsicInst>(Src);
if (!II)
return false; // May be bitcast from <256 x i32> to x86amx.
- Prepare();
+ Prepare(Bitcast->getType());
Value *Row = II->getOperand(0);
Value *Col = II->getOperand(1);
std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
@@ -637,6 +669,364 @@ bool X86VolatileTileData::volatileTileData() {
namespace {
+class X86LowerAMXCast {
+ Function &Func;
+
+public:
+ X86LowerAMXCast(Function &F) : Func(F) {}
+ bool combineAMXcast(TargetLibraryInfo *TLI);
+ bool transformAMXCast(IntrinsicInst *AMXCast);
+ bool transformAllAMXCast();
+ bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
+ SmallSetVector<Instruction *, 16> &DeadInst);
+};
+
+static bool DCEInstruction(Instruction *I,
+ SmallSetVector<Instruction *, 16> &WorkList,
+ const TargetLibraryInfo *TLI) {
+ if (isInstructionTriviallyDead(I, TLI)) {
+ salvageDebugInfo(*I);
+ salvageKnowledge(I);
+
+ // Null out all of the instruction's operands to see if any operand becomes
+ // dead as we go.
+ for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
+ Value *OpV = I->getOperand(i);
+ I->setOperand(i, nullptr);
+
+ if (!OpV->use_empty() || I == OpV)
+ continue;
+
+ // If the operand is an instruction that became dead as we nulled out the
+ // operand, and if it is 'trivially' dead, delete it in a future loop
+ // iteration.
+ if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
+ if (isInstructionTriviallyDead(OpI, TLI)) {
+ WorkList.insert(OpI);
+ }
+ }
+ }
+ I->eraseFromParent();
+ return true;
+ }
+ return false;
+}
+
+/// This function handles following case
+///
+/// A -> B amxcast
+/// PHI
+/// B -> A amxcast
+///
+/// All the related PHI nodes can be replaced by new PHI nodes with type A.
+/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
+bool X86LowerAMXCast::optimizeAMXCastFromPhi(
+ IntrinsicInst *CI, PHINode *PN,
+ SmallSetVector<Instruction *, 16> &DeadInst) {
+ IRBuilder<> Builder(CI);
+ Value *Src = CI->getOperand(0);
+ Type *SrcTy = Src->getType(); // Type B
+ Type *DestTy = CI->getType(); // Type A
+
+ SmallVector<PHINode *, 4> PhiWorklist;
+ SmallSetVector<PHINode *, 4> OldPhiNodes;
+
+ // Find all of the A->B casts and PHI nodes.
+ // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
+ // OldPhiNodes is used to track all known PHI nodes, before adding a new
+ // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
+ PhiWorklist.push_back(PN);
+ OldPhiNodes.insert(PN);
+ while (!PhiWorklist.empty()) {
+ auto *OldPN = PhiWorklist.pop_back_val();
+ for (Value *IncValue : OldPN->incoming_values()) {
+ // TODO: currently, We ignore cases where it is a const. In the future, we
+ // might support const.
+ if (isa<Constant>(IncValue))
+ return false;
+
+ if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
+ if (OldPhiNodes.insert(PNode))
+ PhiWorklist.push_back(PNode);
+ continue;
+ }
+ Instruction *ACI = dyn_cast<Instruction>(IncValue);
+ if (ACI && isAMXCast(ACI)) {
+ // Verify it's a A->B cast.
+ Type *TyA = ACI->getOperand(0)->getType();
+ Type *TyB = ACI->getType();
+ if (TyA != DestTy || TyB != SrcTy)
+ return false;
+ continue;
+ }
+ return false;
+ }
+ }
+
+ // Check that each user of each old PHI node is something that we can
+ // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
+ for (auto *OldPN : OldPhiNodes) {
+ for (User *V : OldPN->users()) {
+ Instruction *ACI = dyn_cast<Instruction>(V);
+ if (ACI && isAMXCast(ACI)) {
+ // Verify it's a B->A cast.
+ Type *TyB = ACI->getOperand(0)->getType();
+ Type *TyA = ACI->getType();
+ if (TyA != DestTy || TyB != SrcTy)
+ return false;
+ } else if (auto *PHI = dyn_cast<PHINode>(V)) {
+ // As long as the user is another old PHI node, then even if we don't
+ // rewrite it, the PHI web we're considering won't have any users
+ // outside itself, so it'll be dead.
+ // example:
+ // bb.0:
+ // %0 = amxcast ...
+ // bb.1:
+ // %1 = amxcast ...
+ // bb.2:
+ // %goodphi = phi %0, %1
+ // %3 = amxcast %goodphi
+ // bb.3:
+ // %goodphi2 = phi %0, %goodphi
+ // %4 = amxcast %goodphi2
+ // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
+ // outside the phi-web, so the combination stop When
+ // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
+ // will be done.
+ if (OldPhiNodes.count(PHI) == 0)
+ return false;
+ } else
+ return false;
+ }
+ }
+
+ // For each old PHI node, create a corresponding new PHI node with a type A.
+ SmallDenseMap<PHINode *, PHINode *> NewPNodes;
+ for (auto *OldPN : OldPhiNodes) {
+ Builder.SetInsertPoint(OldPN);
+ PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
+ NewPNodes[OldPN] = NewPN;
+ }
+
+ // Fill in the operands of new PHI nodes.
+ for (auto *OldPN : OldPhiNodes) {
+ PHINode *NewPN = NewPNodes[OldPN];
+ for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
+ Value *V = OldPN->getOperand(j);
+ Value *NewV = nullptr;
+ Instruction *ACI = dyn_cast<Instruction>(V);
+ // There should not be a AMXcast from a const.
+ if (ACI && isAMXCast(ACI))
+ NewV = ACI->getOperand(0);
+ else if (auto *PrevPN = dyn_cast<PHINode>(V))
+ NewV = NewPNodes[PrevPN];
+ assert(NewV);
+ NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
+ }
+ }
+
+ // Traverse all accumulated PHI nodes and process its users,
+ // which are Stores and BitcCasts. Without this processing
+ // NewPHI nodes could be replicated and could lead to extra
+ // moves generated after DeSSA.
+ // If there is a store with type B, change it to type A.
+
+ // Replace users of BitCast B->A with NewPHI. These will help
+ // later to get rid of a closure formed by OldPHI nodes.
+ for (auto *OldPN : OldPhiNodes) {
+ PHINode *NewPN = NewPNodes[OldPN];
+ for (User *V : make_early_inc_range(OldPN->users())) {
+ Instruction *ACI = dyn_cast<Instruction>(V);
+ if (ACI && isAMXCast(ACI)) {
+ Type *TyB = ACI->getOperand(0)->getType();
+ Type *TyA = ACI->getType();
+ assert(TyA == DestTy && TyB == SrcTy);
+ (void)TyA;
+ (void)TyB;
+ ACI->replaceAllUsesWith(NewPN);
+ DeadInst.insert(ACI);
+ } else if (auto *PHI = dyn_cast<PHINode>(V)) {
+ // We don't need to push PHINode into DeadInst since they are operands
+ // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
+ assert(OldPhiNodes.contains(PHI));
+ (void)PHI;
+ } else
+ llvm_unreachable("all uses should be handled");
+ }
+ }
+ return true;
+}
+
+bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
+ bool Change = false;
+ // Collect tile cast instruction.
+ SmallVector<Instruction *, 8> Vec2TileInsts;
+ SmallVector<Instruction *, 8> Tile2VecInsts;
+ SmallVector<Instruction *, 8> PhiCastWorkList;
+ SmallSetVector<Instruction *, 16> DeadInst;
+ for (BasicBlock &BB : Func) {
+ for (Instruction &I : BB) {
+ Value *Vec;
+ if (match(&I,
+ m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec))))
+ Vec2TileInsts.push_back(&I);
+ else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
+ m_Value(Vec))))
+ Tile2VecInsts.push_back(&I);
+ }
+ }
+
+ auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
+ for (auto *Inst : Insts) {
+ for (User *U : Inst->users()) {
+ IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
+ if (!II || II->getIntrinsicID() != IID)
+ continue;
+ // T1 = vec2tile V0
+ // V2 = tile2vec T1
+ // V3 = OP V2
+ // -->
+ // T1 = vec2tile V0
+ // V2 = tile2vec T1
+ // V3 = OP V0
+ II->replaceAllUsesWith(Inst->getOperand(0));
+ Change = true;
+ }
+ }
+ };
+
+ Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
+ Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
+
+ auto EraseInst = [](SmallVectorImpl<Instruction *> &Insts) {
+ for (auto *Inst : Insts) {
+ if (Inst->use_empty())
+ Inst->eraseFromParent();
+ }
+ };
+
+ EraseInst(Vec2TileInsts);
+ EraseInst(Tile2VecInsts);
+
+ // Handle the A->B->A cast, and there is an intervening PHI node.
+ for (BasicBlock &BB : Func) {
+ for (Instruction &I : BB) {
+ if (isAMXCast(&I)) {
+ if (PHINode *PN = dyn_cast<PHINode>(I.getOperand(0)))
+ PhiCastWorkList.push_back(&I);
+ }
+ }
+ }
+ for (auto *I : PhiCastWorkList) {
+ // We skip the dead Amxcast.
+ if (DeadInst.contains(I))
+ continue;
+ PHINode *PN = cast<PHINode>(I->getOperand(0));
+ if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
+ DeadInst.insert(PN);
+ Change = true;
+ }
+ }
+
+ // Since we create new phi and merge AMXCast, some old phis and AMXCast might
+ // have no uses. We do some DeadCodeElimination for them.
+ while (!DeadInst.empty()) {
+ Instruction *I = DeadInst.pop_back_val();
+ Change |= DCEInstruction(I, DeadInst, TLI);
+ }
+ return Change;
+}
+
+// There might be remaining AMXcast after combineAMXcast and they should be
+// handled elegantly.
+bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
+ IRBuilder<> Builder(AMXCast);
+ AllocaInst *AllocaAddr;
+ Value *I8Ptr, *Stride;
+ auto *Src = AMXCast->getOperand(0);
+
+ auto Prepare = [&](Type *MemTy) {
+ AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
+ I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
+ Stride = Builder.getInt64(64);
+ };
+
+ if (AMXCast->getType()->isX86_AMXTy()) {
+ // %2 = amxcast <225 x i32> %src to x86_amx
+ // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
+ // i8* %addr3, i64 60, x86_amx %2)
+ // -->
+ // %addr = alloca <225 x i32>, align 64
+ // store <225 x i32> %src, <225 x i32>* %addr, align 64
+ // %addr2 = bitcast <225 x i32>* %addr to i8*
+ // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
+ // i8* %addr2,
+ // i64 60)
+ // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
+ // i8* %addr3, i64 60, x86_amx %2)
+ Use &U = *(AMXCast->use_begin());
+ unsigned OpNo = U.getOperandNo();
+ auto *II = dyn_cast<IntrinsicInst>(U.getUser());
+ if (!II)
+ return false; // May be bitcast from x86amx to <256 x i32>.
+ Prepare(AMXCast->getOperand(0)->getType());
+ Builder.CreateStore(Src, AllocaAddr);
+ // TODO we can pick an constant operand for the shape.
+ Value *Row = nullptr, *Col = nullptr;
+ std::tie(Row, Col) = getShape(II, OpNo);
+ std::array<Value *, 4> Args = {
+ Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
+ Value *NewInst = Builder.CreateIntrinsic(
+ Intrinsic::x86_tileloadd64_internal, None, Args);
+ AMXCast->replaceAllUsesWith(NewInst);
+ AMXCast->eraseFromParent();
+ } else {
+ // %2 = amxcast x86_amx %src to <225 x i32>
+ // -->
+ // %addr = alloca <225 x i32>, align 64
+ // %addr2 = bitcast <225 x i32>* to i8*
+ // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
+ // i8* %addr2, i64 %stride)
+ // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
+ auto *II = dyn_cast<IntrinsicInst>(Src);
+ if (!II)
+ return false; // May be bitcast from <256 x i32> to x86amx.
+ Prepare(AMXCast->getType());
+ Value *Row = II->getOperand(0);
+ Value *Col = II->getOperand(1);
+ std::array<Value *, 5> Args = {
+ Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
+ Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
+ Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
+ AMXCast->replaceAllUsesWith(NewInst);
+ AMXCast->eraseFromParent();
+ }
+
+ return true;
+}
+
+bool X86LowerAMXCast::transformAllAMXCast() {
+ bool Change = false;
+ // Collect tile cast instruction.
+ SmallVector<Instruction *, 8> WorkLists;
+ for (BasicBlock &BB : Func) {
+ for (Instruction &I : BB) {
+ if (isAMXCast(&I))
+ WorkLists.push_back(&I);
+ }
+ }
+
+ for (auto *Inst : WorkLists) {
+ Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
+ }
+
+ return Change;
+}
+
+} // anonymous namespace
+
+namespace {
+
class X86LowerAMXTypeLegacyPass : public FunctionPass {
public:
static char ID;
@@ -647,8 +1037,15 @@ class X86LowerAMXTypeLegacyPass : public FunctionPass {
bool runOnFunction(Function &F) override {
TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
+ TargetLibraryInfo *TLI =
+ &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
+ X86LowerAMXCast LAC(F);
+ LAC.combineAMXcast(TLI);
+ // There might be remaining AMXcast after combineAMXcast and they should be
+ // handled elegantly.
+ LAC.transformAllAMXCast();
- X86LowerAMXType LAT(F, TM);
+ X86LowerAMXType LAT(F);
bool C = LAT.visit();
// Prepare for fast register allocation at O0.
@@ -671,6 +1068,7 @@ class X86LowerAMXTypeLegacyPass : public FunctionPass {
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
AU.addRequired<TargetPassConfig>();
+ AU.addRequired<TargetLibraryInfoWrapperPass>();
}
};
@@ -681,6 +1079,7 @@ char X86LowerAMXTypeLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
false)
INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
+INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
false)
diff --git a/llvm/test/CodeGen/X86/AMX/amx-type.ll b/llvm/test/CodeGen/X86/AMX/amx-type.ll
index 989a1076ce7a6..ddf650525baaa 100644
--- a/llvm/test/CodeGen/X86/AMX/amx-type.ll
+++ b/llvm/test/CodeGen/X86/AMX/amx-type.ll
@@ -163,18 +163,19 @@ define dso_local void @__tile_dpbssd(%struct.__tile_str* nocapture %0, %struct._
; CHECK-NEXT: [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2
; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 1
; CHECK-NEXT: [[TMP9:%.*]] = load i16, i16* [[TMP8]], align 2
-; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 2
-; CHECK-NEXT: [[TMP11:%.*]] = bitcast <256 x i32>* [[TMP10]] to i8*
-; CHECK-NEXT: [[TMP12:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP11]], i64 64)
-; CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 2
-; CHECK-NEXT: [[TMP14:%.*]] = bitcast <256 x i32>* [[TMP13]] to i8*
-; CHECK-NEXT: [[TMP15:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP9]], i8* [[TMP14]], i64 64)
-; CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2
-; CHECK-NEXT: [[TMP17:%.*]] = bitcast <256 x i32>* [[TMP16]] to i8*
-; CHECK-NEXT: [[TMP18:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP7]], i8* [[TMP17]], i64 64)
-; CHECK-NEXT: [[TMP19:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[TMP5]], i16 [[TMP7]], i16 [[TMP9]], x86_amx [[TMP12]], x86_amx [[TMP15]], x86_amx [[TMP18]])
-; CHECK-NEXT: [[TMP20:%.*]] = bitcast <256 x i32>* [[TMP10]] to i8*
-; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP20]], i64 64, x86_amx [[TMP19]])
+; CHECK-NEXT: [[TMP10:%.*]] = udiv i16 [[TMP9]], 4
+; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 2
+; CHECK-NEXT: [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP11]] to i8*
+; CHECK-NEXT: [[TMP13:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP12]], i64 64)
+; CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 2
+; CHECK-NEXT: [[TMP15:%.*]] = bitcast <256 x i32>* [[TMP14]] to i8*
+; CHECK-NEXT: [[TMP16:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP9]], i8* [[TMP15]], i64 64)
+; CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2
+; CHECK-NEXT: [[TMP18:%.*]] = bitcast <256 x i32>* [[TMP17]] to i8*
+; CHECK-NEXT: [[TMP19:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP10]], i16 [[TMP7]], i8* [[TMP18]], i64 64)
+; CHECK-NEXT: [[TMP20:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[TMP5]], i16 [[TMP7]], i16 [[TMP9]], x86_amx [[TMP13]], x86_amx [[TMP16]], x86_amx [[TMP19]])
+; CHECK-NEXT: [[TMP21:%.*]] = bitcast <256 x i32>* [[TMP11]] to i8*
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP21]], i64 64, x86_amx [[TMP20]])
; CHECK-NEXT: ret void
;
%4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 0
@@ -200,15 +201,16 @@ define dso_local void @__tile_dpbssd(%struct.__tile_str* nocapture %0, %struct._
define dso_local void @__tile_dpbsud(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) {
; CHECK-LABEL: @__tile_dpbsud(
-; CHECK-NEXT: [[TMP1:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8*
-; CHECK-NEXT: [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K:%.*]], i8* [[TMP1]], i64 64)
-; CHECK-NEXT: [[TMP3:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8*
-; CHECK-NEXT: [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[K]], i16 [[N:%.*]], i8* [[TMP3]], i64 64)
-; CHECK-NEXT: [[TMP5:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8*
-; CHECK-NEXT: [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP5]], i64 64)
-; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbsud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP6]], x86_amx [[TMP2]], x86_amx [[TMP4]])
-; CHECK-NEXT: [[TMP7:%.*]] = bitcast <256 x i32>* [[PC]] to i8*
-; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP7]], i64 64, x86_amx [[T6]])
+; CHECK-NEXT: [[TMP1:%.*]] = udiv i16 [[K:%.*]], 4
+; CHECK-NEXT: [[TMP2:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8*
+; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP2]], i64 64)
+; CHECK-NEXT: [[TMP4:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8*
+; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], i8* [[TMP4]], i64 64)
+; CHECK-NEXT: [[TMP6:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8*
+; CHECK-NEXT: [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP6]], i64 64)
+; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbsud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP7]], x86_amx [[TMP3]], x86_amx [[TMP5]])
+; CHECK-NEXT: [[TMP8:%.*]] = bitcast <256 x i32>* [[PC]] to i8*
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP8]], i64 64, x86_amx [[T6]])
; CHECK-NEXT: ret void
;
%t0 = load <256 x i32>, <256 x i32>* %pa, align 64
@@ -225,15 +227,16 @@ define dso_local void @__tile_dpbsud(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <
define dso_local void @__tile_dpbusd(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) {
; CHECK-LABEL: @__tile_dpbusd(
-; CHECK-NEXT: [[TMP1:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8*
-; CHECK-NEXT: [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K:%.*]], i8* [[TMP1]], i64 64)
-; CHECK-NEXT: [[TMP3:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8*
-; CHECK-NEXT: [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[K]], i16 [[N:%.*]], i8* [[TMP3]], i64 64)
-; CHECK-NEXT: [[TMP5:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8*
-; CHECK-NEXT: [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP5]], i64 64)
-; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbusd.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP6]], x86_amx [[TMP2]], x86_amx [[TMP4]])
-; CHECK-NEXT: [[TMP7:%.*]] = bitcast <256 x i32>* [[PC]] to i8*
-; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP7]], i64 64, x86_amx [[T6]])
+; CHECK-NEXT: [[TMP1:%.*]] = udiv i16 [[K:%.*]], 4
+; CHECK-NEXT: [[TMP2:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8*
+; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP2]], i64 64)
+; CHECK-NEXT: [[TMP4:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8*
+; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], i8* [[TMP4]], i64 64)
+; CHECK-NEXT: [[TMP6:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8*
+; CHECK-NEXT: [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP6]], i64 64)
+; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbusd.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP7]], x86_amx [[TMP3]], x86_amx [[TMP5]])
+; CHECK-NEXT: [[TMP8:%.*]] = bitcast <256 x i32>* [[PC]] to i8*
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP8]], i64 64, x86_amx [[T6]])
; CHECK-NEXT: ret void
;
%t0 = load <256 x i32>, <256 x i32>* %pa, align 64
@@ -250,15 +253,16 @@ define dso_local void @__tile_dpbusd(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <
define dso_local void @__tile_dpbuud(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) {
; CHECK-LABEL: @__tile_dpbuud(
-; CHECK-NEXT: [[TMP1:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8*
-; CHECK-NEXT: [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K:%.*]], i8* [[TMP1]], i64 64)
-; CHECK-NEXT: [[TMP3:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8*
-; CHECK-NEXT: [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[K]], i16 [[N:%.*]], i8* [[TMP3]], i64 64)
-; CHECK-NEXT: [[TMP5:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8*
-; CHECK-NEXT: [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP5]], i64 64)
-; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbuud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP6]], x86_amx [[TMP2]], x86_amx [[TMP4]])
-; CHECK-NEXT: [[TMP7:%.*]] = bitcast <256 x i32>* [[PC]] to i8*
-; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP7]], i64 64, x86_amx [[T6]])
+; CHECK-NEXT: [[TMP1:%.*]] = udiv i16 [[K:%.*]], 4
+; CHECK-NEXT: [[TMP2:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8*
+; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP2]], i64 64)
+; CHECK-NEXT: [[TMP4:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8*
+; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], i8* [[TMP4]], i64 64)
+; CHECK-NEXT: [[TMP6:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8*
+; CHECK-NEXT: [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP6]], i64 64)
+; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbuud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP7]], x86_amx [[TMP3]], x86_amx [[TMP5]])
+; CHECK-NEXT: [[TMP8:%.*]] = bitcast <256 x i32>* [[PC]] to i8*
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP8]], i64 64, x86_amx [[T6]])
; CHECK-NEXT: ret void
;
%t0 = load <256 x i32>, <256 x i32>* %pa, align 64
@@ -275,15 +279,16 @@ define dso_local void @__tile_dpbuud(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <
define dso_local void @__tile_dpbf16ps(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) {
; CHECK-LABEL: @__tile_dpbf16ps(
-; CHECK-NEXT: [[TMP1:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8*
-; CHECK-NEXT: [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K:%.*]], i8* [[TMP1]], i64 64)
-; CHECK-NEXT: [[TMP3:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8*
-; CHECK-NEXT: [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[K]], i16 [[N:%.*]], i8* [[TMP3]], i64 64)
-; CHECK-NEXT: [[TMP5:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8*
-; CHECK-NEXT: [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP5]], i64 64)
-; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbf16ps.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP6]], x86_amx [[TMP2]], x86_amx [[TMP4]])
-; CHECK-NEXT: [[TMP7:%.*]] = bitcast <256 x i32>* [[PC]] to i8*
-; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP7]], i64 64, x86_amx [[T6]])
+; CHECK-NEXT: [[TMP1:%.*]] = udiv i16 [[K:%.*]], 4
+; CHECK-NEXT: [[TMP2:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8*
+; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP2]], i64 64)
+; CHECK-NEXT: [[TMP4:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8*
+; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], i8* [[TMP4]], i64 64)
+; CHECK-NEXT: [[TMP6:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8*
+; CHECK-NEXT: [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP6]], i64 64)
+; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbf16ps.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP7]], x86_amx [[TMP3]], x86_amx [[TMP5]])
+; CHECK-NEXT: [[TMP8:%.*]] = bitcast <256 x i32>* [[PC]] to i8*
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP8]], i64 64, x86_amx [[T6]])
; CHECK-NEXT: ret void
;
%t0 = load <256 x i32>, <256 x i32>* %pa, align 64
diff --git a/llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll b/llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll
new file mode 100644
index 0000000000000..4aa5c7e3e1b9a
--- /dev/null
+++ b/llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll
@@ -0,0 +1,412 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt --codegen-opt-level=2 -mtriple=x86_64 -lower-amx-type %s -S | FileCheck %s
+
+define void @combine_amx_cast_inside_bb() {
+; CHECK-LABEL: @combine_amx_cast_inside_bb(
+; CHECK-NEXT: wrapper_entry:
+; CHECK-NEXT: [[TMP0:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP0]])
+; CHECK-NEXT: ret void
+;
+wrapper_entry:
+ %0 = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+ %tmp = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %0)
+ %1 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %tmp)
+ call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %1)
+ ret void
+}
+
+; Cases where amxcast can be combined across bb
+; %5 and %6 is combined together since %goodphi's incoming is phi or amxcast
+define void @combine_amx_cast_and_phi() {
+; CHECK-LABEL: @combine_amx_cast_and_phi(
+; CHECK-NEXT: wrapper_entry:
+; CHECK-NEXT: [[TMP0:%.*]] = alloca <560 x i8>, align 64
+; CHECK-NEXT: [[TMP1:%.*]] = alloca <616 x i8>, align 64
+; CHECK-NEXT: [[TMP2:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+; CHECK-NEXT: br i1 undef, label [[FOR_COND_CLEANUP_I_I:%.*]], label [[FOR_BODY_I_LR_PH_I:%.*]]
+; CHECK: for.body.i.lr.ph.i:
+; CHECK-NEXT: [[TMP4:%.*]] = bitcast <110 x i32>* [[TMP2]] to i8*
+; CHECK-NEXT: store <110 x i32> undef, <110 x i32>* [[TMP2]], align 512
+; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP4]], i64 40)
+; CHECK-NEXT: [[TMP6:%.*]] = bitcast <616 x i8>* [[TMP1]] to i8*
+; CHECK-NEXT: store <616 x i8> undef, <616 x i8>* [[TMP1]], align 1024
+; CHECK-NEXT: [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP6]], i64 56)
+; CHECK-NEXT: [[TMP8:%.*]] = bitcast <560 x i8>* [[TMP0]] to i8*
+; CHECK-NEXT: store <560 x i8> undef, <560 x i8>* [[TMP0]], align 1024
+; CHECK-NEXT: [[TMP9:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP8]], i64 40)
+; CHECK-NEXT: [[TMP10:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP5]], x86_amx [[TMP7]], x86_amx [[TMP9]])
+; CHECK-NEXT: br label [[FOR_COND_CLEANUP_I_I]]
+; CHECK: for.cond.cleanup.i.i:
+; CHECK-NEXT: [[TMP11:%.*]] = phi x86_amx [ [[TMP3]], [[WRAPPER_ENTRY:%.*]] ], [ [[TMP10]], [[FOR_BODY_I_LR_PH_I]] ]
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP11]])
+; CHECK-NEXT: ret void
+;
+wrapper_entry:
+ %0 = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+ %tmp = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %0)
+ br i1 undef, label %for.cond.cleanup.i.i, label %for.body.i.lr.ph.i
+
+for.body.i.lr.ph.i: ; preds = %wrapper_entry
+ %1 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> undef)
+ %2 = call x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8> undef)
+ %3 = call x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8> undef)
+ %4 = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx %1, x86_amx %2, x86_amx %3)
+ %5 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %4)
+ br label %for.cond.cleanup.i.i
+
+for.cond.cleanup.i.i: ; preds = %for.body.i.lr.ph.i, %wrapper_entry
+ %goodphi = phi <110 x i32> [ %tmp, %wrapper_entry ], [ %5, %for.body.i.lr.ph.i ]
+ %6 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %goodphi)
+ call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %6)
+ ret void
+}
+
+; Cases where amxcast can't be combined across bb
+; %5 and %6 is not combined together since %evilphi's incoming is not phi or amxcast
+define void @fail_to_combine_amx_cast_and_phi(<110 x i32> %tmp) {
+; CHECK-LABEL: @fail_to_combine_amx_cast_and_phi(
+; CHECK-NEXT: wrapper_entry:
+; CHECK-NEXT: [[TMP0:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT: [[TMP1:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT: [[TMP2:%.*]] = alloca <560 x i8>, align 64
+; CHECK-NEXT: [[TMP3:%.*]] = alloca <616 x i8>, align 64
+; CHECK-NEXT: [[TMP4:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT: [[TMP5:%.*]] = add <110 x i32> [[TMP:%.*]], [[TMP]]
+; CHECK-NEXT: br i1 undef, label [[FOR_COND_CLEANUP_I_I:%.*]], label [[FOR_BODY_I_LR_PH_I:%.*]]
+; CHECK: for.body.i.lr.ph.i:
+; CHECK-NEXT: [[TMP6:%.*]] = bitcast <110 x i32>* [[TMP4]] to i8*
+; CHECK-NEXT: store <110 x i32> undef, <110 x i32>* [[TMP4]], align 512
+; CHECK-NEXT: [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP6]], i64 40)
+; CHECK-NEXT: [[TMP8:%.*]] = bitcast <616 x i8>* [[TMP3]] to i8*
+; CHECK-NEXT: store <616 x i8> undef, <616 x i8>* [[TMP3]], align 1024
+; CHECK-NEXT: [[TMP9:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP8]], i64 56)
+; CHECK-NEXT: [[TMP10:%.*]] = bitcast <560 x i8>* [[TMP2]] to i8*
+; CHECK-NEXT: store <560 x i8> undef, <560 x i8>* [[TMP2]], align 1024
+; CHECK-NEXT: [[TMP11:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP10]], i64 40)
+; CHECK-NEXT: [[TMP12:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP7]], x86_amx [[TMP9]], x86_amx [[TMP11]])
+; CHECK-NEXT: [[TMP13:%.*]] = bitcast <110 x i32>* [[TMP1]] to i8*
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* [[TMP13]], i64 40, x86_amx [[TMP12]])
+; CHECK-NEXT: [[TMP14:%.*]] = load <110 x i32>, <110 x i32>* [[TMP1]], align 512
+; CHECK-NEXT: br label [[FOR_COND_CLEANUP_I_I]]
+; CHECK: for.cond.cleanup.i.i:
+; CHECK-NEXT: [[EVILPHI:%.*]] = phi <110 x i32> [ [[TMP5]], [[WRAPPER_ENTRY:%.*]] ], [ [[TMP14]], [[FOR_BODY_I_LR_PH_I]] ]
+; CHECK-NEXT: [[TMP15:%.*]] = bitcast <110 x i32>* [[TMP0]] to i8*
+; CHECK-NEXT: store <110 x i32> [[EVILPHI]], <110 x i32>* [[TMP0]], align 512
+; CHECK-NEXT: [[TMP16:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP15]], i64 40)
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP16]])
+; CHECK-NEXT: ret void
+;
+wrapper_entry:
+ %0 = add <110 x i32> %tmp, %tmp
+ br i1 undef, label %for.cond.cleanup.i.i, label %for.body.i.lr.ph.i
+
+for.body.i.lr.ph.i: ; preds = %wrapper_entry
+ %1 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> undef)
+ %2 = call x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8> undef)
+ %3 = call x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8> undef)
+ %4 = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx %1, x86_amx %2, x86_amx %3)
+ %5 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %4)
+ br label %for.cond.cleanup.i.i
+
+for.cond.cleanup.i.i: ; preds = %for.body.i.lr.ph.i, %wrapper_entry
+ %evilphi = phi <110 x i32> [ %0, %wrapper_entry ], [ %5, %for.body.i.lr.ph.i ]
+ %6 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %evilphi)
+ call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %6)
+ ret void
+}
+
+; Cases where amxcast can't be combined across bb
+; %5 and %6 is not combined together since %evilphi's user aka %evilphi2 is not inside phi web.
+define void @fail_to_combine_amx_cast_and_phi2() {
+; CHECK-LABEL: @fail_to_combine_amx_cast_and_phi2(
+; CHECK-NEXT: wrapper_entry:
+; CHECK-NEXT: [[TMP0:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT: [[TMP1:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT: [[TMP2:%.*]] = alloca <560 x i8>, align 64
+; CHECK-NEXT: [[TMP3:%.*]] = alloca <616 x i8>, align 64
+; CHECK-NEXT: [[TMP4:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT: [[TMP5:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT: [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+; CHECK-NEXT: [[TMP7:%.*]] = bitcast <110 x i32>* [[TMP5]] to i8*
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* [[TMP7]], i64 40, x86_amx [[TMP6]])
+; CHECK-NEXT: [[TMP8:%.*]] = load <110 x i32>, <110 x i32>* [[TMP5]], align 512
+; CHECK-NEXT: br i1 undef, label [[FOR_COND_CLEANUP_I_I:%.*]], label [[FOR_BODY_I_LR_PH_I:%.*]]
+; CHECK: for.body.i.lr.ph.i:
+; CHECK-NEXT: [[TMP9:%.*]] = bitcast <110 x i32>* [[TMP4]] to i8*
+; CHECK-NEXT: store <110 x i32> undef, <110 x i32>* [[TMP4]], align 512
+; CHECK-NEXT: [[TMP10:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP9]], i64 40)
+; CHECK-NEXT: [[TMP11:%.*]] = bitcast <616 x i8>* [[TMP3]] to i8*
+; CHECK-NEXT: store <616 x i8> undef, <616 x i8>* [[TMP3]], align 1024
+; CHECK-NEXT: [[TMP12:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP11]], i64 56)
+; CHECK-NEXT: [[TMP13:%.*]] = bitcast <560 x i8>* [[TMP2]] to i8*
+; CHECK-NEXT: store <560 x i8> undef, <560 x i8>* [[TMP2]], align 1024
+; CHECK-NEXT: [[TMP14:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP13]], i64 40)
+; CHECK-NEXT: [[TMP15:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP10]], x86_amx [[TMP12]], x86_amx [[TMP14]])
+; CHECK-NEXT: [[TMP16:%.*]] = bitcast <110 x i32>* [[TMP1]] to i8*
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* [[TMP16]], i64 40, x86_amx [[TMP15]])
+; CHECK-NEXT: [[TMP17:%.*]] = load <110 x i32>, <110 x i32>* [[TMP1]], align 512
+; CHECK-NEXT: br i1 undef, label [[FOR_COND_CLEANUP_I_I]], label [[EXIT:%.*]]
+; CHECK: for.cond.cleanup.i.i:
+; CHECK-NEXT: [[GOODPHI:%.*]] = phi <110 x i32> [ [[TMP8]], [[WRAPPER_ENTRY:%.*]] ], [ [[TMP17]], [[FOR_BODY_I_LR_PH_I]] ]
+; CHECK-NEXT: [[TMP18:%.*]] = bitcast <110 x i32>* [[TMP0]] to i8*
+; CHECK-NEXT: store <110 x i32> [[GOODPHI]], <110 x i32>* [[TMP0]], align 512
+; CHECK-NEXT: [[TMP19:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP18]], i64 40)
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP19]])
+; CHECK-NEXT: br i1 undef, label [[EXIT]], label [[FOR_BODY_I_LR_PH_I]]
+; CHECK: exit:
+; CHECK-NEXT: [[EVILPHI2:%.*]] = phi <110 x i32> [ [[GOODPHI]], [[FOR_COND_CLEANUP_I_I]] ], [ [[TMP17]], [[FOR_BODY_I_LR_PH_I]] ]
+; CHECK-NEXT: store <110 x i32> [[EVILPHI2]], <110 x i32>* undef, align 512
+; CHECK-NEXT: ret void
+;
+wrapper_entry:
+ %0 = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+ %tmp = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %0)
+ br i1 undef, label %for.cond.cleanup.i.i, label %for.body.i.lr.ph.i
+
+for.body.i.lr.ph.i: ; preds = %wrapper_entry
+ %1 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> undef)
+ %2 = call x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8> undef)
+ %3 = call x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8> undef)
+ %4 = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx %1, x86_amx %2, x86_amx %3)
+ %5 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %4)
+ br i1 undef, label %for.cond.cleanup.i.i, label %exit
+
+for.cond.cleanup.i.i: ; preds = %for.body.i.lr.ph.i, %wrapper_entry
+ %goodphi = phi <110 x i32> [ %tmp, %wrapper_entry ], [ %5, %for.body.i.lr.ph.i ]
+ %6 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %goodphi)
+ call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %6)
+ br i1 undef, label %exit, label %for.body.i.lr.ph.i
+exit:
+ %evilphi2 = phi <110 x i32> [ %goodphi, %for.cond.cleanup.i.i ], [ %5, %for.body.i.lr.ph.i ]
+ store <110 x i32> %evilphi2, <110 x i32>* undef, align 512
+ ret void
+}
+
+define void @fail_to_combine_amx_cast_and_phi_due_to_const_value() {
+; CHECK-LABEL: @fail_to_combine_amx_cast_and_phi_due_to_const_value(
+; CHECK-NEXT: wrapper_entry:
+; CHECK-NEXT: [[TMP0:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT: [[TMP1:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT: [[TMP2:%.*]] = alloca <560 x i8>, align 64
+; CHECK-NEXT: [[TMP3:%.*]] = alloca <616 x i8>, align 64
+; CHECK-NEXT: [[TMP4:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT: br i1 undef, label [[FOR_COND_CLEANUP_I_I:%.*]], label [[FOR_BODY_I_LR_PH_I:%.*]]
+; CHECK: for.body.i.lr.ph.i:
+; CHECK-NEXT: [[TMP5:%.*]] = bitcast <110 x i32>* [[TMP4]] to i8*
+; CHECK-NEXT: store <110 x i32> undef, <110 x i32>* [[TMP4]], align 512
+; CHECK-NEXT: [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP5]], i64 40)
+; CHECK-NEXT: [[TMP7:%.*]] = bitcast <616 x i8>* [[TMP3]] to i8*
+; CHECK-NEXT: store <616 x i8> undef, <616 x i8>* [[TMP3]], align 1024
+; CHECK-NEXT: [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP7]], i64 56)
+; CHECK-NEXT: [[TMP9:%.*]] = bitcast <560 x i8>* [[TMP2]] to i8*
+; CHECK-NEXT: store <560 x i8> undef, <560 x i8>* [[TMP2]], align 1024
+; CHECK-NEXT: [[TMP10:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP9]], i64 40)
+; CHECK-NEXT: [[TMP11:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP6]], x86_amx [[TMP8]], x86_amx [[TMP10]])
+; CHECK-NEXT: [[TMP12:%.*]] = bitcast <110 x i32>* [[TMP1]] to i8*
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* [[TMP12]], i64 40, x86_amx [[TMP11]])
+; CHECK-NEXT: [[TMP13:%.*]] = load <110 x i32>, <110 x i32>* [[TMP1]], align 512
+; CHECK-NEXT: br label [[FOR_COND_CLEANUP_I_I]]
+; CHECK: for.cond.cleanup.i.i:
+; CHECK-NEXT: [[EVILPHI:%.*]] = phi <110 x i32> [ undef, [[WRAPPER_ENTRY:%.*]] ], [ [[TMP13]], [[FOR_BODY_I_LR_PH_I]] ]
+; CHECK-NEXT: [[TMP14:%.*]] = bitcast <110 x i32>* [[TMP0]] to i8*
+; CHECK-NEXT: store <110 x i32> [[EVILPHI]], <110 x i32>* [[TMP0]], align 512
+; CHECK-NEXT: [[TMP15:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP14]], i64 40)
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP15]])
+; CHECK-NEXT: ret void
+;
+wrapper_entry:
+ br i1 undef, label %for.cond.cleanup.i.i, label %for.body.i.lr.ph.i
+
+for.body.i.lr.ph.i: ; preds = %wrapper_entry
+ %0 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> undef)
+ %1 = call x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8> undef)
+ %2 = call x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8> undef)
+ %3 = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx %0, x86_amx %1, x86_amx %2)
+ %4 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %3)
+ br label %for.cond.cleanup.i.i
+
+for.cond.cleanup.i.i: ; preds = %for.body.i.lr.ph.i, %wrapper_entry
+ %evilphi = phi <110 x i32> [ undef, %wrapper_entry ], [ %4, %for.body.i.lr.ph.i ]
+ %5 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %evilphi)
+ call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %5)
+ ret void
+}
+
+; Cases where amxcast can be combined across bb
+; When optimizeAMXCastFromPhi process %6 and %goodphi, %goodphi2 is outside the phi-web, so the optimization stop
+; When optimizeAMXCastFromPhi process %7 and %goodphi2, the optimization continue.
+define void @combine_amx_cast_and_multiple_phi() {
+; CHECK-LABEL: @combine_amx_cast_and_multiple_phi(
+; CHECK-NEXT: wrapper_entry:
+; CHECK-NEXT: [[TMP0:%.*]] = alloca <560 x i8>, align 64
+; CHECK-NEXT: [[TMP1:%.*]] = alloca <616 x i8>, align 64
+; CHECK-NEXT: [[TMP2:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+; CHECK-NEXT: br i1 undef, label [[FOR_COND_CLEANUP_I_I:%.*]], label [[FOR_BODY_I_LR_PH_I:%.*]]
+; CHECK: for.body.i.lr.ph.i:
+; CHECK-NEXT: [[TMP4:%.*]] = bitcast <110 x i32>* [[TMP2]] to i8*
+; CHECK-NEXT: store <110 x i32> undef, <110 x i32>* [[TMP2]], align 512
+; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP4]], i64 40)
+; CHECK-NEXT: [[TMP6:%.*]] = bitcast <616 x i8>* [[TMP1]] to i8*
+; CHECK-NEXT: store <616 x i8> undef, <616 x i8>* [[TMP1]], align 1024
+; CHECK-NEXT: [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP6]], i64 56)
+; CHECK-NEXT: [[TMP8:%.*]] = bitcast <560 x i8>* [[TMP0]] to i8*
+; CHECK-NEXT: store <560 x i8> undef, <560 x i8>* [[TMP0]], align 1024
+; CHECK-NEXT: [[TMP9:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP8]], i64 40)
+; CHECK-NEXT: [[TMP10:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP5]], x86_amx [[TMP7]], x86_amx [[TMP9]])
+; CHECK-NEXT: br i1 undef, label [[FOR_COND_CLEANUP_I_I]], label [[EXIT:%.*]]
+; CHECK: for.cond.cleanup.i.i:
+; CHECK-NEXT: [[TMP11:%.*]] = phi x86_amx [ [[TMP3]], [[WRAPPER_ENTRY:%.*]] ], [ [[TMP10]], [[FOR_BODY_I_LR_PH_I]] ]
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP11]])
+; CHECK-NEXT: br i1 undef, label [[EXIT]], label [[FOR_BODY_I_LR_PH_I]]
+; CHECK: exit:
+; CHECK-NEXT: [[TMP12:%.*]] = phi x86_amx [ [[TMP11]], [[FOR_COND_CLEANUP_I_I]] ], [ [[TMP10]], [[FOR_BODY_I_LR_PH_I]] ]
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP12]])
+; CHECK-NEXT: ret void
+;
+wrapper_entry:
+ %0 = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+ %tmp = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %0)
+ br i1 undef, label %for.cond.cleanup.i.i, label %for.body.i.lr.ph.i
+
+for.body.i.lr.ph.i: ; preds = %wrapper_entry
+ %1 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> undef)
+ %2 = call x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8> undef)
+ %3 = call x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8> undef)
+ %4 = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx %1, x86_amx %2, x86_amx %3)
+ %5 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %4)
+ br i1 undef, label %for.cond.cleanup.i.i, label %exit
+
+for.cond.cleanup.i.i: ; preds = %for.body.i.lr.ph.i, %wrapper_entry
+ %goodphi = phi <110 x i32> [ %tmp, %wrapper_entry ], [ %5, %for.body.i.lr.ph.i ]
+ %6 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %goodphi)
+ call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %6)
+ br i1 undef, label %exit, label %for.body.i.lr.ph.i
+exit:
+ %evilphi2 = phi <110 x i32> [ %goodphi, %for.cond.cleanup.i.i ], [ %5, %for.body.i.lr.ph.i ]
+ %7 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %evilphi2)
+ call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %7)
+ ret void
+}
+
+; Currently we are not able to delete DeadPHICycle, later we will handle with them
+define void @combine_amx_cast_and_phi_in_a_circle() {
+; CHECK-LABEL: @combine_amx_cast_and_phi_in_a_circle(
+; CHECK-NEXT: wrapper_entry:
+; CHECK-NEXT: [[TMP0:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT: [[TMP1:%.*]] = alloca <560 x i8>, align 64
+; CHECK-NEXT: [[TMP2:%.*]] = alloca <616 x i8>, align 64
+; CHECK-NEXT: [[TMP3:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT: [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+; CHECK-NEXT: br label [[BB1:%.*]]
+; CHECK: bb1:
+; CHECK-NEXT: [[TMP5:%.*]] = bitcast <110 x i32>* [[TMP3]] to i8*
+; CHECK-NEXT: store <110 x i32> undef, <110 x i32>* [[TMP3]], align 512
+; CHECK-NEXT: [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP5]], i64 40)
+; CHECK-NEXT: [[TMP7:%.*]] = bitcast <616 x i8>* [[TMP2]] to i8*
+; CHECK-NEXT: store <616 x i8> undef, <616 x i8>* [[TMP2]], align 1024
+; CHECK-NEXT: [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP7]], i64 56)
+; CHECK-NEXT: [[TMP9:%.*]] = bitcast <560 x i8>* [[TMP1]] to i8*
+; CHECK-NEXT: store <560 x i8> undef, <560 x i8>* [[TMP1]], align 1024
+; CHECK-NEXT: [[TMP10:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP9]], i64 40)
+; CHECK-NEXT: [[TMP11:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP6]], x86_amx [[TMP8]], x86_amx [[TMP10]])
+; CHECK-NEXT: [[TMP12:%.*]] = bitcast <110 x i32>* [[TMP0]] to i8*
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* [[TMP12]], i64 40, x86_amx [[TMP11]])
+; CHECK-NEXT: [[TMP13:%.*]] = load <110 x i32>, <110 x i32>* [[TMP0]], align 512
+; CHECK-NEXT: br i1 undef, label [[BB2:%.*]], label [[BB3:%.*]]
+; CHECK: bb2:
+; CHECK-NEXT: [[TMP14:%.*]] = phi x86_amx [ [[TMP15:%.*]], [[BB3]] ], [ [[TMP11]], [[BB1]] ]
+; CHECK-NEXT: [[GOODPHI:%.*]] = phi <110 x i32> [ [[EVILPHI2:%.*]], [[BB3]] ], [ [[TMP13]], [[BB1]] ]
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP14]])
+; CHECK-NEXT: br label [[BB3]]
+; CHECK: bb3:
+; CHECK-NEXT: [[TMP15]] = phi x86_amx [ [[TMP14]], [[BB2]] ], [ [[TMP11]], [[BB1]] ]
+; CHECK-NEXT: [[EVILPHI2]] = phi <110 x i32> [ [[GOODPHI]], [[BB2]] ], [ [[TMP13]], [[BB1]] ]
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP15]])
+; CHECK-NEXT: br i1 undef, label [[BB2]], label [[EXIT:%.*]]
+; CHECK: exit:
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP15]])
+; CHECK-NEXT: ret void
+;
+wrapper_entry:
+ %0 = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+ %tmp = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %0)
+ br label %bb1
+
+bb1: ; preds = %wrapper_entry
+ %1 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> undef)
+ %2 = call x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8> undef)
+ %3 = call x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8> undef)
+ %4 = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx %1, x86_amx %2, x86_amx %3)
+ %5 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %4)
+ br i1 undef, label %bb2, label %bb3
+
+bb2: ; preds = %bb1, %wrapper_entry
+ %goodphi = phi <110 x i32> [ %evilphi2, %bb3], [ %5, %bb1 ]
+ %6 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %goodphi)
+ call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %6)
+ br label %bb3
+bb3:
+ %evilphi2 = phi <110 x i32> [ %goodphi, %bb2 ], [ %5, %bb1 ]
+ %7 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %evilphi2)
+ call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %7)
+ br i1 undef, label %bb2, label %exit
+exit:
+ %8 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %evilphi2)
+ call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %8)
+ ret void
+}
+
+define void @eliminate_unused_phi_and_cast() {
+; CHECK-LABEL: @eliminate_unused_phi_and_cast(
+; CHECK-NEXT: wrapper_entry:
+; CHECK-NEXT: [[TMP0:%.*]] = alloca <560 x i8>, align 64
+; CHECK-NEXT: [[TMP1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+; CHECK-NEXT: br i1 undef, label [[FOR_COND_CLEANUP_I_I:%.*]], label [[FOR_BODY_I_LR_PH_I:%.*]]
+; CHECK: for.body.i.lr.ph.i:
+; CHECK-NEXT: [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* undef, i64 undef)
+; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* undef, i64 undef)
+; CHECK-NEXT: [[TMP4:%.*]] = bitcast <560 x i8>* [[TMP0]] to i8*
+; CHECK-NEXT: store <560 x i8> undef, <560 x i8>* [[TMP0]], align 1024
+; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP4]], i64 40)
+; CHECK-NEXT: [[TMP6:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP2]], x86_amx [[TMP3]], x86_amx [[TMP5]])
+; CHECK-NEXT: br label [[FOR_COND_CLEANUP_I_I]]
+; CHECK: for.cond.cleanup.i.i:
+; CHECK-NEXT: [[TMP7:%.*]] = phi x86_amx [ [[TMP1]], [[WRAPPER_ENTRY:%.*]] ], [ [[TMP6]], [[FOR_BODY_I_LR_PH_I]] ]
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP7]])
+; CHECK-NEXT: ret void
+;
+wrapper_entry:
+ %0 = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+ %tmp = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %0)
+ br i1 undef, label %for.cond.cleanup.i.i, label %for.body.i.lr.ph.i
+
+for.body.i.lr.ph.i: ; preds = %wrapper_entry
+ %1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* undef, i64 undef)
+ %v1 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %1)
+ %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* undef, i64 undef)
+ %v2 = call <616 x i8> @llvm.x86.cast.tile.to.vector.v616i8(x86_amx %2)
+ %3 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %v1)
+ %4 = call x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8> %v2)
+ %5 = call x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8> undef)
+ %6 = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx %3, x86_amx %4, x86_amx %5)
+ %7 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %6)
+ br label %for.cond.cleanup.i.i
+
+for.cond.cleanup.i.i: ; preds = %for.body.i.lr.ph.i, %wrapper_entry
+ %goodphi = phi <110 x i32> [ %tmp, %wrapper_entry ], [ %7, %for.body.i.lr.ph.i ]
+ %8 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %goodphi)
+ call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %8)
+ ret void
+}
+
+declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
+declare <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx)
+declare <616 x i8> @llvm.x86.cast.tile.to.vector.v616i8(x86_amx)
+declare x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32>)
+declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)
+declare x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8>)
+declare x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8>)
+declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
diff --git a/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll b/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll
new file mode 100644
index 0000000000000..98a820197bbd6
--- /dev/null
+++ b/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll
@@ -0,0 +1,429 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt --codegen-opt-level=2 -mtriple=x86_64 -lower-amx-type %s -S | FileCheck %s
+
+%struct.__tile_str = type { i16, i16, <256 x i32> }
+
+ at buf = dso_local global [1024 x i8] zeroinitializer, align 64
+ at buf2 = dso_local global [1024 x i8] zeroinitializer, align 64
+
+; test bitcast x86_amx to <256 x i32>
+define dso_local void @test_user_empty(i16 %m, i16 %n, i8 *%buf, i64 %s) {
+; CHECK-LABEL: @test_user_empty(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[T1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[N:%.*]], i8* [[BUF:%.*]], i64 [[S:%.*]])
+; CHECK-NEXT: ret void
+;
+entry:
+ %t1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %n, i8* %buf, i64 %s)
+ %t2 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t1)
+ ret void
+}
+
+; test bitcast <256 x i32> to x86_amx
+define dso_local void @test_user_empty2(<256 x i32> %in) {
+; CHECK-LABEL: @test_user_empty2(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: ret void
+;
+entry:
+ %t = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %in)
+ ret void
+}
+
+define dso_local <256 x i32> @test_amx_load_bitcast_v256i32(<256 x i32>* %in, i16 %m, i16 %n, i8 *%buf, i64 %s) {
+; CHECK-LABEL: @test_amx_load_bitcast_v256i32(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[TMP0:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[T1:%.*]] = load <256 x i32>, <256 x i32>* [[IN:%.*]], align 64
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8*
+; CHECK-NEXT: store <256 x i32> [[T1]], <256 x i32>* [[TMP0]], align 1024
+; CHECK-NEXT: [[TMP2:%.*]] = sext i16 [[N:%.*]] to i64
+; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[N]], i8* [[TMP1]], i64 [[TMP2]])
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[BUF:%.*]], i64 [[S:%.*]], x86_amx [[TMP3]])
+; CHECK-NEXT: ret <256 x i32> [[T1]]
+;
+entry:
+ %t1 = load <256 x i32>, <256 x i32>* %in, align 64
+ %t2 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t1)
+ call void @llvm.x86.tilestored64.internal(i16 %m, i16 %n, i8* %buf, i64 %s, x86_amx %t2)
+ ret <256 x i32> %t1
+}
+
+define dso_local <225 x i32> @test_amx_load_bitcast_v225i32(<225 x i32>* %in, i16 %m, i16 %n, i8 *%buf, i64 %s) {
+; CHECK-LABEL: @test_amx_load_bitcast_v225i32(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[TMP0:%.*]] = alloca <225 x i32>, align 64
+; CHECK-NEXT: [[T1:%.*]] = load <225 x i32>, <225 x i32>* [[IN:%.*]], align 64
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <225 x i32>* [[TMP0]] to i8*
+; CHECK-NEXT: store <225 x i32> [[T1]], <225 x i32>* [[TMP0]], align 1024
+; CHECK-NEXT: [[TMP2:%.*]] = sext i16 [[N:%.*]] to i64
+; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[N]], i8* [[TMP1]], i64 [[TMP2]])
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[BUF:%.*]], i64 [[S:%.*]], x86_amx [[TMP3]])
+; CHECK-NEXT: ret <225 x i32> [[T1]]
+;
+entry:
+ %t1 = load <225 x i32>, <225 x i32>* %in, align 64
+ %t2 = call x86_amx @llvm.x86.cast.vector.to.tile.v225i32(<225 x i32> %t1)
+ call void @llvm.x86.tilestored64.internal(i16 %m, i16 %n, i8* %buf, i64 %s, x86_amx %t2)
+ ret <225 x i32> %t1
+}
+
+define dso_local <256 x i32> @test_amx_bitcast_store(<256 x i32>* %out, i16 %m, i16 %n, i8 *%buf, i64 %s) {
+; CHECK-LABEL: @test_amx_bitcast_store(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[TMP0:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[T1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[M]], i8* [[BUF:%.*]], i64 [[S:%.*]])
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8*
+; CHECK-NEXT: [[TMP2:%.*]] = sext i16 [[M]] to i64
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[M]], i8* [[TMP1]], i64 [[TMP2]], x86_amx [[T1]])
+; CHECK-NEXT: [[TMP3:%.*]] = load <256 x i32>, <256 x i32>* [[TMP0]], align 1024
+; CHECK-NEXT: store <256 x i32> [[TMP3]], <256 x i32>* [[OUT:%.*]], align 1024
+; CHECK-NEXT: ret <256 x i32> [[TMP3]]
+;
+entry:
+ %t1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %m, i8* %buf, i64 %s)
+ %t2 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t1)
+ store <256 x i32> %t2, <256 x i32>* %out
+ ret <256 x i32> %t2
+}
+
+define dso_local void @test_src_add(<256 x i32> %x, <256 x i32> %y, i16 %r, i16 %c, i8* %buf, i64 %s) {
+; CHECK-LABEL: @test_src_add(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[TMP0:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[ADD:%.*]] = add <256 x i32> [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8*
+; CHECK-NEXT: store <256 x i32> [[ADD]], <256 x i32>* [[TMP0]], align 1024
+; CHECK-NEXT: [[TMP2:%.*]] = sext i16 [[C:%.*]] to i64
+; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[R:%.*]], i16 [[C]], i8* [[TMP1]], i64 [[TMP2]])
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[R]], i16 [[C]], i8* [[BUF:%.*]], i64 [[S:%.*]], x86_amx [[TMP3]])
+; CHECK-NEXT: ret void
+;
+entry:
+ %add = add <256 x i32> %y, %x
+ %t = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %add)
+ call void @llvm.x86.tilestored64.internal(i16 %r, i16 %c, i8* %buf, i64 %s, x86_amx %t)
+ ret void
+}
+
+define dso_local void @test_src_add2(<256 x i32> %x, i16 %r, i16 %c, i8* %buf, i64 %s) {
+; CHECK-LABEL: @test_src_add2(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[TMP0:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[T1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[R:%.*]], i16 [[C:%.*]], i8* [[BUF:%.*]], i64 [[S:%.*]])
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8*
+; CHECK-NEXT: [[TMP2:%.*]] = sext i16 [[C]] to i64
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[R]], i16 [[C]], i8* [[TMP1]], i64 [[TMP2]], x86_amx [[T1]])
+; CHECK-NEXT: [[TMP3:%.*]] = load <256 x i32>, <256 x i32>* [[TMP0]], align 1024
+; CHECK-NEXT: [[ADD:%.*]] = add <256 x i32> [[TMP3]], [[X:%.*]]
+; CHECK-NEXT: ret void
+;
+entry:
+ %t1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %r, i16 %c, i8* %buf, i64 %s)
+ %t2 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t1)
+ %add = add <256 x i32> %t2, %x
+ ret void
+}
+
+define dso_local void @__tile_loadd(%struct.__tile_str* nocapture %0, i8* %1, i64 %2) local_unnamed_addr {
+; CHECK-LABEL: @__tile_loadd(
+; CHECK-NEXT: [[TMP4:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 0
+; CHECK-NEXT: [[TMP6:%.*]] = load i16, i16* [[TMP5]], align 64
+; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0]], i64 0, i32 1
+; CHECK-NEXT: [[TMP8:%.*]] = load i16, i16* [[TMP7]], align 2
+; CHECK-NEXT: [[TMP9:%.*]] = shl i64 [[TMP2:%.*]], 32
+; CHECK-NEXT: [[TMP10:%.*]] = ashr exact i64 [[TMP9]], 32
+; CHECK-NEXT: [[TMP11:%.*]] = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP6]], i16 [[TMP8]], i8* [[TMP1:%.*]], i64 [[TMP10]])
+; CHECK-NEXT: [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8*
+; CHECK-NEXT: [[TMP13:%.*]] = sext i16 [[TMP8]] to i64
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP6]], i16 [[TMP8]], i8* [[TMP12]], i64 [[TMP13]], x86_amx [[TMP11]])
+; CHECK-NEXT: [[TMP14:%.*]] = load <256 x i32>, <256 x i32>* [[TMP4]], align 1024
+; CHECK-NEXT: [[TMP15:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0]], i64 0, i32 2
+; CHECK-NEXT: store <256 x i32> [[TMP14]], <256 x i32>* [[TMP15]], align 64
+; CHECK-NEXT: ret void
+;
+ %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 0
+ %5 = load i16, i16* %4, align 64
+ %6 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 1
+ %7 = load i16, i16* %6, align 2
+ %8 = shl i64 %2, 32
+ %9 = ashr exact i64 %8, 32
+ %10 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %5, i16 %7, i8* %1, i64 %9)
+ %11 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %10)
+ %12 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 2
+ store <256 x i32> %11, <256 x i32>* %12, align 64
+ ret void
+}
+
+define dso_local void @__tile_dpbssd(%struct.__tile_str* nocapture %0, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %1, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %2) local_unnamed_addr {
+; CHECK-LABEL: @__tile_dpbssd(
+; CHECK-NEXT: [[TMP4:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[TMP5:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[TMP6:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[TMP7:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP1:%.*]], i64 0, i32 0
+; CHECK-NEXT: [[TMP9:%.*]] = load i16, i16* [[TMP8]], align 64
+; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 1
+; CHECK-NEXT: [[TMP11:%.*]] = load i16, i16* [[TMP10]], align 2
+; CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 1
+; CHECK-NEXT: [[TMP13:%.*]] = load i16, i16* [[TMP12]], align 2
+; CHECK-NEXT: [[TMP14:%.*]] = udiv i16 [[TMP13]], 4
+; CHECK-NEXT: [[TMP15:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 2
+; CHECK-NEXT: [[TMP16:%.*]] = load <256 x i32>, <256 x i32>* [[TMP15]], align 64
+; CHECK-NEXT: [[TMP17:%.*]] = bitcast <256 x i32>* [[TMP7]] to i8*
+; CHECK-NEXT: store <256 x i32> [[TMP16]], <256 x i32>* [[TMP7]], align 1024
+; CHECK-NEXT: [[TMP18:%.*]] = sext i16 [[TMP11]] to i64
+; CHECK-NEXT: [[TMP19:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP11]], i8* [[TMP17]], i64 [[TMP18]])
+; CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 2
+; CHECK-NEXT: [[TMP21:%.*]] = load <256 x i32>, <256 x i32>* [[TMP20]], align 64
+; CHECK-NEXT: [[TMP22:%.*]] = bitcast <256 x i32>* [[TMP6]] to i8*
+; CHECK-NEXT: store <256 x i32> [[TMP21]], <256 x i32>* [[TMP6]], align 1024
+; CHECK-NEXT: [[TMP23:%.*]] = sext i16 [[TMP13]] to i64
+; CHECK-NEXT: [[TMP24:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP13]], i8* [[TMP22]], i64 [[TMP23]])
+; CHECK-NEXT: [[TMP25:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2
+; CHECK-NEXT: [[TMP26:%.*]] = load <256 x i32>, <256 x i32>* [[TMP25]], align 64
+; CHECK-NEXT: [[TMP27:%.*]] = bitcast <256 x i32>* [[TMP5]] to i8*
+; CHECK-NEXT: store <256 x i32> [[TMP26]], <256 x i32>* [[TMP5]], align 1024
+; CHECK-NEXT: [[TMP28:%.*]] = sext i16 [[TMP11]] to i64
+; CHECK-NEXT: [[TMP29:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP14]], i16 [[TMP11]], i8* [[TMP27]], i64 [[TMP28]])
+; CHECK-NEXT: [[TMP30:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[TMP9]], i16 [[TMP11]], i16 [[TMP13]], x86_amx [[TMP19]], x86_amx [[TMP24]], x86_amx [[TMP29]])
+; CHECK-NEXT: [[TMP31:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8*
+; CHECK-NEXT: [[TMP32:%.*]] = sext i16 [[TMP11]] to i64
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP9]], i16 [[TMP11]], i8* [[TMP31]], i64 [[TMP32]], x86_amx [[TMP30]])
+; CHECK-NEXT: [[TMP33:%.*]] = load <256 x i32>, <256 x i32>* [[TMP4]], align 1024
+; CHECK-NEXT: store <256 x i32> [[TMP33]], <256 x i32>* [[TMP15]], align 64
+; CHECK-NEXT: ret void
+;
+ %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 0
+ %5 = load i16, i16* %4, align 64
+ %6 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 1
+ %7 = load i16, i16* %6, align 2
+ %8 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 1
+ %9 = load i16, i16* %8, align 2
+ %10 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 2
+ %11 = load <256 x i32>, <256 x i32>* %10, align 64
+ %12 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %11)
+ %13 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 2
+ %14 = load <256 x i32>, <256 x i32>* %13, align 64
+ %15 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %14)
+ %16 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 2
+ %17 = load <256 x i32>, <256 x i32>* %16, align 64
+ %18 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %17)
+ %19 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %5, i16 %7, i16 %9, x86_amx %12, x86_amx %15, x86_amx %18)
+ %20 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %19)
+ store <256 x i32> %20, <256 x i32>* %10, align 64
+ ret void
+}
+
+define dso_local void @__tile_dpbsud(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) {
+; CHECK-LABEL: @__tile_dpbsud(
+; CHECK-NEXT: [[TMP1:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[TMP2:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[TMP3:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[TMP4:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[TMP5:%.*]] = udiv i16 [[K:%.*]], 4
+; CHECK-NEXT: [[T0:%.*]] = load <256 x i32>, <256 x i32>* [[PA:%.*]], align 64
+; CHECK-NEXT: [[TMP6:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8*
+; CHECK-NEXT: store <256 x i32> [[T0]], <256 x i32>* [[TMP4]], align 1024
+; CHECK-NEXT: [[TMP7:%.*]] = sext i16 [[K]] to i64
+; CHECK-NEXT: [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP6]], i64 [[TMP7]])
+; CHECK-NEXT: [[T2:%.*]] = load <256 x i32>, <256 x i32>* [[PB:%.*]], align 64
+; CHECK-NEXT: [[TMP9:%.*]] = bitcast <256 x i32>* [[TMP3]] to i8*
+; CHECK-NEXT: store <256 x i32> [[T2]], <256 x i32>* [[TMP3]], align 1024
+; CHECK-NEXT: [[TMP10:%.*]] = sext i16 [[N:%.*]] to i64
+; CHECK-NEXT: [[TMP11:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[N]], i8* [[TMP9]], i64 [[TMP10]])
+; CHECK-NEXT: [[T4:%.*]] = load <256 x i32>, <256 x i32>* [[PC:%.*]], align 64
+; CHECK-NEXT: [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP2]] to i8*
+; CHECK-NEXT: store <256 x i32> [[T4]], <256 x i32>* [[TMP2]], align 1024
+; CHECK-NEXT: [[TMP13:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT: [[TMP14:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP12]], i64 [[TMP13]])
+; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbsud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP14]], x86_amx [[TMP8]], x86_amx [[TMP11]])
+; CHECK-NEXT: [[TMP15:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8*
+; CHECK-NEXT: [[TMP16:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP15]], i64 [[TMP16]], x86_amx [[T6]])
+; CHECK-NEXT: [[TMP17:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024
+; CHECK-NEXT: store <256 x i32> [[TMP17]], <256 x i32>* [[PC]], align 64
+; CHECK-NEXT: ret void
+;
+ %t0 = load <256 x i32>, <256 x i32>* %pa, align 64
+ %t1 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t0)
+ %t2 = load <256 x i32>, <256 x i32>* %pb, align 64
+ %t3 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t2)
+ %t4 = load <256 x i32>, <256 x i32>* %pc, align 64
+ %t5 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t4)
+ %t6 = tail call x86_amx @llvm.x86.tdpbsud.internal(i16 %m, i16 %n, i16 %k, x86_amx %t5, x86_amx %t1, x86_amx %t3)
+ %t7 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t6)
+ store <256 x i32> %t7, <256 x i32>* %pc, align 64
+ ret void
+}
+
+define dso_local void @__tile_dpbusd(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) {
+; CHECK-LABEL: @__tile_dpbusd(
+; CHECK-NEXT: [[TMP1:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[TMP2:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[TMP3:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[TMP4:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[TMP5:%.*]] = udiv i16 [[K:%.*]], 4
+; CHECK-NEXT: [[T0:%.*]] = load <256 x i32>, <256 x i32>* [[PA:%.*]], align 64
+; CHECK-NEXT: [[TMP6:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8*
+; CHECK-NEXT: store <256 x i32> [[T0]], <256 x i32>* [[TMP4]], align 1024
+; CHECK-NEXT: [[TMP7:%.*]] = sext i16 [[K]] to i64
+; CHECK-NEXT: [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP6]], i64 [[TMP7]])
+; CHECK-NEXT: [[T2:%.*]] = load <256 x i32>, <256 x i32>* [[PB:%.*]], align 64
+; CHECK-NEXT: [[TMP9:%.*]] = bitcast <256 x i32>* [[TMP3]] to i8*
+; CHECK-NEXT: store <256 x i32> [[T2]], <256 x i32>* [[TMP3]], align 1024
+; CHECK-NEXT: [[TMP10:%.*]] = sext i16 [[N:%.*]] to i64
+; CHECK-NEXT: [[TMP11:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[N]], i8* [[TMP9]], i64 [[TMP10]])
+; CHECK-NEXT: [[T4:%.*]] = load <256 x i32>, <256 x i32>* [[PC:%.*]], align 64
+; CHECK-NEXT: [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP2]] to i8*
+; CHECK-NEXT: store <256 x i32> [[T4]], <256 x i32>* [[TMP2]], align 1024
+; CHECK-NEXT: [[TMP13:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT: [[TMP14:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP12]], i64 [[TMP13]])
+; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbusd.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP14]], x86_amx [[TMP8]], x86_amx [[TMP11]])
+; CHECK-NEXT: [[TMP15:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8*
+; CHECK-NEXT: [[TMP16:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP15]], i64 [[TMP16]], x86_amx [[T6]])
+; CHECK-NEXT: [[TMP17:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024
+; CHECK-NEXT: store <256 x i32> [[TMP17]], <256 x i32>* [[PC]], align 64
+; CHECK-NEXT: ret void
+;
+ %t0 = load <256 x i32>, <256 x i32>* %pa, align 64
+ %t1 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t0)
+ %t2 = load <256 x i32>, <256 x i32>* %pb, align 64
+ %t3 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t2)
+ %t4 = load <256 x i32>, <256 x i32>* %pc, align 64
+ %t5 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t4)
+ %t6 = tail call x86_amx @llvm.x86.tdpbusd.internal(i16 %m, i16 %n, i16 %k, x86_amx %t5, x86_amx %t1, x86_amx %t3)
+ %t7 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t6)
+ store <256 x i32> %t7, <256 x i32>* %pc, align 64
+ ret void
+}
+
+define dso_local void @__tile_dpbuud(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) {
+; CHECK-LABEL: @__tile_dpbuud(
+; CHECK-NEXT: [[TMP1:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[TMP2:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[TMP3:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[TMP4:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[TMP5:%.*]] = udiv i16 [[K:%.*]], 4
+; CHECK-NEXT: [[T0:%.*]] = load <256 x i32>, <256 x i32>* [[PA:%.*]], align 64
+; CHECK-NEXT: [[TMP6:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8*
+; CHECK-NEXT: store <256 x i32> [[T0]], <256 x i32>* [[TMP4]], align 1024
+; CHECK-NEXT: [[TMP7:%.*]] = sext i16 [[K]] to i64
+; CHECK-NEXT: [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP6]], i64 [[TMP7]])
+; CHECK-NEXT: [[T2:%.*]] = load <256 x i32>, <256 x i32>* [[PB:%.*]], align 64
+; CHECK-NEXT: [[TMP9:%.*]] = bitcast <256 x i32>* [[TMP3]] to i8*
+; CHECK-NEXT: store <256 x i32> [[T2]], <256 x i32>* [[TMP3]], align 1024
+; CHECK-NEXT: [[TMP10:%.*]] = sext i16 [[N:%.*]] to i64
+; CHECK-NEXT: [[TMP11:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[N]], i8* [[TMP9]], i64 [[TMP10]])
+; CHECK-NEXT: [[T4:%.*]] = load <256 x i32>, <256 x i32>* [[PC:%.*]], align 64
+; CHECK-NEXT: [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP2]] to i8*
+; CHECK-NEXT: store <256 x i32> [[T4]], <256 x i32>* [[TMP2]], align 1024
+; CHECK-NEXT: [[TMP13:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT: [[TMP14:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP12]], i64 [[TMP13]])
+; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbuud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP14]], x86_amx [[TMP8]], x86_amx [[TMP11]])
+; CHECK-NEXT: [[TMP15:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8*
+; CHECK-NEXT: [[TMP16:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP15]], i64 [[TMP16]], x86_amx [[T6]])
+; CHECK-NEXT: [[TMP17:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024
+; CHECK-NEXT: store <256 x i32> [[TMP17]], <256 x i32>* [[PC]], align 64
+; CHECK-NEXT: ret void
+;
+ %t0 = load <256 x i32>, <256 x i32>* %pa, align 64
+ %t1 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t0)
+ %t2 = load <256 x i32>, <256 x i32>* %pb, align 64
+ %t3 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t2)
+ %t4 = load <256 x i32>, <256 x i32>* %pc, align 64
+ %t5 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t4)
+ %t6 = tail call x86_amx @llvm.x86.tdpbuud.internal(i16 %m, i16 %n, i16 %k, x86_amx %t5, x86_amx %t1, x86_amx %t3)
+ %t7 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t6)
+ store <256 x i32> %t7, <256 x i32>* %pc, align 64
+ ret void
+}
+
+define dso_local void @__tile_dpbf16ps(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) {
+; CHECK-LABEL: @__tile_dpbf16ps(
+; CHECK-NEXT: [[TMP1:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[TMP2:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[TMP3:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[TMP4:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[TMP5:%.*]] = udiv i16 [[K:%.*]], 4
+; CHECK-NEXT: [[T0:%.*]] = load <256 x i32>, <256 x i32>* [[PA:%.*]], align 64
+; CHECK-NEXT: [[TMP6:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8*
+; CHECK-NEXT: store <256 x i32> [[T0]], <256 x i32>* [[TMP4]], align 1024
+; CHECK-NEXT: [[TMP7:%.*]] = sext i16 [[K]] to i64
+; CHECK-NEXT: [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP6]], i64 [[TMP7]])
+; CHECK-NEXT: [[T2:%.*]] = load <256 x i32>, <256 x i32>* [[PB:%.*]], align 64
+; CHECK-NEXT: [[TMP9:%.*]] = bitcast <256 x i32>* [[TMP3]] to i8*
+; CHECK-NEXT: store <256 x i32> [[T2]], <256 x i32>* [[TMP3]], align 1024
+; CHECK-NEXT: [[TMP10:%.*]] = sext i16 [[N:%.*]] to i64
+; CHECK-NEXT: [[TMP11:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[N]], i8* [[TMP9]], i64 [[TMP10]])
+; CHECK-NEXT: [[T4:%.*]] = load <256 x i32>, <256 x i32>* [[PC:%.*]], align 64
+; CHECK-NEXT: [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP2]] to i8*
+; CHECK-NEXT: store <256 x i32> [[T4]], <256 x i32>* [[TMP2]], align 1024
+; CHECK-NEXT: [[TMP13:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT: [[TMP14:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP12]], i64 [[TMP13]])
+; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbf16ps.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP14]], x86_amx [[TMP8]], x86_amx [[TMP11]])
+; CHECK-NEXT: [[TMP15:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8*
+; CHECK-NEXT: [[TMP16:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP15]], i64 [[TMP16]], x86_amx [[T6]])
+; CHECK-NEXT: [[TMP17:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024
+; CHECK-NEXT: store <256 x i32> [[TMP17]], <256 x i32>* [[PC]], align 64
+; CHECK-NEXT: ret void
+;
+ %t0 = load <256 x i32>, <256 x i32>* %pa, align 64
+ %t1 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t0)
+ %t2 = load <256 x i32>, <256 x i32>* %pb, align 64
+ %t3 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t2)
+ %t4 = load <256 x i32>, <256 x i32>* %pc, align 64
+ %t5 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t4)
+ %t6 = tail call x86_amx @llvm.x86.tdpbf16ps.internal(i16 %m, i16 %n, i16 %k, x86_amx %t5, x86_amx %t1, x86_amx %t3)
+ %t7 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t6)
+ store <256 x i32> %t7, <256 x i32>* %pc, align 64
+ ret void
+}
+
+define dso_local void @__tile_stored(i8* %0, i64 %1, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %2) local_unnamed_addr {
+; CHECK-LABEL: @__tile_stored(
+; CHECK-NEXT: [[TMP4:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 0
+; CHECK-NEXT: [[TMP6:%.*]] = load i16, i16* [[TMP5]], align 64
+; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 1
+; CHECK-NEXT: [[TMP8:%.*]] = load i16, i16* [[TMP7]], align 2
+; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2
+; CHECK-NEXT: [[TMP10:%.*]] = load <256 x i32>, <256 x i32>* [[TMP9]], align 64
+; CHECK-NEXT: [[TMP11:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8*
+; CHECK-NEXT: store <256 x i32> [[TMP10]], <256 x i32>* [[TMP4]], align 1024
+; CHECK-NEXT: [[TMP12:%.*]] = sext i16 [[TMP8]] to i64
+; CHECK-NEXT: [[TMP13:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP6]], i16 [[TMP8]], i8* [[TMP11]], i64 [[TMP12]])
+; CHECK-NEXT: [[TMP14:%.*]] = shl i64 [[TMP1:%.*]], 32
+; CHECK-NEXT: [[TMP15:%.*]] = ashr exact i64 [[TMP14]], 32
+; CHECK-NEXT: tail call void @llvm.x86.tilestored64.internal(i16 [[TMP6]], i16 [[TMP8]], i8* [[TMP0:%.*]], i64 [[TMP15]], x86_amx [[TMP13]])
+; CHECK-NEXT: ret void
+;
+ %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 0
+ %5 = load i16, i16* %4, align 64
+ %6 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 1
+ %7 = load i16, i16* %6, align 2
+ %8 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 2
+ %9 = load <256 x i32>, <256 x i32>* %8, align 64
+ %10 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %9)
+ %11 = shl i64 %1, 32
+ %12 = ashr exact i64 %11, 32
+ tail call void @llvm.x86.tilestored64.internal(i16 %5, i16 %7, i8* %0, i64 %12, x86_amx %10)
+ ret void
+}
+
+declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
+declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
+declare x86_amx @llvm.x86.tdpbsud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
+declare x86_amx @llvm.x86.tdpbusd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
+declare x86_amx @llvm.x86.tdpbuud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
+declare x86_amx @llvm.x86.tdpbf16ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
+declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)
+
+declare x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32>)
+declare x86_amx @llvm.x86.cast.vector.to.tile.v225i32(<225 x i32>)
+declare <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx)
+declare <225 x i32> @llvm.x86.cast.tile.to.vector.v225i32(x86_amx)
More information about the llvm-commits
mailing list