[llvm] bcec4cc - [X86] [AMX] Replace bitcast with specific AMX intrinsics with X86 specific cast.

Bing1 Yu via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 17 02:04:37 PDT 2021


Author: Bing1 Yu
Date: 2021-08-17T17:04:26+08:00
New Revision: bcec4ccd04ae678a0d17b8fe8170e04221bf1959

URL: https://github.com/llvm/llvm-project/commit/bcec4ccd04ae678a0d17b8fe8170e04221bf1959
DIFF: https://github.com/llvm/llvm-project/commit/bcec4ccd04ae678a0d17b8fe8170e04221bf1959.diff

LOG: [X86] [AMX] Replace bitcast with specific AMX intrinsics with X86 specific cast.

There is some discussion on the bitcast for vector and x86_amx at https://reviews.llvm.org/D99152. This patch is to introduce a x86 specific cast for vector and x86_amx, so that it can avoid some unnecessary optimization by middle-end. On the other way, we have to optimize the x86 specific cast by ourselves. This patch also optimize the cast operation to eliminate redundant code.

Reviewed By: LuoYuanke

Differential Revision: https://reviews.llvm.org/D107544

Added: 
    llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll
    llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll

Modified: 
    llvm/include/llvm/IR/IntrinsicsX86.td
    llvm/lib/Target/X86/X86LowerAMXType.cpp
    llvm/test/CodeGen/X86/AMX/amx-type.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/IntrinsicsX86.td b/llvm/include/llvm/IR/IntrinsicsX86.td
index ae0a416175f9e..eba83493e686d 100644
--- a/llvm/include/llvm/IR/IntrinsicsX86.td
+++ b/llvm/include/llvm/IR/IntrinsicsX86.td
@@ -5093,6 +5093,10 @@ let TargetPrefix = "x86" in {
                         [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
                          llvm_x86amx_ty, llvm_x86amx_ty,
                          llvm_x86amx_ty], []>;
+  def int_x86_cast_vector_to_tile:
+              Intrinsic<[llvm_x86amx_ty], [llvm_anyvector_ty], [IntrNoMem]>;
+  def int_x86_cast_tile_to_vector:
+              Intrinsic<[llvm_anyvector_ty], [llvm_x86amx_ty], [IntrNoMem]>;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp
index 4ba44ccb6c160..a2bcc98f3d5b6 100644
--- a/llvm/lib/Target/X86/X86LowerAMXType.cpp
+++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp
@@ -40,8 +40,10 @@
 //
 #include "X86.h"
 #include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/CodeGen/Passes.h"
 #include "llvm/CodeGen/TargetPassConfig.h"
@@ -56,66 +58,44 @@
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
 #include "llvm/Target/TargetMachine.h"
+#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
+#include "llvm/Transforms/Utils/Local.h"
 
 using namespace llvm;
 using namespace PatternMatch;
 
 #define DEBUG_TYPE "lower-amx-type"
 
-static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder,
-                                           BasicBlock *BB) {
+static bool isAMXCast(Instruction *II) {
+  return match(II,
+               m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) ||
+         match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
+}
+
+static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB,
+                                           Type *Ty) {
   Function &F = *BB->getParent();
   Module *M = BB->getModule();
   const DataLayout &DL = M->getDataLayout();
 
-  Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
   LLVMContext &Ctx = Builder.getContext();
   auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
   unsigned AllocaAS = DL.getAllocaAddrSpace();
   AllocaInst *AllocaRes =
-      new AllocaInst(V256I32Ty, AllocaAS, "", &F.getEntryBlock().front());
+      new AllocaInst(Ty, AllocaAS, "", &F.getEntryBlock().front());
   AllocaRes->setAlignment(AllocaAlignment);
   return AllocaRes;
 }
 
-namespace {
-class X86LowerAMXType {
-  Function &Func;
-  TargetMachine *TM = nullptr;
-
-  // In AMX intrinsics we let Shape = {Row, Col}, but the
-  // RealCol = Col / ElementSize. We may use the RealCol
-  // as a new Row for other new created AMX intrinsics.
-  std::map<Value *, Value *> Col2Row;
-
-public:
-  X86LowerAMXType(Function &F, TargetMachine *TargetM) : Func(F), TM(TargetM) {}
-  bool visit();
-  void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
-  void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
-  bool transformBitcast(BitCastInst *Bitcast);
-  std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo);
-  Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity);
-};
-
-Value *X86LowerAMXType::getRowFromCol(Instruction *II, Value *V,
-                                      unsigned Granularity) {
-  if (Col2Row.count(V))
-    return Col2Row[V];
-  IRBuilder<> Builder(&*II->getParent()->getFirstInsertionPt());
-  if (auto *I = dyn_cast<Instruction>(V)) {
-    BasicBlock::iterator Iter = I->getIterator();
-    ++Iter;
-    Builder.SetInsertPoint(&*Iter);
-  }
-  ConstantInt *Gran = Builder.getInt16(Granularity);
-  Value *RealRow = Builder.CreateUDiv(V, Gran);
-  Col2Row[V] = RealRow;
-  return RealRow;
+static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) {
+  for (Instruction &I : F.getEntryBlock())
+    if (!isa<AllocaInst>(&I))
+      return &I;
+  llvm_unreachable("No terminator in the entry block!");
 }
 
-std::pair<Value *, Value *> X86LowerAMXType::getShape(IntrinsicInst *II,
-                                                      unsigned OpNo) {
+static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
+  IRBuilder<> Builder(II);
   Value *Row = nullptr, *Col = nullptr;
   switch (II->getIntrinsicID()) {
   default:
@@ -144,14 +124,32 @@ std::pair<Value *, Value *> X86LowerAMXType::getShape(IntrinsicInst *II,
       Col = II->getArgOperand(2);
       break;
     case 5:
-      Row = II->getArgOperand(2);
-      // FIXME: There is a design bug for AMX shape, which the Col should be
-      // Col/4 if it will be used as Row, but current Greedy RA can't handle
-      // this case well, it may failed if we generate a new Shape definition.
-      // So Let's just do it in O0 first.
-      // Row = Row / 4
-      if (TM->getOptLevel() == CodeGenOpt::None)
-        Row = getRowFromCol(II, Row, 4);
+      if (isa<ConstantInt>(II->getArgOperand(2)))
+        Row = Builder.getInt16(
+            (dyn_cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4);
+      else if (isa<Instruction>(II->getArgOperand(2))) {
+        // When it is not a const value and it is not a function argument, we
+        // create Row after the definition of II->getOperand(2) instead of
+        // before II. For example, II is %118, we try to getshape for %117:
+        //   %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
+        //   i32> %115).
+        //   %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
+        //   %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
+        //   %117).
+        // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
+        // definition is after its user(new tileload for %117).
+        // So, the best choice is to create %row right after the definition of
+        // %106.
+        Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2)));
+        Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4));
+        cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2)));
+      } else {
+        // When it is not a const value and it is a function argument, we create
+        // Row at the entry bb.
+        IRBuilder<> NewBuilder(
+            getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
+        Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4));
+      }
       Col = II->getArgOperand(1);
       break;
     }
@@ -162,6 +160,40 @@ std::pair<Value *, Value *> X86LowerAMXType::getShape(IntrinsicInst *II,
   return std::make_pair(Row, Col);
 }
 
+namespace {
+class X86LowerAMXType {
+  Function &Func;
+
+  // In AMX intrinsics we let Shape = {Row, Col}, but the
+  // RealCol = Col / ElementSize. We may use the RealCol
+  // as a new Row for other new created AMX intrinsics.
+  std::map<Value *, Value *> Col2Row;
+
+public:
+  X86LowerAMXType(Function &F) : Func(F) {}
+  bool visit();
+  void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
+  void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
+  bool transformBitcast(BitCastInst *Bitcast);
+  Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity);
+};
+
+Value *X86LowerAMXType::getRowFromCol(Instruction *II, Value *V,
+                                      unsigned Granularity) {
+  if (Col2Row.count(V))
+    return Col2Row[V];
+  IRBuilder<> Builder(&*II->getParent()->getFirstInsertionPt());
+  if (auto *I = dyn_cast<Instruction>(V)) {
+    BasicBlock::iterator Iter = I->getIterator();
+    ++Iter;
+    Builder.SetInsertPoint(&*Iter);
+  }
+  ConstantInt *Gran = Builder.getInt16(Granularity);
+  Value *RealRow = Builder.CreateUDiv(V, Gran);
+  Col2Row[V] = RealRow;
+  return RealRow;
+}
+
 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
 // %2 = bitcast <256 x i32> %src to x86_amx
 // -->
@@ -230,8 +262,8 @@ bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
   Value *I8Ptr, *Stride;
   auto *Src = Bitcast->getOperand(0);
 
-  auto Prepare = [&]() {
-    AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent());
+  auto Prepare = [&](Type *MemTy) {
+    AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
     I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
     Stride = Builder.getInt64(64);
   };
@@ -250,7 +282,7 @@ bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
     auto *II = dyn_cast<IntrinsicInst>(U.getUser());
     if (!II)
       return false; // May be bitcast from x86amx to <256 x i32>.
-    Prepare();
+    Prepare(Bitcast->getOperand(0)->getType());
     Builder.CreateStore(Src, AllocaAddr);
     // TODO we can pick an constant operand for the shape.
     Value *Row = nullptr, *Col = nullptr;
@@ -270,7 +302,7 @@ bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
     auto *II = dyn_cast<IntrinsicInst>(Src);
     if (!II)
       return false; // May be bitcast from <256 x i32> to x86amx.
-    Prepare();
+    Prepare(Bitcast->getType());
     Value *Row = II->getOperand(0);
     Value *Col = II->getOperand(1);
     std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
@@ -637,6 +669,364 @@ bool X86VolatileTileData::volatileTileData() {
 
 namespace {
 
+class X86LowerAMXCast {
+  Function &Func;
+
+public:
+  X86LowerAMXCast(Function &F) : Func(F) {}
+  bool combineAMXcast(TargetLibraryInfo *TLI);
+  bool transformAMXCast(IntrinsicInst *AMXCast);
+  bool transformAllAMXCast();
+  bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
+                              SmallSetVector<Instruction *, 16> &DeadInst);
+};
+
+static bool DCEInstruction(Instruction *I,
+                           SmallSetVector<Instruction *, 16> &WorkList,
+                           const TargetLibraryInfo *TLI) {
+  if (isInstructionTriviallyDead(I, TLI)) {
+    salvageDebugInfo(*I);
+    salvageKnowledge(I);
+
+    // Null out all of the instruction's operands to see if any operand becomes
+    // dead as we go.
+    for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
+      Value *OpV = I->getOperand(i);
+      I->setOperand(i, nullptr);
+
+      if (!OpV->use_empty() || I == OpV)
+        continue;
+
+      // If the operand is an instruction that became dead as we nulled out the
+      // operand, and if it is 'trivially' dead, delete it in a future loop
+      // iteration.
+      if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
+        if (isInstructionTriviallyDead(OpI, TLI)) {
+          WorkList.insert(OpI);
+        }
+      }
+    }
+    I->eraseFromParent();
+    return true;
+  }
+  return false;
+}
+
+/// This function handles following case
+///
+///     A  ->  B    amxcast
+///     PHI
+///     B  ->  A    amxcast
+///
+/// All the related PHI nodes can be replaced by new PHI nodes with type A.
+/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
+bool X86LowerAMXCast::optimizeAMXCastFromPhi(
+    IntrinsicInst *CI, PHINode *PN,
+    SmallSetVector<Instruction *, 16> &DeadInst) {
+  IRBuilder<> Builder(CI);
+  Value *Src = CI->getOperand(0);
+  Type *SrcTy = Src->getType(); // Type B
+  Type *DestTy = CI->getType(); // Type A
+
+  SmallVector<PHINode *, 4> PhiWorklist;
+  SmallSetVector<PHINode *, 4> OldPhiNodes;
+
+  // Find all of the A->B casts and PHI nodes.
+  // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
+  // OldPhiNodes is used to track all known PHI nodes, before adding a new
+  // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
+  PhiWorklist.push_back(PN);
+  OldPhiNodes.insert(PN);
+  while (!PhiWorklist.empty()) {
+    auto *OldPN = PhiWorklist.pop_back_val();
+    for (Value *IncValue : OldPN->incoming_values()) {
+      // TODO: currently, We ignore cases where it is a const. In the future, we
+      // might support const.
+      if (isa<Constant>(IncValue))
+        return false;
+
+      if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
+        if (OldPhiNodes.insert(PNode))
+          PhiWorklist.push_back(PNode);
+        continue;
+      }
+      Instruction *ACI = dyn_cast<Instruction>(IncValue);
+      if (ACI && isAMXCast(ACI)) {
+        // Verify it's a A->B cast.
+        Type *TyA = ACI->getOperand(0)->getType();
+        Type *TyB = ACI->getType();
+        if (TyA != DestTy || TyB != SrcTy)
+          return false;
+        continue;
+      }
+      return false;
+    }
+  }
+
+  // Check that each user of each old PHI node is something that we can
+  // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
+  for (auto *OldPN : OldPhiNodes) {
+    for (User *V : OldPN->users()) {
+      Instruction *ACI = dyn_cast<Instruction>(V);
+      if (ACI && isAMXCast(ACI)) {
+        // Verify it's a B->A cast.
+        Type *TyB = ACI->getOperand(0)->getType();
+        Type *TyA = ACI->getType();
+        if (TyA != DestTy || TyB != SrcTy)
+          return false;
+      } else if (auto *PHI = dyn_cast<PHINode>(V)) {
+        // As long as the user is another old PHI node, then even if we don't
+        // rewrite it, the PHI web we're considering won't have any users
+        // outside itself, so it'll be dead.
+        // example:
+        //   bb.0:
+        //      %0 = amxcast ...
+        //   bb.1:
+        //      %1 = amxcast ...
+        //   bb.2:
+        //      %goodphi = phi %0, %1
+        //      %3 = amxcast %goodphi
+        //   bb.3:
+        //      %goodphi2 = phi %0, %goodphi
+        //      %4 = amxcast %goodphi2
+        // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
+        // outside the phi-web, so the combination stop When
+        // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
+        // will be done.
+        if (OldPhiNodes.count(PHI) == 0)
+          return false;
+      } else
+        return false;
+    }
+  }
+
+  // For each old PHI node, create a corresponding new PHI node with a type A.
+  SmallDenseMap<PHINode *, PHINode *> NewPNodes;
+  for (auto *OldPN : OldPhiNodes) {
+    Builder.SetInsertPoint(OldPN);
+    PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
+    NewPNodes[OldPN] = NewPN;
+  }
+
+  // Fill in the operands of new PHI nodes.
+  for (auto *OldPN : OldPhiNodes) {
+    PHINode *NewPN = NewPNodes[OldPN];
+    for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
+      Value *V = OldPN->getOperand(j);
+      Value *NewV = nullptr;
+      Instruction *ACI = dyn_cast<Instruction>(V);
+      // There should not be a AMXcast from a const.
+      if (ACI && isAMXCast(ACI))
+        NewV = ACI->getOperand(0);
+      else if (auto *PrevPN = dyn_cast<PHINode>(V))
+        NewV = NewPNodes[PrevPN];
+      assert(NewV);
+      NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
+    }
+  }
+
+  // Traverse all accumulated PHI nodes and process its users,
+  // which are Stores and BitcCasts. Without this processing
+  // NewPHI nodes could be replicated and could lead to extra
+  // moves generated after DeSSA.
+  // If there is a store with type B, change it to type A.
+
+  // Replace users of BitCast B->A with NewPHI. These will help
+  // later to get rid of a closure formed by OldPHI nodes.
+  for (auto *OldPN : OldPhiNodes) {
+    PHINode *NewPN = NewPNodes[OldPN];
+    for (User *V : make_early_inc_range(OldPN->users())) {
+      Instruction *ACI = dyn_cast<Instruction>(V);
+      if (ACI && isAMXCast(ACI)) {
+        Type *TyB = ACI->getOperand(0)->getType();
+        Type *TyA = ACI->getType();
+        assert(TyA == DestTy && TyB == SrcTy);
+        (void)TyA;
+        (void)TyB;
+        ACI->replaceAllUsesWith(NewPN);
+        DeadInst.insert(ACI);
+      } else if (auto *PHI = dyn_cast<PHINode>(V)) {
+        // We don't need to push PHINode into DeadInst since they are operands
+        // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
+        assert(OldPhiNodes.contains(PHI));
+        (void)PHI;
+      } else
+        llvm_unreachable("all uses should be handled");
+    }
+  }
+  return true;
+}
+
+bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
+  bool Change = false;
+  // Collect tile cast instruction.
+  SmallVector<Instruction *, 8> Vec2TileInsts;
+  SmallVector<Instruction *, 8> Tile2VecInsts;
+  SmallVector<Instruction *, 8> PhiCastWorkList;
+  SmallSetVector<Instruction *, 16> DeadInst;
+  for (BasicBlock &BB : Func) {
+    for (Instruction &I : BB) {
+      Value *Vec;
+      if (match(&I,
+                m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec))))
+        Vec2TileInsts.push_back(&I);
+      else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
+                             m_Value(Vec))))
+        Tile2VecInsts.push_back(&I);
+    }
+  }
+
+  auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
+    for (auto *Inst : Insts) {
+      for (User *U : Inst->users()) {
+        IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
+        if (!II || II->getIntrinsicID() != IID)
+          continue;
+        // T1 = vec2tile V0
+        // V2 = tile2vec T1
+        // V3 = OP V2
+        // -->
+        // T1 = vec2tile V0
+        // V2 = tile2vec T1
+        // V3 = OP V0
+        II->replaceAllUsesWith(Inst->getOperand(0));
+        Change = true;
+      }
+    }
+  };
+
+  Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
+  Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
+
+  auto EraseInst = [](SmallVectorImpl<Instruction *> &Insts) {
+    for (auto *Inst : Insts) {
+      if (Inst->use_empty())
+        Inst->eraseFromParent();
+    }
+  };
+
+  EraseInst(Vec2TileInsts);
+  EraseInst(Tile2VecInsts);
+
+  // Handle the A->B->A cast, and there is an intervening PHI node.
+  for (BasicBlock &BB : Func) {
+    for (Instruction &I : BB) {
+      if (isAMXCast(&I)) {
+        if (PHINode *PN = dyn_cast<PHINode>(I.getOperand(0)))
+          PhiCastWorkList.push_back(&I);
+      }
+    }
+  }
+  for (auto *I : PhiCastWorkList) {
+    // We skip the dead Amxcast.
+    if (DeadInst.contains(I))
+      continue;
+    PHINode *PN = cast<PHINode>(I->getOperand(0));
+    if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
+      DeadInst.insert(PN);
+      Change = true;
+    }
+  }
+
+  // Since we create new phi and merge AMXCast, some old phis and AMXCast might
+  // have no uses. We do some DeadCodeElimination for them.
+  while (!DeadInst.empty()) {
+    Instruction *I = DeadInst.pop_back_val();
+    Change |= DCEInstruction(I, DeadInst, TLI);
+  }
+  return Change;
+}
+
+// There might be remaining AMXcast after combineAMXcast and they should be
+// handled elegantly.
+bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
+  IRBuilder<> Builder(AMXCast);
+  AllocaInst *AllocaAddr;
+  Value *I8Ptr, *Stride;
+  auto *Src = AMXCast->getOperand(0);
+
+  auto Prepare = [&](Type *MemTy) {
+    AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
+    I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
+    Stride = Builder.getInt64(64);
+  };
+
+  if (AMXCast->getType()->isX86_AMXTy()) {
+    // %2 = amxcast <225 x i32> %src to x86_amx
+    // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
+    //                                           i8* %addr3, i64 60, x86_amx %2)
+    // -->
+    // %addr = alloca <225 x i32>, align 64
+    // store <225 x i32> %src, <225 x i32>* %addr, align 64
+    // %addr2 = bitcast <225 x i32>* %addr to i8*
+    // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
+    //                                                  i8* %addr2,
+    //                                                  i64 60)
+    // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
+    //                                           i8* %addr3, i64 60, x86_amx %2)
+    Use &U = *(AMXCast->use_begin());
+    unsigned OpNo = U.getOperandNo();
+    auto *II = dyn_cast<IntrinsicInst>(U.getUser());
+    if (!II)
+      return false; // May be bitcast from x86amx to <256 x i32>.
+    Prepare(AMXCast->getOperand(0)->getType());
+    Builder.CreateStore(Src, AllocaAddr);
+    // TODO we can pick an constant operand for the shape.
+    Value *Row = nullptr, *Col = nullptr;
+    std::tie(Row, Col) = getShape(II, OpNo);
+    std::array<Value *, 4> Args = {
+        Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
+    Value *NewInst = Builder.CreateIntrinsic(
+        Intrinsic::x86_tileloadd64_internal, None, Args);
+    AMXCast->replaceAllUsesWith(NewInst);
+    AMXCast->eraseFromParent();
+  } else {
+    // %2 = amxcast x86_amx %src to <225 x i32>
+    // -->
+    // %addr = alloca <225 x i32>, align 64
+    // %addr2 = bitcast <225 x i32>* to i8*
+    // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
+    //                                           i8* %addr2, i64 %stride)
+    // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
+    auto *II = dyn_cast<IntrinsicInst>(Src);
+    if (!II)
+      return false; // May be bitcast from <256 x i32> to x86amx.
+    Prepare(AMXCast->getType());
+    Value *Row = II->getOperand(0);
+    Value *Col = II->getOperand(1);
+    std::array<Value *, 5> Args = {
+        Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
+    Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
+    Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
+    AMXCast->replaceAllUsesWith(NewInst);
+    AMXCast->eraseFromParent();
+  }
+
+  return true;
+}
+
+bool X86LowerAMXCast::transformAllAMXCast() {
+  bool Change = false;
+  // Collect tile cast instruction.
+  SmallVector<Instruction *, 8> WorkLists;
+  for (BasicBlock &BB : Func) {
+    for (Instruction &I : BB) {
+      if (isAMXCast(&I))
+        WorkLists.push_back(&I);
+    }
+  }
+
+  for (auto *Inst : WorkLists) {
+    Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
+  }
+
+  return Change;
+}
+
+} // anonymous namespace
+
+namespace {
+
 class X86LowerAMXTypeLegacyPass : public FunctionPass {
 public:
   static char ID;
@@ -647,8 +1037,15 @@ class X86LowerAMXTypeLegacyPass : public FunctionPass {
 
   bool runOnFunction(Function &F) override {
     TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
+    TargetLibraryInfo *TLI =
+        &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
+    X86LowerAMXCast LAC(F);
+    LAC.combineAMXcast(TLI);
+    // There might be remaining AMXcast after combineAMXcast and they should be
+    // handled elegantly.
+    LAC.transformAllAMXCast();
 
-    X86LowerAMXType LAT(F, TM);
+    X86LowerAMXType LAT(F);
     bool C = LAT.visit();
 
     // Prepare for fast register allocation at O0.
@@ -671,6 +1068,7 @@ class X86LowerAMXTypeLegacyPass : public FunctionPass {
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.setPreservesCFG();
     AU.addRequired<TargetPassConfig>();
+    AU.addRequired<TargetLibraryInfoWrapperPass>();
   }
 };
 
@@ -681,6 +1079,7 @@ char X86LowerAMXTypeLegacyPass::ID = 0;
 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
                       false)
 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
+INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
                     false)
 

diff  --git a/llvm/test/CodeGen/X86/AMX/amx-type.ll b/llvm/test/CodeGen/X86/AMX/amx-type.ll
index 989a1076ce7a6..ddf650525baaa 100644
--- a/llvm/test/CodeGen/X86/AMX/amx-type.ll
+++ b/llvm/test/CodeGen/X86/AMX/amx-type.ll
@@ -163,18 +163,19 @@ define dso_local void @__tile_dpbssd(%struct.__tile_str* nocapture %0, %struct._
 ; CHECK-NEXT:    [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2
 ; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 1
 ; CHECK-NEXT:    [[TMP9:%.*]] = load i16, i16* [[TMP8]], align 2
-; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 2
-; CHECK-NEXT:    [[TMP11:%.*]] = bitcast <256 x i32>* [[TMP10]] to i8*
-; CHECK-NEXT:    [[TMP12:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP11]], i64 64)
-; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 2
-; CHECK-NEXT:    [[TMP14:%.*]] = bitcast <256 x i32>* [[TMP13]] to i8*
-; CHECK-NEXT:    [[TMP15:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP9]], i8* [[TMP14]], i64 64)
-; CHECK-NEXT:    [[TMP16:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2
-; CHECK-NEXT:    [[TMP17:%.*]] = bitcast <256 x i32>* [[TMP16]] to i8*
-; CHECK-NEXT:    [[TMP18:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP7]], i8* [[TMP17]], i64 64)
-; CHECK-NEXT:    [[TMP19:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[TMP5]], i16 [[TMP7]], i16 [[TMP9]], x86_amx [[TMP12]], x86_amx [[TMP15]], x86_amx [[TMP18]])
-; CHECK-NEXT:    [[TMP20:%.*]] = bitcast <256 x i32>* [[TMP10]] to i8*
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP20]], i64 64, x86_amx [[TMP19]])
+; CHECK-NEXT:    [[TMP10:%.*]] = udiv i16 [[TMP9]], 4
+; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP11]] to i8*
+; CHECK-NEXT:    [[TMP13:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP12]], i64 64)
+; CHECK-NEXT:    [[TMP14:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP15:%.*]] = bitcast <256 x i32>* [[TMP14]] to i8*
+; CHECK-NEXT:    [[TMP16:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP9]], i8* [[TMP15]], i64 64)
+; CHECK-NEXT:    [[TMP17:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP18:%.*]] = bitcast <256 x i32>* [[TMP17]] to i8*
+; CHECK-NEXT:    [[TMP19:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP10]], i16 [[TMP7]], i8* [[TMP18]], i64 64)
+; CHECK-NEXT:    [[TMP20:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[TMP5]], i16 [[TMP7]], i16 [[TMP9]], x86_amx [[TMP13]], x86_amx [[TMP16]], x86_amx [[TMP19]])
+; CHECK-NEXT:    [[TMP21:%.*]] = bitcast <256 x i32>* [[TMP11]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP21]], i64 64, x86_amx [[TMP20]])
 ; CHECK-NEXT:    ret void
 ;
   %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 0
@@ -200,15 +201,16 @@ define dso_local void @__tile_dpbssd(%struct.__tile_str* nocapture %0, %struct._
 
 define dso_local void @__tile_dpbsud(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) {
 ; CHECK-LABEL: @__tile_dpbsud(
-; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8*
-; CHECK-NEXT:    [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K:%.*]], i8* [[TMP1]], i64 64)
-; CHECK-NEXT:    [[TMP3:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8*
-; CHECK-NEXT:    [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[K]], i16 [[N:%.*]], i8* [[TMP3]], i64 64)
-; CHECK-NEXT:    [[TMP5:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8*
-; CHECK-NEXT:    [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP5]], i64 64)
-; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbsud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP6]], x86_amx [[TMP2]], x86_amx [[TMP4]])
-; CHECK-NEXT:    [[TMP7:%.*]] = bitcast <256 x i32>* [[PC]] to i8*
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP7]], i64 64, x86_amx [[T6]])
+; CHECK-NEXT:    [[TMP1:%.*]] = udiv i16 [[K:%.*]], 4
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8*
+; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP2]], i64 64)
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8*
+; CHECK-NEXT:    [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], i8* [[TMP4]], i64 64)
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8*
+; CHECK-NEXT:    [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP6]], i64 64)
+; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbsud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP7]], x86_amx [[TMP3]], x86_amx [[TMP5]])
+; CHECK-NEXT:    [[TMP8:%.*]] = bitcast <256 x i32>* [[PC]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP8]], i64 64, x86_amx [[T6]])
 ; CHECK-NEXT:    ret void
 ;
   %t0 = load <256 x i32>, <256 x i32>* %pa, align 64
@@ -225,15 +227,16 @@ define dso_local void @__tile_dpbsud(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <
 
 define dso_local void @__tile_dpbusd(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) {
 ; CHECK-LABEL: @__tile_dpbusd(
-; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8*
-; CHECK-NEXT:    [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K:%.*]], i8* [[TMP1]], i64 64)
-; CHECK-NEXT:    [[TMP3:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8*
-; CHECK-NEXT:    [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[K]], i16 [[N:%.*]], i8* [[TMP3]], i64 64)
-; CHECK-NEXT:    [[TMP5:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8*
-; CHECK-NEXT:    [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP5]], i64 64)
-; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbusd.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP6]], x86_amx [[TMP2]], x86_amx [[TMP4]])
-; CHECK-NEXT:    [[TMP7:%.*]] = bitcast <256 x i32>* [[PC]] to i8*
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP7]], i64 64, x86_amx [[T6]])
+; CHECK-NEXT:    [[TMP1:%.*]] = udiv i16 [[K:%.*]], 4
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8*
+; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP2]], i64 64)
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8*
+; CHECK-NEXT:    [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], i8* [[TMP4]], i64 64)
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8*
+; CHECK-NEXT:    [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP6]], i64 64)
+; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbusd.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP7]], x86_amx [[TMP3]], x86_amx [[TMP5]])
+; CHECK-NEXT:    [[TMP8:%.*]] = bitcast <256 x i32>* [[PC]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP8]], i64 64, x86_amx [[T6]])
 ; CHECK-NEXT:    ret void
 ;
   %t0 = load <256 x i32>, <256 x i32>* %pa, align 64
@@ -250,15 +253,16 @@ define dso_local void @__tile_dpbusd(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <
 
 define dso_local void @__tile_dpbuud(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) {
 ; CHECK-LABEL: @__tile_dpbuud(
-; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8*
-; CHECK-NEXT:    [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K:%.*]], i8* [[TMP1]], i64 64)
-; CHECK-NEXT:    [[TMP3:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8*
-; CHECK-NEXT:    [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[K]], i16 [[N:%.*]], i8* [[TMP3]], i64 64)
-; CHECK-NEXT:    [[TMP5:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8*
-; CHECK-NEXT:    [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP5]], i64 64)
-; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbuud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP6]], x86_amx [[TMP2]], x86_amx [[TMP4]])
-; CHECK-NEXT:    [[TMP7:%.*]] = bitcast <256 x i32>* [[PC]] to i8*
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP7]], i64 64, x86_amx [[T6]])
+; CHECK-NEXT:    [[TMP1:%.*]] = udiv i16 [[K:%.*]], 4
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8*
+; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP2]], i64 64)
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8*
+; CHECK-NEXT:    [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], i8* [[TMP4]], i64 64)
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8*
+; CHECK-NEXT:    [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP6]], i64 64)
+; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbuud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP7]], x86_amx [[TMP3]], x86_amx [[TMP5]])
+; CHECK-NEXT:    [[TMP8:%.*]] = bitcast <256 x i32>* [[PC]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP8]], i64 64, x86_amx [[T6]])
 ; CHECK-NEXT:    ret void
 ;
   %t0 = load <256 x i32>, <256 x i32>* %pa, align 64
@@ -275,15 +279,16 @@ define dso_local void @__tile_dpbuud(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <
 
 define dso_local void @__tile_dpbf16ps(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) {
 ; CHECK-LABEL: @__tile_dpbf16ps(
-; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8*
-; CHECK-NEXT:    [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K:%.*]], i8* [[TMP1]], i64 64)
-; CHECK-NEXT:    [[TMP3:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8*
-; CHECK-NEXT:    [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[K]], i16 [[N:%.*]], i8* [[TMP3]], i64 64)
-; CHECK-NEXT:    [[TMP5:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8*
-; CHECK-NEXT:    [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP5]], i64 64)
-; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbf16ps.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP6]], x86_amx [[TMP2]], x86_amx [[TMP4]])
-; CHECK-NEXT:    [[TMP7:%.*]] = bitcast <256 x i32>* [[PC]] to i8*
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP7]], i64 64, x86_amx [[T6]])
+; CHECK-NEXT:    [[TMP1:%.*]] = udiv i16 [[K:%.*]], 4
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8*
+; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP2]], i64 64)
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8*
+; CHECK-NEXT:    [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], i8* [[TMP4]], i64 64)
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8*
+; CHECK-NEXT:    [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP6]], i64 64)
+; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbf16ps.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP7]], x86_amx [[TMP3]], x86_amx [[TMP5]])
+; CHECK-NEXT:    [[TMP8:%.*]] = bitcast <256 x i32>* [[PC]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP8]], i64 64, x86_amx [[T6]])
 ; CHECK-NEXT:    ret void
 ;
   %t0 = load <256 x i32>, <256 x i32>* %pa, align 64

diff  --git a/llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll b/llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll
new file mode 100644
index 0000000000000..4aa5c7e3e1b9a
--- /dev/null
+++ b/llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll
@@ -0,0 +1,412 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt --codegen-opt-level=2 -mtriple=x86_64 -lower-amx-type %s -S | FileCheck %s
+
+define void @combine_amx_cast_inside_bb() {
+; CHECK-LABEL: @combine_amx_cast_inside_bb(
+; CHECK-NEXT:  wrapper_entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP0]])
+; CHECK-NEXT:    ret void
+;
+wrapper_entry:
+  %0 = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+  %tmp = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %0)
+  %1 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %tmp)
+  call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %1)
+  ret void
+}
+
+; Cases where amxcast can be combined across bb
+; %5 and %6 is combined together since %goodphi's incoming is phi or amxcast
+define void @combine_amx_cast_and_phi() {
+; CHECK-LABEL: @combine_amx_cast_and_phi(
+; CHECK-NEXT:  wrapper_entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = alloca <560 x i8>, align 64
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca <616 x i8>, align 64
+; CHECK-NEXT:    [[TMP2:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+; CHECK-NEXT:    br i1 undef, label [[FOR_COND_CLEANUP_I_I:%.*]], label [[FOR_BODY_I_LR_PH_I:%.*]]
+; CHECK:       for.body.i.lr.ph.i:
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <110 x i32>* [[TMP2]] to i8*
+; CHECK-NEXT:    store <110 x i32> undef, <110 x i32>* [[TMP2]], align 512
+; CHECK-NEXT:    [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP4]], i64 40)
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast <616 x i8>* [[TMP1]] to i8*
+; CHECK-NEXT:    store <616 x i8> undef, <616 x i8>* [[TMP1]], align 1024
+; CHECK-NEXT:    [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP6]], i64 56)
+; CHECK-NEXT:    [[TMP8:%.*]] = bitcast <560 x i8>* [[TMP0]] to i8*
+; CHECK-NEXT:    store <560 x i8> undef, <560 x i8>* [[TMP0]], align 1024
+; CHECK-NEXT:    [[TMP9:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP8]], i64 40)
+; CHECK-NEXT:    [[TMP10:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP5]], x86_amx [[TMP7]], x86_amx [[TMP9]])
+; CHECK-NEXT:    br label [[FOR_COND_CLEANUP_I_I]]
+; CHECK:       for.cond.cleanup.i.i:
+; CHECK-NEXT:    [[TMP11:%.*]] = phi x86_amx [ [[TMP3]], [[WRAPPER_ENTRY:%.*]] ], [ [[TMP10]], [[FOR_BODY_I_LR_PH_I]] ]
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP11]])
+; CHECK-NEXT:    ret void
+;
+wrapper_entry:
+  %0 = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+  %tmp = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %0)
+  br i1 undef, label %for.cond.cleanup.i.i, label %for.body.i.lr.ph.i
+
+for.body.i.lr.ph.i:                               ; preds = %wrapper_entry
+  %1 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> undef)
+  %2 = call x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8> undef)
+  %3 = call x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8> undef)
+  %4 = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx %1, x86_amx %2, x86_amx %3)
+  %5 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %4)
+  br label %for.cond.cleanup.i.i
+
+for.cond.cleanup.i.i:                             ; preds = %for.body.i.lr.ph.i, %wrapper_entry
+  %goodphi = phi <110 x i32> [ %tmp, %wrapper_entry ], [ %5, %for.body.i.lr.ph.i ]
+  %6 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %goodphi)
+  call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %6)
+  ret void
+}
+
+; Cases where amxcast can't be combined across bb
+; %5 and %6 is not combined together since %evilphi's incoming is not phi or amxcast
+define void @fail_to_combine_amx_cast_and_phi(<110 x i32> %tmp) {
+; CHECK-LABEL: @fail_to_combine_amx_cast_and_phi(
+; CHECK-NEXT:  wrapper_entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT:    [[TMP2:%.*]] = alloca <560 x i8>, align 64
+; CHECK-NEXT:    [[TMP3:%.*]] = alloca <616 x i8>, align 64
+; CHECK-NEXT:    [[TMP4:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT:    [[TMP5:%.*]] = add <110 x i32> [[TMP:%.*]], [[TMP]]
+; CHECK-NEXT:    br i1 undef, label [[FOR_COND_CLEANUP_I_I:%.*]], label [[FOR_BODY_I_LR_PH_I:%.*]]
+; CHECK:       for.body.i.lr.ph.i:
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast <110 x i32>* [[TMP4]] to i8*
+; CHECK-NEXT:    store <110 x i32> undef, <110 x i32>* [[TMP4]], align 512
+; CHECK-NEXT:    [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP6]], i64 40)
+; CHECK-NEXT:    [[TMP8:%.*]] = bitcast <616 x i8>* [[TMP3]] to i8*
+; CHECK-NEXT:    store <616 x i8> undef, <616 x i8>* [[TMP3]], align 1024
+; CHECK-NEXT:    [[TMP9:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP8]], i64 56)
+; CHECK-NEXT:    [[TMP10:%.*]] = bitcast <560 x i8>* [[TMP2]] to i8*
+; CHECK-NEXT:    store <560 x i8> undef, <560 x i8>* [[TMP2]], align 1024
+; CHECK-NEXT:    [[TMP11:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP10]], i64 40)
+; CHECK-NEXT:    [[TMP12:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP7]], x86_amx [[TMP9]], x86_amx [[TMP11]])
+; CHECK-NEXT:    [[TMP13:%.*]] = bitcast <110 x i32>* [[TMP1]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* [[TMP13]], i64 40, x86_amx [[TMP12]])
+; CHECK-NEXT:    [[TMP14:%.*]] = load <110 x i32>, <110 x i32>* [[TMP1]], align 512
+; CHECK-NEXT:    br label [[FOR_COND_CLEANUP_I_I]]
+; CHECK:       for.cond.cleanup.i.i:
+; CHECK-NEXT:    [[EVILPHI:%.*]] = phi <110 x i32> [ [[TMP5]], [[WRAPPER_ENTRY:%.*]] ], [ [[TMP14]], [[FOR_BODY_I_LR_PH_I]] ]
+; CHECK-NEXT:    [[TMP15:%.*]] = bitcast <110 x i32>* [[TMP0]] to i8*
+; CHECK-NEXT:    store <110 x i32> [[EVILPHI]], <110 x i32>* [[TMP0]], align 512
+; CHECK-NEXT:    [[TMP16:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP15]], i64 40)
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP16]])
+; CHECK-NEXT:    ret void
+;
+wrapper_entry:
+  %0 = add <110 x i32> %tmp, %tmp
+  br i1 undef, label %for.cond.cleanup.i.i, label %for.body.i.lr.ph.i
+
+for.body.i.lr.ph.i:                               ; preds = %wrapper_entry
+  %1 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> undef)
+  %2 = call x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8> undef)
+  %3 = call x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8> undef)
+  %4 = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx %1, x86_amx %2, x86_amx %3)
+  %5 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %4)
+  br label %for.cond.cleanup.i.i
+
+for.cond.cleanup.i.i:                             ; preds = %for.body.i.lr.ph.i, %wrapper_entry
+  %evilphi = phi <110 x i32> [ %0, %wrapper_entry ], [ %5, %for.body.i.lr.ph.i ]
+  %6 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %evilphi)
+  call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %6)
+  ret void
+}
+
+; Cases where amxcast can't be combined across bb
+; %5 and %6 is not combined together since %evilphi's user aka %evilphi2 is not inside phi web.
+define void @fail_to_combine_amx_cast_and_phi2() {
+; CHECK-LABEL: @fail_to_combine_amx_cast_and_phi2(
+; CHECK-NEXT:  wrapper_entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT:    [[TMP2:%.*]] = alloca <560 x i8>, align 64
+; CHECK-NEXT:    [[TMP3:%.*]] = alloca <616 x i8>, align 64
+; CHECK-NEXT:    [[TMP4:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT:    [[TMP5:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT:    [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+; CHECK-NEXT:    [[TMP7:%.*]] = bitcast <110 x i32>* [[TMP5]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* [[TMP7]], i64 40, x86_amx [[TMP6]])
+; CHECK-NEXT:    [[TMP8:%.*]] = load <110 x i32>, <110 x i32>* [[TMP5]], align 512
+; CHECK-NEXT:    br i1 undef, label [[FOR_COND_CLEANUP_I_I:%.*]], label [[FOR_BODY_I_LR_PH_I:%.*]]
+; CHECK:       for.body.i.lr.ph.i:
+; CHECK-NEXT:    [[TMP9:%.*]] = bitcast <110 x i32>* [[TMP4]] to i8*
+; CHECK-NEXT:    store <110 x i32> undef, <110 x i32>* [[TMP4]], align 512
+; CHECK-NEXT:    [[TMP10:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP9]], i64 40)
+; CHECK-NEXT:    [[TMP11:%.*]] = bitcast <616 x i8>* [[TMP3]] to i8*
+; CHECK-NEXT:    store <616 x i8> undef, <616 x i8>* [[TMP3]], align 1024
+; CHECK-NEXT:    [[TMP12:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP11]], i64 56)
+; CHECK-NEXT:    [[TMP13:%.*]] = bitcast <560 x i8>* [[TMP2]] to i8*
+; CHECK-NEXT:    store <560 x i8> undef, <560 x i8>* [[TMP2]], align 1024
+; CHECK-NEXT:    [[TMP14:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP13]], i64 40)
+; CHECK-NEXT:    [[TMP15:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP10]], x86_amx [[TMP12]], x86_amx [[TMP14]])
+; CHECK-NEXT:    [[TMP16:%.*]] = bitcast <110 x i32>* [[TMP1]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* [[TMP16]], i64 40, x86_amx [[TMP15]])
+; CHECK-NEXT:    [[TMP17:%.*]] = load <110 x i32>, <110 x i32>* [[TMP1]], align 512
+; CHECK-NEXT:    br i1 undef, label [[FOR_COND_CLEANUP_I_I]], label [[EXIT:%.*]]
+; CHECK:       for.cond.cleanup.i.i:
+; CHECK-NEXT:    [[GOODPHI:%.*]] = phi <110 x i32> [ [[TMP8]], [[WRAPPER_ENTRY:%.*]] ], [ [[TMP17]], [[FOR_BODY_I_LR_PH_I]] ]
+; CHECK-NEXT:    [[TMP18:%.*]] = bitcast <110 x i32>* [[TMP0]] to i8*
+; CHECK-NEXT:    store <110 x i32> [[GOODPHI]], <110 x i32>* [[TMP0]], align 512
+; CHECK-NEXT:    [[TMP19:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP18]], i64 40)
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP19]])
+; CHECK-NEXT:    br i1 undef, label [[EXIT]], label [[FOR_BODY_I_LR_PH_I]]
+; CHECK:       exit:
+; CHECK-NEXT:    [[EVILPHI2:%.*]] = phi <110 x i32> [ [[GOODPHI]], [[FOR_COND_CLEANUP_I_I]] ], [ [[TMP17]], [[FOR_BODY_I_LR_PH_I]] ]
+; CHECK-NEXT:    store <110 x i32> [[EVILPHI2]], <110 x i32>* undef, align 512
+; CHECK-NEXT:    ret void
+;
+wrapper_entry:
+  %0 = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+  %tmp = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %0)
+  br i1 undef, label %for.cond.cleanup.i.i, label %for.body.i.lr.ph.i
+
+for.body.i.lr.ph.i:                               ; preds = %wrapper_entry
+  %1 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> undef)
+  %2 = call x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8> undef)
+  %3 = call x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8> undef)
+  %4 = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx %1, x86_amx %2, x86_amx %3)
+  %5 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %4)
+  br i1 undef, label %for.cond.cleanup.i.i, label %exit
+
+for.cond.cleanup.i.i:                             ; preds = %for.body.i.lr.ph.i, %wrapper_entry
+  %goodphi = phi <110 x i32> [ %tmp, %wrapper_entry ], [ %5, %for.body.i.lr.ph.i ]
+  %6 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %goodphi)
+  call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %6)
+  br i1 undef, label %exit, label %for.body.i.lr.ph.i
+exit:
+  %evilphi2 = phi <110 x i32> [ %goodphi, %for.cond.cleanup.i.i ], [ %5, %for.body.i.lr.ph.i ]
+  store <110 x i32> %evilphi2, <110 x i32>* undef, align 512
+  ret void
+}
+
+define void @fail_to_combine_amx_cast_and_phi_due_to_const_value() {
+; CHECK-LABEL: @fail_to_combine_amx_cast_and_phi_due_to_const_value(
+; CHECK-NEXT:  wrapper_entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT:    [[TMP2:%.*]] = alloca <560 x i8>, align 64
+; CHECK-NEXT:    [[TMP3:%.*]] = alloca <616 x i8>, align 64
+; CHECK-NEXT:    [[TMP4:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT:    br i1 undef, label [[FOR_COND_CLEANUP_I_I:%.*]], label [[FOR_BODY_I_LR_PH_I:%.*]]
+; CHECK:       for.body.i.lr.ph.i:
+; CHECK-NEXT:    [[TMP5:%.*]] = bitcast <110 x i32>* [[TMP4]] to i8*
+; CHECK-NEXT:    store <110 x i32> undef, <110 x i32>* [[TMP4]], align 512
+; CHECK-NEXT:    [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP5]], i64 40)
+; CHECK-NEXT:    [[TMP7:%.*]] = bitcast <616 x i8>* [[TMP3]] to i8*
+; CHECK-NEXT:    store <616 x i8> undef, <616 x i8>* [[TMP3]], align 1024
+; CHECK-NEXT:    [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP7]], i64 56)
+; CHECK-NEXT:    [[TMP9:%.*]] = bitcast <560 x i8>* [[TMP2]] to i8*
+; CHECK-NEXT:    store <560 x i8> undef, <560 x i8>* [[TMP2]], align 1024
+; CHECK-NEXT:    [[TMP10:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP9]], i64 40)
+; CHECK-NEXT:    [[TMP11:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP6]], x86_amx [[TMP8]], x86_amx [[TMP10]])
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <110 x i32>* [[TMP1]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* [[TMP12]], i64 40, x86_amx [[TMP11]])
+; CHECK-NEXT:    [[TMP13:%.*]] = load <110 x i32>, <110 x i32>* [[TMP1]], align 512
+; CHECK-NEXT:    br label [[FOR_COND_CLEANUP_I_I]]
+; CHECK:       for.cond.cleanup.i.i:
+; CHECK-NEXT:    [[EVILPHI:%.*]] = phi <110 x i32> [ undef, [[WRAPPER_ENTRY:%.*]] ], [ [[TMP13]], [[FOR_BODY_I_LR_PH_I]] ]
+; CHECK-NEXT:    [[TMP14:%.*]] = bitcast <110 x i32>* [[TMP0]] to i8*
+; CHECK-NEXT:    store <110 x i32> [[EVILPHI]], <110 x i32>* [[TMP0]], align 512
+; CHECK-NEXT:    [[TMP15:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP14]], i64 40)
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP15]])
+; CHECK-NEXT:    ret void
+;
+wrapper_entry:
+  br i1 undef, label %for.cond.cleanup.i.i, label %for.body.i.lr.ph.i
+
+for.body.i.lr.ph.i:                               ; preds = %wrapper_entry
+  %0 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> undef)
+  %1 = call x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8> undef)
+  %2 = call x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8> undef)
+  %3 = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx %0, x86_amx %1, x86_amx %2)
+  %4 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %3)
+  br label %for.cond.cleanup.i.i
+
+for.cond.cleanup.i.i:                             ; preds = %for.body.i.lr.ph.i, %wrapper_entry
+  %evilphi = phi <110 x i32> [ undef, %wrapper_entry ], [ %4, %for.body.i.lr.ph.i ]
+  %5 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %evilphi)
+  call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %5)
+  ret void
+}
+
+; Cases where amxcast can be combined across bb
+; When optimizeAMXCastFromPhi process %6 and %goodphi, %goodphi2 is outside the phi-web, so the optimization stop
+; When optimizeAMXCastFromPhi process %7 and %goodphi2, the optimization continue.
+define void @combine_amx_cast_and_multiple_phi() {
+; CHECK-LABEL: @combine_amx_cast_and_multiple_phi(
+; CHECK-NEXT:  wrapper_entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = alloca <560 x i8>, align 64
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca <616 x i8>, align 64
+; CHECK-NEXT:    [[TMP2:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+; CHECK-NEXT:    br i1 undef, label [[FOR_COND_CLEANUP_I_I:%.*]], label [[FOR_BODY_I_LR_PH_I:%.*]]
+; CHECK:       for.body.i.lr.ph.i:
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <110 x i32>* [[TMP2]] to i8*
+; CHECK-NEXT:    store <110 x i32> undef, <110 x i32>* [[TMP2]], align 512
+; CHECK-NEXT:    [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP4]], i64 40)
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast <616 x i8>* [[TMP1]] to i8*
+; CHECK-NEXT:    store <616 x i8> undef, <616 x i8>* [[TMP1]], align 1024
+; CHECK-NEXT:    [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP6]], i64 56)
+; CHECK-NEXT:    [[TMP8:%.*]] = bitcast <560 x i8>* [[TMP0]] to i8*
+; CHECK-NEXT:    store <560 x i8> undef, <560 x i8>* [[TMP0]], align 1024
+; CHECK-NEXT:    [[TMP9:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP8]], i64 40)
+; CHECK-NEXT:    [[TMP10:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP5]], x86_amx [[TMP7]], x86_amx [[TMP9]])
+; CHECK-NEXT:    br i1 undef, label [[FOR_COND_CLEANUP_I_I]], label [[EXIT:%.*]]
+; CHECK:       for.cond.cleanup.i.i:
+; CHECK-NEXT:    [[TMP11:%.*]] = phi x86_amx [ [[TMP3]], [[WRAPPER_ENTRY:%.*]] ], [ [[TMP10]], [[FOR_BODY_I_LR_PH_I]] ]
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP11]])
+; CHECK-NEXT:    br i1 undef, label [[EXIT]], label [[FOR_BODY_I_LR_PH_I]]
+; CHECK:       exit:
+; CHECK-NEXT:    [[TMP12:%.*]] = phi x86_amx [ [[TMP11]], [[FOR_COND_CLEANUP_I_I]] ], [ [[TMP10]], [[FOR_BODY_I_LR_PH_I]] ]
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP12]])
+; CHECK-NEXT:    ret void
+;
+wrapper_entry:
+  %0 = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+  %tmp = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %0)
+  br i1 undef, label %for.cond.cleanup.i.i, label %for.body.i.lr.ph.i
+
+for.body.i.lr.ph.i:                               ; preds = %wrapper_entry
+  %1 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> undef)
+  %2 = call x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8> undef)
+  %3 = call x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8> undef)
+  %4 = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx %1, x86_amx %2, x86_amx %3)
+  %5 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %4)
+  br i1 undef, label %for.cond.cleanup.i.i, label %exit
+
+for.cond.cleanup.i.i:                             ; preds = %for.body.i.lr.ph.i, %wrapper_entry
+  %goodphi = phi <110 x i32> [ %tmp, %wrapper_entry ], [ %5, %for.body.i.lr.ph.i ]
+  %6 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %goodphi)
+  call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %6)
+  br i1 undef, label %exit, label %for.body.i.lr.ph.i
+exit:
+  %evilphi2 = phi <110 x i32> [ %goodphi, %for.cond.cleanup.i.i ], [ %5, %for.body.i.lr.ph.i ]
+  %7 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %evilphi2)
+  call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %7)
+  ret void
+}
+
+; Currently we are not able to delete DeadPHICycle, later we will handle with them
+define void @combine_amx_cast_and_phi_in_a_circle() {
+; CHECK-LABEL: @combine_amx_cast_and_phi_in_a_circle(
+; CHECK-NEXT:  wrapper_entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca <560 x i8>, align 64
+; CHECK-NEXT:    [[TMP2:%.*]] = alloca <616 x i8>, align 64
+; CHECK-NEXT:    [[TMP3:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT:    [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+; CHECK-NEXT:    br label [[BB1:%.*]]
+; CHECK:       bb1:
+; CHECK-NEXT:    [[TMP5:%.*]] = bitcast <110 x i32>* [[TMP3]] to i8*
+; CHECK-NEXT:    store <110 x i32> undef, <110 x i32>* [[TMP3]], align 512
+; CHECK-NEXT:    [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP5]], i64 40)
+; CHECK-NEXT:    [[TMP7:%.*]] = bitcast <616 x i8>* [[TMP2]] to i8*
+; CHECK-NEXT:    store <616 x i8> undef, <616 x i8>* [[TMP2]], align 1024
+; CHECK-NEXT:    [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP7]], i64 56)
+; CHECK-NEXT:    [[TMP9:%.*]] = bitcast <560 x i8>* [[TMP1]] to i8*
+; CHECK-NEXT:    store <560 x i8> undef, <560 x i8>* [[TMP1]], align 1024
+; CHECK-NEXT:    [[TMP10:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP9]], i64 40)
+; CHECK-NEXT:    [[TMP11:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP6]], x86_amx [[TMP8]], x86_amx [[TMP10]])
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <110 x i32>* [[TMP0]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* [[TMP12]], i64 40, x86_amx [[TMP11]])
+; CHECK-NEXT:    [[TMP13:%.*]] = load <110 x i32>, <110 x i32>* [[TMP0]], align 512
+; CHECK-NEXT:    br i1 undef, label [[BB2:%.*]], label [[BB3:%.*]]
+; CHECK:       bb2:
+; CHECK-NEXT:    [[TMP14:%.*]] = phi x86_amx [ [[TMP15:%.*]], [[BB3]] ], [ [[TMP11]], [[BB1]] ]
+; CHECK-NEXT:    [[GOODPHI:%.*]] = phi <110 x i32> [ [[EVILPHI2:%.*]], [[BB3]] ], [ [[TMP13]], [[BB1]] ]
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP14]])
+; CHECK-NEXT:    br label [[BB3]]
+; CHECK:       bb3:
+; CHECK-NEXT:    [[TMP15]] = phi x86_amx [ [[TMP14]], [[BB2]] ], [ [[TMP11]], [[BB1]] ]
+; CHECK-NEXT:    [[EVILPHI2]] = phi <110 x i32> [ [[GOODPHI]], [[BB2]] ], [ [[TMP13]], [[BB1]] ]
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP15]])
+; CHECK-NEXT:    br i1 undef, label [[BB2]], label [[EXIT:%.*]]
+; CHECK:       exit:
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP15]])
+; CHECK-NEXT:    ret void
+;
+wrapper_entry:
+  %0 = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+  %tmp = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %0)
+  br label %bb1
+
+bb1:                               ; preds = %wrapper_entry
+  %1 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> undef)
+  %2 = call x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8> undef)
+  %3 = call x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8> undef)
+  %4 = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx %1, x86_amx %2, x86_amx %3)
+  %5 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %4)
+  br i1 undef, label %bb2, label %bb3
+
+bb2:                             ; preds = %bb1, %wrapper_entry
+  %goodphi = phi <110 x i32> [ %evilphi2, %bb3], [ %5, %bb1 ]
+  %6 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %goodphi)
+  call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %6)
+  br label %bb3
+bb3:
+  %evilphi2 = phi <110 x i32> [ %goodphi, %bb2 ], [ %5, %bb1 ]
+  %7 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %evilphi2)
+  call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %7)
+  br i1 undef, label %bb2, label %exit
+exit:
+  %8 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %evilphi2)
+  call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %8)
+  ret void
+}
+
+define void @eliminate_unused_phi_and_cast() {
+; CHECK-LABEL: @eliminate_unused_phi_and_cast(
+; CHECK-NEXT:  wrapper_entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = alloca <560 x i8>, align 64
+; CHECK-NEXT:    [[TMP1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+; CHECK-NEXT:    br i1 undef, label [[FOR_COND_CLEANUP_I_I:%.*]], label [[FOR_BODY_I_LR_PH_I:%.*]]
+; CHECK:       for.body.i.lr.ph.i:
+; CHECK-NEXT:    [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* undef, i64 undef)
+; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* undef, i64 undef)
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <560 x i8>* [[TMP0]] to i8*
+; CHECK-NEXT:    store <560 x i8> undef, <560 x i8>* [[TMP0]], align 1024
+; CHECK-NEXT:    [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP4]], i64 40)
+; CHECK-NEXT:    [[TMP6:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP2]], x86_amx [[TMP3]], x86_amx [[TMP5]])
+; CHECK-NEXT:    br label [[FOR_COND_CLEANUP_I_I]]
+; CHECK:       for.cond.cleanup.i.i:
+; CHECK-NEXT:    [[TMP7:%.*]] = phi x86_amx [ [[TMP1]], [[WRAPPER_ENTRY:%.*]] ], [ [[TMP6]], [[FOR_BODY_I_LR_PH_I]] ]
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP7]])
+; CHECK-NEXT:    ret void
+;
+wrapper_entry:
+  %0 = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef)
+  %tmp = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %0)
+  br i1 undef, label %for.cond.cleanup.i.i, label %for.body.i.lr.ph.i
+
+for.body.i.lr.ph.i:                               ; preds = %wrapper_entry
+  %1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* undef, i64 undef)
+  %v1 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %1)
+  %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* undef, i64 undef)
+  %v2 = call <616 x i8> @llvm.x86.cast.tile.to.vector.v616i8(x86_amx %2)
+  %3 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %v1)
+  %4 = call x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8> %v2)
+  %5 = call x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8> undef)
+  %6 = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx %3, x86_amx %4, x86_amx %5)
+  %7 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %6)
+  br label %for.cond.cleanup.i.i
+
+for.cond.cleanup.i.i:                             ; preds = %for.body.i.lr.ph.i, %wrapper_entry
+  %goodphi = phi <110 x i32> [ %tmp, %wrapper_entry ], [ %7, %for.body.i.lr.ph.i ]
+  %8 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %goodphi)
+  call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %8)
+  ret void
+}
+
+declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
+declare <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx)
+declare <616 x i8> @llvm.x86.cast.tile.to.vector.v616i8(x86_amx)
+declare x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32>)
+declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)
+declare x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8>)
+declare x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8>)
+declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)

diff  --git a/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll b/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll
new file mode 100644
index 0000000000000..98a820197bbd6
--- /dev/null
+++ b/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll
@@ -0,0 +1,429 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt --codegen-opt-level=2 -mtriple=x86_64 -lower-amx-type %s -S | FileCheck %s
+
+%struct.__tile_str = type { i16, i16, <256 x i32> }
+
+ at buf = dso_local global [1024 x i8] zeroinitializer, align 64
+ at buf2 = dso_local global [1024 x i8] zeroinitializer, align 64
+
+; test bitcast x86_amx to <256 x i32>
+define dso_local void @test_user_empty(i16 %m, i16 %n, i8 *%buf, i64 %s) {
+; CHECK-LABEL: @test_user_empty(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[T1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[N:%.*]], i8* [[BUF:%.*]], i64 [[S:%.*]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %t1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %n, i8* %buf, i64 %s)
+  %t2 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t1)
+  ret void
+}
+
+; test bitcast <256 x i32> to x86_amx
+define dso_local void @test_user_empty2(<256 x i32> %in) {
+; CHECK-LABEL: @test_user_empty2(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %t = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %in)
+  ret void
+}
+
+define dso_local <256 x i32> @test_amx_load_bitcast_v256i32(<256 x i32>* %in, i16 %m, i16 %n, i8 *%buf, i64 %s) {
+; CHECK-LABEL: @test_amx_load_bitcast_v256i32(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[T1:%.*]] = load <256 x i32>, <256 x i32>* [[IN:%.*]], align 64
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8*
+; CHECK-NEXT:    store <256 x i32> [[T1]], <256 x i32>* [[TMP0]], align 1024
+; CHECK-NEXT:    [[TMP2:%.*]] = sext i16 [[N:%.*]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[N]], i8* [[TMP1]], i64 [[TMP2]])
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[BUF:%.*]], i64 [[S:%.*]], x86_amx [[TMP3]])
+; CHECK-NEXT:    ret <256 x i32> [[T1]]
+;
+entry:
+  %t1 = load <256 x i32>, <256 x i32>* %in, align 64
+  %t2 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t1)
+  call void @llvm.x86.tilestored64.internal(i16 %m, i16 %n, i8* %buf, i64 %s, x86_amx %t2)
+  ret <256 x i32> %t1
+}
+
+define dso_local <225 x i32> @test_amx_load_bitcast_v225i32(<225 x i32>* %in, i16 %m, i16 %n, i8 *%buf, i64 %s) {
+; CHECK-LABEL: @test_amx_load_bitcast_v225i32(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = alloca <225 x i32>, align 64
+; CHECK-NEXT:    [[T1:%.*]] = load <225 x i32>, <225 x i32>* [[IN:%.*]], align 64
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <225 x i32>* [[TMP0]] to i8*
+; CHECK-NEXT:    store <225 x i32> [[T1]], <225 x i32>* [[TMP0]], align 1024
+; CHECK-NEXT:    [[TMP2:%.*]] = sext i16 [[N:%.*]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[N]], i8* [[TMP1]], i64 [[TMP2]])
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[BUF:%.*]], i64 [[S:%.*]], x86_amx [[TMP3]])
+; CHECK-NEXT:    ret <225 x i32> [[T1]]
+;
+entry:
+  %t1 = load <225 x i32>, <225 x i32>* %in, align 64
+  %t2 = call x86_amx @llvm.x86.cast.vector.to.tile.v225i32(<225 x i32> %t1)
+  call void @llvm.x86.tilestored64.internal(i16 %m, i16 %n, i8* %buf, i64 %s, x86_amx %t2)
+  ret <225 x i32> %t1
+}
+
+define dso_local <256 x i32> @test_amx_bitcast_store(<256 x i32>* %out, i16 %m, i16 %n, i8 *%buf, i64 %s) {
+; CHECK-LABEL: @test_amx_bitcast_store(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[T1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[M]], i8* [[BUF:%.*]], i64 [[S:%.*]])
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8*
+; CHECK-NEXT:    [[TMP2:%.*]] = sext i16 [[M]] to i64
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[M]], i8* [[TMP1]], i64 [[TMP2]], x86_amx [[T1]])
+; CHECK-NEXT:    [[TMP3:%.*]] = load <256 x i32>, <256 x i32>* [[TMP0]], align 1024
+; CHECK-NEXT:    store <256 x i32> [[TMP3]], <256 x i32>* [[OUT:%.*]], align 1024
+; CHECK-NEXT:    ret <256 x i32> [[TMP3]]
+;
+entry:
+  %t1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %m, i8* %buf, i64 %s)
+  %t2 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t1)
+  store <256 x i32> %t2, <256 x i32>* %out
+  ret <256 x i32> %t2
+}
+
+define dso_local void @test_src_add(<256 x i32> %x, <256 x i32> %y, i16 %r, i16 %c, i8* %buf, i64 %s) {
+; CHECK-LABEL: @test_src_add(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[ADD:%.*]] = add <256 x i32> [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8*
+; CHECK-NEXT:    store <256 x i32> [[ADD]], <256 x i32>* [[TMP0]], align 1024
+; CHECK-NEXT:    [[TMP2:%.*]] = sext i16 [[C:%.*]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[R:%.*]], i16 [[C]], i8* [[TMP1]], i64 [[TMP2]])
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[R]], i16 [[C]], i8* [[BUF:%.*]], i64 [[S:%.*]], x86_amx [[TMP3]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %add = add <256 x i32> %y, %x
+  %t = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %add)
+  call void @llvm.x86.tilestored64.internal(i16 %r, i16 %c, i8* %buf, i64 %s, x86_amx %t)
+  ret void
+}
+
+define dso_local void @test_src_add2(<256 x i32> %x, i16 %r, i16 %c, i8* %buf, i64 %s) {
+; CHECK-LABEL: @test_src_add2(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[T1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[R:%.*]], i16 [[C:%.*]], i8* [[BUF:%.*]], i64 [[S:%.*]])
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8*
+; CHECK-NEXT:    [[TMP2:%.*]] = sext i16 [[C]] to i64
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[R]], i16 [[C]], i8* [[TMP1]], i64 [[TMP2]], x86_amx [[T1]])
+; CHECK-NEXT:    [[TMP3:%.*]] = load <256 x i32>, <256 x i32>* [[TMP0]], align 1024
+; CHECK-NEXT:    [[ADD:%.*]] = add <256 x i32> [[TMP3]], [[X:%.*]]
+; CHECK-NEXT:    ret void
+;
+entry:
+  %t1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %r, i16 %c, i8* %buf, i64 %s)
+  %t2 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t1)
+  %add = add <256 x i32> %t2, %x
+  ret void
+}
+
+define dso_local void @__tile_loadd(%struct.__tile_str* nocapture %0, i8* %1, i64 %2) local_unnamed_addr {
+; CHECK-LABEL: @__tile_loadd(
+; CHECK-NEXT:    [[TMP4:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 0
+; CHECK-NEXT:    [[TMP6:%.*]] = load i16, i16* [[TMP5]], align 64
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0]], i64 0, i32 1
+; CHECK-NEXT:    [[TMP8:%.*]] = load i16, i16* [[TMP7]], align 2
+; CHECK-NEXT:    [[TMP9:%.*]] = shl i64 [[TMP2:%.*]], 32
+; CHECK-NEXT:    [[TMP10:%.*]] = ashr exact i64 [[TMP9]], 32
+; CHECK-NEXT:    [[TMP11:%.*]] = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP6]], i16 [[TMP8]], i8* [[TMP1:%.*]], i64 [[TMP10]])
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8*
+; CHECK-NEXT:    [[TMP13:%.*]] = sext i16 [[TMP8]] to i64
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[TMP6]], i16 [[TMP8]], i8* [[TMP12]], i64 [[TMP13]], x86_amx [[TMP11]])
+; CHECK-NEXT:    [[TMP14:%.*]] = load <256 x i32>, <256 x i32>* [[TMP4]], align 1024
+; CHECK-NEXT:    [[TMP15:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0]], i64 0, i32 2
+; CHECK-NEXT:    store <256 x i32> [[TMP14]], <256 x i32>* [[TMP15]], align 64
+; CHECK-NEXT:    ret void
+;
+  %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 0
+  %5 = load i16, i16* %4, align 64
+  %6 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 1
+  %7 = load i16, i16* %6, align 2
+  %8 = shl i64 %2, 32
+  %9 = ashr exact i64 %8, 32
+  %10 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %5, i16 %7, i8* %1, i64 %9)
+  %11 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %10)
+  %12 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 2
+  store <256 x i32> %11, <256 x i32>* %12, align 64
+  ret void
+}
+
+define dso_local void @__tile_dpbssd(%struct.__tile_str* nocapture %0, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %1, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %2) local_unnamed_addr {
+; CHECK-LABEL: @__tile_dpbssd(
+; CHECK-NEXT:    [[TMP4:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP5:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP6:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP7:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP1:%.*]], i64 0, i32 0
+; CHECK-NEXT:    [[TMP9:%.*]] = load i16, i16* [[TMP8]], align 64
+; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 1
+; CHECK-NEXT:    [[TMP11:%.*]] = load i16, i16* [[TMP10]], align 2
+; CHECK-NEXT:    [[TMP12:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 1
+; CHECK-NEXT:    [[TMP13:%.*]] = load i16, i16* [[TMP12]], align 2
+; CHECK-NEXT:    [[TMP14:%.*]] = udiv i16 [[TMP13]], 4
+; CHECK-NEXT:    [[TMP15:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP16:%.*]] = load <256 x i32>, <256 x i32>* [[TMP15]], align 64
+; CHECK-NEXT:    [[TMP17:%.*]] = bitcast <256 x i32>* [[TMP7]] to i8*
+; CHECK-NEXT:    store <256 x i32> [[TMP16]], <256 x i32>* [[TMP7]], align 1024
+; CHECK-NEXT:    [[TMP18:%.*]] = sext i16 [[TMP11]] to i64
+; CHECK-NEXT:    [[TMP19:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP11]], i8* [[TMP17]], i64 [[TMP18]])
+; CHECK-NEXT:    [[TMP20:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP21:%.*]] = load <256 x i32>, <256 x i32>* [[TMP20]], align 64
+; CHECK-NEXT:    [[TMP22:%.*]] = bitcast <256 x i32>* [[TMP6]] to i8*
+; CHECK-NEXT:    store <256 x i32> [[TMP21]], <256 x i32>* [[TMP6]], align 1024
+; CHECK-NEXT:    [[TMP23:%.*]] = sext i16 [[TMP13]] to i64
+; CHECK-NEXT:    [[TMP24:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP13]], i8* [[TMP22]], i64 [[TMP23]])
+; CHECK-NEXT:    [[TMP25:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP26:%.*]] = load <256 x i32>, <256 x i32>* [[TMP25]], align 64
+; CHECK-NEXT:    [[TMP27:%.*]] = bitcast <256 x i32>* [[TMP5]] to i8*
+; CHECK-NEXT:    store <256 x i32> [[TMP26]], <256 x i32>* [[TMP5]], align 1024
+; CHECK-NEXT:    [[TMP28:%.*]] = sext i16 [[TMP11]] to i64
+; CHECK-NEXT:    [[TMP29:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP14]], i16 [[TMP11]], i8* [[TMP27]], i64 [[TMP28]])
+; CHECK-NEXT:    [[TMP30:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[TMP9]], i16 [[TMP11]], i16 [[TMP13]], x86_amx [[TMP19]], x86_amx [[TMP24]], x86_amx [[TMP29]])
+; CHECK-NEXT:    [[TMP31:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8*
+; CHECK-NEXT:    [[TMP32:%.*]] = sext i16 [[TMP11]] to i64
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[TMP9]], i16 [[TMP11]], i8* [[TMP31]], i64 [[TMP32]], x86_amx [[TMP30]])
+; CHECK-NEXT:    [[TMP33:%.*]] = load <256 x i32>, <256 x i32>* [[TMP4]], align 1024
+; CHECK-NEXT:    store <256 x i32> [[TMP33]], <256 x i32>* [[TMP15]], align 64
+; CHECK-NEXT:    ret void
+;
+  %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 0
+  %5 = load i16, i16* %4, align 64
+  %6 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 1
+  %7 = load i16, i16* %6, align 2
+  %8 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 1
+  %9 = load i16, i16* %8, align 2
+  %10 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 2
+  %11 = load <256 x i32>, <256 x i32>* %10, align 64
+  %12 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %11)
+  %13 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 2
+  %14 = load <256 x i32>, <256 x i32>* %13, align 64
+  %15 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %14)
+  %16 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 2
+  %17 = load <256 x i32>, <256 x i32>* %16, align 64
+  %18 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %17)
+  %19 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %5, i16 %7, i16 %9, x86_amx %12, x86_amx %15, x86_amx %18)
+  %20 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %19)
+  store <256 x i32> %20, <256 x i32>* %10, align 64
+  ret void
+}
+
+define dso_local void @__tile_dpbsud(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) {
+; CHECK-LABEL: @__tile_dpbsud(
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP2:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP3:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP4:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP5:%.*]] = udiv i16 [[K:%.*]], 4
+; CHECK-NEXT:    [[T0:%.*]] = load <256 x i32>, <256 x i32>* [[PA:%.*]], align 64
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8*
+; CHECK-NEXT:    store <256 x i32> [[T0]], <256 x i32>* [[TMP4]], align 1024
+; CHECK-NEXT:    [[TMP7:%.*]] = sext i16 [[K]] to i64
+; CHECK-NEXT:    [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP6]], i64 [[TMP7]])
+; CHECK-NEXT:    [[T2:%.*]] = load <256 x i32>, <256 x i32>* [[PB:%.*]], align 64
+; CHECK-NEXT:    [[TMP9:%.*]] = bitcast <256 x i32>* [[TMP3]] to i8*
+; CHECK-NEXT:    store <256 x i32> [[T2]], <256 x i32>* [[TMP3]], align 1024
+; CHECK-NEXT:    [[TMP10:%.*]] = sext i16 [[N:%.*]] to i64
+; CHECK-NEXT:    [[TMP11:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[N]], i8* [[TMP9]], i64 [[TMP10]])
+; CHECK-NEXT:    [[T4:%.*]] = load <256 x i32>, <256 x i32>* [[PC:%.*]], align 64
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP2]] to i8*
+; CHECK-NEXT:    store <256 x i32> [[T4]], <256 x i32>* [[TMP2]], align 1024
+; CHECK-NEXT:    [[TMP13:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT:    [[TMP14:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP12]], i64 [[TMP13]])
+; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbsud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP14]], x86_amx [[TMP8]], x86_amx [[TMP11]])
+; CHECK-NEXT:    [[TMP15:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8*
+; CHECK-NEXT:    [[TMP16:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP15]], i64 [[TMP16]], x86_amx [[T6]])
+; CHECK-NEXT:    [[TMP17:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024
+; CHECK-NEXT:    store <256 x i32> [[TMP17]], <256 x i32>* [[PC]], align 64
+; CHECK-NEXT:    ret void
+;
+  %t0 = load <256 x i32>, <256 x i32>* %pa, align 64
+  %t1 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t0)
+  %t2 = load <256 x i32>, <256 x i32>* %pb, align 64
+  %t3 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t2)
+  %t4 = load <256 x i32>, <256 x i32>* %pc, align 64
+  %t5 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t4)
+  %t6 = tail call x86_amx @llvm.x86.tdpbsud.internal(i16 %m, i16 %n, i16 %k, x86_amx %t5, x86_amx %t1, x86_amx %t3)
+  %t7 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t6)
+  store <256 x i32> %t7, <256 x i32>* %pc, align 64
+  ret void
+}
+
+define dso_local void @__tile_dpbusd(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) {
+; CHECK-LABEL: @__tile_dpbusd(
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP2:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP3:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP4:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP5:%.*]] = udiv i16 [[K:%.*]], 4
+; CHECK-NEXT:    [[T0:%.*]] = load <256 x i32>, <256 x i32>* [[PA:%.*]], align 64
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8*
+; CHECK-NEXT:    store <256 x i32> [[T0]], <256 x i32>* [[TMP4]], align 1024
+; CHECK-NEXT:    [[TMP7:%.*]] = sext i16 [[K]] to i64
+; CHECK-NEXT:    [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP6]], i64 [[TMP7]])
+; CHECK-NEXT:    [[T2:%.*]] = load <256 x i32>, <256 x i32>* [[PB:%.*]], align 64
+; CHECK-NEXT:    [[TMP9:%.*]] = bitcast <256 x i32>* [[TMP3]] to i8*
+; CHECK-NEXT:    store <256 x i32> [[T2]], <256 x i32>* [[TMP3]], align 1024
+; CHECK-NEXT:    [[TMP10:%.*]] = sext i16 [[N:%.*]] to i64
+; CHECK-NEXT:    [[TMP11:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[N]], i8* [[TMP9]], i64 [[TMP10]])
+; CHECK-NEXT:    [[T4:%.*]] = load <256 x i32>, <256 x i32>* [[PC:%.*]], align 64
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP2]] to i8*
+; CHECK-NEXT:    store <256 x i32> [[T4]], <256 x i32>* [[TMP2]], align 1024
+; CHECK-NEXT:    [[TMP13:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT:    [[TMP14:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP12]], i64 [[TMP13]])
+; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbusd.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP14]], x86_amx [[TMP8]], x86_amx [[TMP11]])
+; CHECK-NEXT:    [[TMP15:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8*
+; CHECK-NEXT:    [[TMP16:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP15]], i64 [[TMP16]], x86_amx [[T6]])
+; CHECK-NEXT:    [[TMP17:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024
+; CHECK-NEXT:    store <256 x i32> [[TMP17]], <256 x i32>* [[PC]], align 64
+; CHECK-NEXT:    ret void
+;
+  %t0 = load <256 x i32>, <256 x i32>* %pa, align 64
+  %t1 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t0)
+  %t2 = load <256 x i32>, <256 x i32>* %pb, align 64
+  %t3 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t2)
+  %t4 = load <256 x i32>, <256 x i32>* %pc, align 64
+  %t5 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t4)
+  %t6 = tail call x86_amx @llvm.x86.tdpbusd.internal(i16 %m, i16 %n, i16 %k, x86_amx %t5, x86_amx %t1, x86_amx %t3)
+  %t7 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t6)
+  store <256 x i32> %t7, <256 x i32>* %pc, align 64
+  ret void
+}
+
+define dso_local void @__tile_dpbuud(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) {
+; CHECK-LABEL: @__tile_dpbuud(
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP2:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP3:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP4:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP5:%.*]] = udiv i16 [[K:%.*]], 4
+; CHECK-NEXT:    [[T0:%.*]] = load <256 x i32>, <256 x i32>* [[PA:%.*]], align 64
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8*
+; CHECK-NEXT:    store <256 x i32> [[T0]], <256 x i32>* [[TMP4]], align 1024
+; CHECK-NEXT:    [[TMP7:%.*]] = sext i16 [[K]] to i64
+; CHECK-NEXT:    [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP6]], i64 [[TMP7]])
+; CHECK-NEXT:    [[T2:%.*]] = load <256 x i32>, <256 x i32>* [[PB:%.*]], align 64
+; CHECK-NEXT:    [[TMP9:%.*]] = bitcast <256 x i32>* [[TMP3]] to i8*
+; CHECK-NEXT:    store <256 x i32> [[T2]], <256 x i32>* [[TMP3]], align 1024
+; CHECK-NEXT:    [[TMP10:%.*]] = sext i16 [[N:%.*]] to i64
+; CHECK-NEXT:    [[TMP11:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[N]], i8* [[TMP9]], i64 [[TMP10]])
+; CHECK-NEXT:    [[T4:%.*]] = load <256 x i32>, <256 x i32>* [[PC:%.*]], align 64
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP2]] to i8*
+; CHECK-NEXT:    store <256 x i32> [[T4]], <256 x i32>* [[TMP2]], align 1024
+; CHECK-NEXT:    [[TMP13:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT:    [[TMP14:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP12]], i64 [[TMP13]])
+; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbuud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP14]], x86_amx [[TMP8]], x86_amx [[TMP11]])
+; CHECK-NEXT:    [[TMP15:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8*
+; CHECK-NEXT:    [[TMP16:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP15]], i64 [[TMP16]], x86_amx [[T6]])
+; CHECK-NEXT:    [[TMP17:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024
+; CHECK-NEXT:    store <256 x i32> [[TMP17]], <256 x i32>* [[PC]], align 64
+; CHECK-NEXT:    ret void
+;
+  %t0 = load <256 x i32>, <256 x i32>* %pa, align 64
+  %t1 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t0)
+  %t2 = load <256 x i32>, <256 x i32>* %pb, align 64
+  %t3 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t2)
+  %t4 = load <256 x i32>, <256 x i32>* %pc, align 64
+  %t5 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t4)
+  %t6 = tail call x86_amx @llvm.x86.tdpbuud.internal(i16 %m, i16 %n, i16 %k, x86_amx %t5, x86_amx %t1, x86_amx %t3)
+  %t7 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t6)
+  store <256 x i32> %t7, <256 x i32>* %pc, align 64
+  ret void
+}
+
+define dso_local void @__tile_dpbf16ps(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) {
+; CHECK-LABEL: @__tile_dpbf16ps(
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP2:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP3:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP4:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP5:%.*]] = udiv i16 [[K:%.*]], 4
+; CHECK-NEXT:    [[T0:%.*]] = load <256 x i32>, <256 x i32>* [[PA:%.*]], align 64
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8*
+; CHECK-NEXT:    store <256 x i32> [[T0]], <256 x i32>* [[TMP4]], align 1024
+; CHECK-NEXT:    [[TMP7:%.*]] = sext i16 [[K]] to i64
+; CHECK-NEXT:    [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP6]], i64 [[TMP7]])
+; CHECK-NEXT:    [[T2:%.*]] = load <256 x i32>, <256 x i32>* [[PB:%.*]], align 64
+; CHECK-NEXT:    [[TMP9:%.*]] = bitcast <256 x i32>* [[TMP3]] to i8*
+; CHECK-NEXT:    store <256 x i32> [[T2]], <256 x i32>* [[TMP3]], align 1024
+; CHECK-NEXT:    [[TMP10:%.*]] = sext i16 [[N:%.*]] to i64
+; CHECK-NEXT:    [[TMP11:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[N]], i8* [[TMP9]], i64 [[TMP10]])
+; CHECK-NEXT:    [[T4:%.*]] = load <256 x i32>, <256 x i32>* [[PC:%.*]], align 64
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP2]] to i8*
+; CHECK-NEXT:    store <256 x i32> [[T4]], <256 x i32>* [[TMP2]], align 1024
+; CHECK-NEXT:    [[TMP13:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT:    [[TMP14:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP12]], i64 [[TMP13]])
+; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbf16ps.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP14]], x86_amx [[TMP8]], x86_amx [[TMP11]])
+; CHECK-NEXT:    [[TMP15:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8*
+; CHECK-NEXT:    [[TMP16:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP15]], i64 [[TMP16]], x86_amx [[T6]])
+; CHECK-NEXT:    [[TMP17:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024
+; CHECK-NEXT:    store <256 x i32> [[TMP17]], <256 x i32>* [[PC]], align 64
+; CHECK-NEXT:    ret void
+;
+  %t0 = load <256 x i32>, <256 x i32>* %pa, align 64
+  %t1 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t0)
+  %t2 = load <256 x i32>, <256 x i32>* %pb, align 64
+  %t3 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t2)
+  %t4 = load <256 x i32>, <256 x i32>* %pc, align 64
+  %t5 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t4)
+  %t6 = tail call x86_amx @llvm.x86.tdpbf16ps.internal(i16 %m, i16 %n, i16 %k, x86_amx %t5, x86_amx %t1, x86_amx %t3)
+  %t7 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t6)
+  store <256 x i32> %t7, <256 x i32>* %pc, align 64
+  ret void
+}
+
+define dso_local void @__tile_stored(i8* %0, i64 %1, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %2) local_unnamed_addr {
+; CHECK-LABEL: @__tile_stored(
+; CHECK-NEXT:    [[TMP4:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 0
+; CHECK-NEXT:    [[TMP6:%.*]] = load i16, i16* [[TMP5]], align 64
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 1
+; CHECK-NEXT:    [[TMP8:%.*]] = load i16, i16* [[TMP7]], align 2
+; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP10:%.*]] = load <256 x i32>, <256 x i32>* [[TMP9]], align 64
+; CHECK-NEXT:    [[TMP11:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8*
+; CHECK-NEXT:    store <256 x i32> [[TMP10]], <256 x i32>* [[TMP4]], align 1024
+; CHECK-NEXT:    [[TMP12:%.*]] = sext i16 [[TMP8]] to i64
+; CHECK-NEXT:    [[TMP13:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP6]], i16 [[TMP8]], i8* [[TMP11]], i64 [[TMP12]])
+; CHECK-NEXT:    [[TMP14:%.*]] = shl i64 [[TMP1:%.*]], 32
+; CHECK-NEXT:    [[TMP15:%.*]] = ashr exact i64 [[TMP14]], 32
+; CHECK-NEXT:    tail call void @llvm.x86.tilestored64.internal(i16 [[TMP6]], i16 [[TMP8]], i8* [[TMP0:%.*]], i64 [[TMP15]], x86_amx [[TMP13]])
+; CHECK-NEXT:    ret void
+;
+  %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 0
+  %5 = load i16, i16* %4, align 64
+  %6 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 1
+  %7 = load i16, i16* %6, align 2
+  %8 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 2
+  %9 = load <256 x i32>, <256 x i32>* %8, align 64
+  %10 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %9)
+  %11 = shl i64 %1, 32
+  %12 = ashr exact i64 %11, 32
+  tail call void @llvm.x86.tilestored64.internal(i16 %5, i16 %7, i8* %0, i64 %12, x86_amx %10)
+  ret void
+}
+
+declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
+declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
+declare x86_amx @llvm.x86.tdpbsud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
+declare x86_amx @llvm.x86.tdpbusd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
+declare x86_amx @llvm.x86.tdpbuud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
+declare x86_amx @llvm.x86.tdpbf16ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
+declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)
+
+declare x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32>)
+declare x86_amx @llvm.x86.cast.vector.to.tile.v225i32(<225 x i32>)
+declare <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx)
+declare <225 x i32> @llvm.x86.cast.tile.to.vector.v225i32(x86_amx)


        


More information about the llvm-commits mailing list