[PATCH] D91927: [X86] Add x86_amx type for intel AMX.

Craig Topper via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 24 10:45:21 PST 2020


craig.topper added inline comments.


================
Comment at: llvm/lib/Target/X86/X86LowerAMXType.cpp:72
   LLVMContext &Ctx = Builder.getContext();
-  Type *Ty = LD->getType();
-  EVT VT = EVT::getEVT(Ty);
-  EVT HalfVT = VT.getHalfNumVectorElementsVT(Ctx);
-  Type *HalfTy = HalfVT.getTypeForEVT(Ctx);
-
-  Value *Ptr = LD->getPointerOperand();
-  PointerType *HalfPtrTy = HalfTy->getPointerTo(LD->getPointerAddressSpace());
-  Value *HalfPtr = Builder.CreateBitCast(Ptr, HalfPtrTy);
-  // The HW require the alignment for AMX tile is 64, but front-end generate
-  // code for the vector alignment which is the vector size.
-  uint64_t HalfTySize = HalfTy->getPrimitiveSizeInBits().getFixedSize() / 8;
-  Align Alignment = std::min(LD->getAlign(), Align(HalfTySize));
-  auto *Lo =
-      Builder.CreateAlignedLoad(HalfTy, HalfPtr, Alignment, LD->isVolatile());
-
-  HalfPtr = Builder.CreateGEP(HalfTy, HalfPtr, Builder.getInt32(1));
-  auto *Hi =
-      Builder.CreateAlignedLoad(HalfTy, HalfPtr, Alignment, LD->isVolatile());
-
-  LoadMap[Inst] = std::make_pair(Lo, Hi);
-}
-
-bool X86LowerAMXType::visitLD() {
-  if (LDSet.empty())
-    return false;
-  for (auto &Inst : LDSet) {
-    int Count = 0;
-    Value *NewInst = nullptr;
-    // The user should be all AMX intrinsics or all LLVM instruction.
-    // Don't support it is used by both AMX intrinsics and LLVM instructions.
-    for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) {
-      Use &U = *I++;
-      const IntrinsicInst *II = dyn_cast<IntrinsicInst>(U.getUser());
-      if (!II) {
-        Count++;
-        continue;
-      }
-      if (NewInst)
-        continue;
-      Value *Row, *Col;
-      switch (II->getIntrinsicID()) {
-      default:
-        report_fatal_error("Non-AMX intrinsic use tile type.");
-        break;
-      case Intrinsic::x86_tdpbssd_internal: {
-        unsigned OpNo = U.getOperandNo();
-        switch (OpNo) {
-        case 3:
-          Row = II->getArgOperand(0);
-          Col = II->getArgOperand(1);
-          break;
-        case 4:
-          Row = II->getArgOperand(0);
-          Col = II->getArgOperand(2);
-          break;
-        case 5:
-          Row = II->getArgOperand(2);
-          Col = II->getArgOperand(1);
-          break;
-        }
-        break;
-      }
-      case Intrinsic::x86_tilestored64_internal: {
-        Row = II->getArgOperand(0);
-        Col = II->getArgOperand(1);
-        break;
-      }
-      }
-      assert(Count == 0 && "Can NOT mix amx intrinsic and LLVM instruction");
-      // FIXME: The shape def should be ahead of load.
-      IRBuilder<> Builder(Inst);
-      LLVMContext &Ctx = Builder.getContext();
-      // Use the maximun column as stride.
-      Value *Stride = Builder.getInt64(64);
-      Value *I8Ptr =
-          Builder.CreateBitCast(Inst->getOperand(0), Type::getInt8PtrTy(Ctx));
-      std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
-
-      NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
-                                        None, Args);
-
-      Inst->replaceAllUsesWith(NewInst);
-    }
-    if (!NewInst)
-      splitLD(Inst);
+  AllocaInst *AllocaAddr = CreateAllocaInst(Builder, Bitcast->getParent());
+  Value *I8Ptr =
----------------
Shouldn't this be in the function's entry block?


================
Comment at: llvm/lib/Target/X86/X86LowerAMXType.cpp:89
+    // TODO we can pick an constant operand for the shape.
+    auto *Row = AMXIntrinsic->getOperand(0);
+    auto *Col = AMXIntrinsic->getOperand(1);
----------------
Just use Value. auto doesn't add any value other than shortening by 1 character.


================
Comment at: llvm/lib/Target/X86/X86LowerAMXType.cpp:178
+        LLVMContext &Ctx = Builder.getContext();
+        // Use the maximun column as stride. It must be the same with load
+        // stride.
----------------
maximun->maximum


================
Comment at: llvm/lib/Target/X86/X86LowerAMXType.cpp:182
+        Value *I8Ptr =
+            Builder.CreateBitCast(ST->getOperand(1), Type::getInt8PtrTy(Ctx));
+        std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
----------------
Use Builder.getInt8PtrTy then you don't need Ctx


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D91927/new/

https://reviews.llvm.org/D91927



More information about the llvm-commits mailing list