[llvm] 6753eb0 - [X86][AMX] Materialize undef or zero value to tilezero

via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 31 04:11:14 PDT 2022


Author: Luo, Yuanke
Date: 2022-03-31T19:10:28+08:00
New Revision: 6753eb0c90b93e3f864d657f7de5f9880a6f1d58

URL: https://github.com/llvm/llvm-project/commit/6753eb0c90b93e3f864d657f7de5f9880a6f1d58
DIFF: https://github.com/llvm/llvm-project/commit/6753eb0c90b93e3f864d657f7de5f9880a6f1d58.diff

LOG: [X86][AMX] Materialize undef or zero value to tilezero

The AMX combiner would store undef or zero to stack and invoke tileload
to load the data to tile register. To avoid the store/load, we can
materialzie undef or zero value to tilezero.

Differential Revision: https://reviews.llvm.org/D122714

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86LowerAMXType.cpp
    llvm/test/CodeGen/X86/AMX/amx-combine-undef.ll
    llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp
index b85bd824fdc33..0fb8d07b6b61e 100644
--- a/llvm/lib/Target/X86/X86LowerAMXType.cpp
+++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp
@@ -74,6 +74,24 @@ static bool isAMXCast(Instruction *II) {
          match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
 }
 
+static bool isAMXInstrinsic(User *I) {
+  auto *II = dyn_cast<IntrinsicInst>(I);
+  if (!II)
+    return false;
+  if (isAMXCast(II))
+    return false;
+  // Check if return type or parameter is x86_amx. If it is x86_amx
+  // the intrinsic must be x86 amx intrinsics.
+  if (II->getType()->isX86_AMXTy())
+    return true;
+  for (Value *V : II->args()) {
+    if (V->getType()->isX86_AMXTy())
+      return true;
+  }
+
+  return false;
+}
+
 static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB,
                                            Type *Ty) {
   Function &F = *BB->getParent();
@@ -162,6 +180,36 @@ static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
   return std::make_pair(Row, Col);
 }
 
+static std::pair<Value *, Value *> getShape(PHINode *Phi) {
+  Use &U = *(Phi->use_begin());
+  unsigned OpNo = U.getOperandNo();
+  User *V = U.getUser();
+  // TODO We don't traverse all users. To make the algorithm simple, here we
+  // just traverse the first user. If we can find shape, then return the shape,
+  // otherwise just return nullptr and the optimization for undef/zero will be
+  // abandoned.
+  while (V) {
+    if (isAMXCast(dyn_cast<Instruction>(V))) {
+      if (V->use_empty())
+        break;
+      Use &U = *(V->use_begin());
+      OpNo = U.getOperandNo();
+      V = U.getUser();
+    } else if (isAMXInstrinsic(V)) {
+      return getShape(cast<IntrinsicInst>(V), OpNo);
+    } else if (isa<PHINode>(V)) {
+      if (V->use_empty())
+        break;
+      Use &U = *(Phi->use_begin());
+      V = U.getUser();
+    } else {
+      break;
+    }
+  }
+
+  return std::make_pair(nullptr, nullptr);
+}
+
 namespace {
 class X86LowerAMXType {
   Function &Func;
@@ -720,11 +768,33 @@ bool X86LowerAMXCast::optimizeAMXCastFromPhi(
   OldPhiNodes.insert(PN);
   while (!PhiWorklist.empty()) {
     auto *OldPN = PhiWorklist.pop_back_val();
-    for (Value *IncValue : OldPN->incoming_values()) {
+    for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {
+      Value *IncValue = OldPN->getIncomingValue(I);
       // TODO: currently, We ignore cases where it is a const. In the future, we
       // might support const.
-      if (isa<Constant>(IncValue))
-        return false;
+      if (isa<Constant>(IncValue)) {
+        auto *IncConst = dyn_cast<Constant>(IncValue);
+        if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
+          return false;
+        Value *Row = nullptr, *Col = nullptr;
+        std::tie(Row, Col) = getShape(OldPN);
+        // TODO: If it is not constant the Row and Col must domoniate tilezero
+        // that we are going to create.
+        if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
+          return false;
+        // Create tilezero at the end of incoming block.
+        auto *Block = OldPN->getIncomingBlock(I);
+        BasicBlock::iterator Iter = Block->getTerminator()->getIterator();
+        Instruction *NewInst = Builder.CreateIntrinsic(
+            Intrinsic::x86_tilezero_internal, None, {Row, Col});
+        NewInst->moveBefore(&*Iter);
+        NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
+                                          {IncValue->getType()}, {NewInst});
+        NewInst->moveBefore(&*Iter);
+        // Replace InValue with new Value.
+        OldPN->setIncomingValue(I, NewInst);
+        IncValue = NewInst;
+      }
 
       if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
         if (OldPhiNodes.insert(PNode))

diff  --git a/llvm/test/CodeGen/X86/AMX/amx-combine-undef.ll b/llvm/test/CodeGen/X86/AMX/amx-combine-undef.ll
index 5a2d55c355579..67e4f3eea42bc 100644
--- a/llvm/test/CodeGen/X86/AMX/amx-combine-undef.ll
+++ b/llvm/test/CodeGen/X86/AMX/amx-combine-undef.ll
@@ -4,21 +4,14 @@
 define void @foo_undef(i8 *%buf) {
 ; CHECK-LABEL: @foo_undef(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP1:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP0:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 32)
 ; CHECK-NEXT:    br i1 undef, label [[L1:%.*]], label [[L2:%.*]]
 ; CHECK:       l1:
 ; CHECK-NEXT:    [[T1:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 32)
-; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8*
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 8, i16 32, i8* [[TMP2]], i64 32, x86_amx [[T1]])
-; CHECK-NEXT:    [[TMP3:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024
 ; CHECK-NEXT:    br i1 undef, label [[L2]], label [[EXIT:%.*]]
 ; CHECK:       l2:
-; CHECK-NEXT:    [[T3:%.*]] = phi <256 x i32> [ undef, [[ENTRY:%.*]] ], [ [[TMP3]], [[L1]] ]
-; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8*
-; CHECK-NEXT:    store <256 x i32> [[T3]], <256 x i32>* [[TMP0]], align 1024
-; CHECK-NEXT:    [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 32, i8* [[TMP4]], i64 32)
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 8, i16 32, i8* [[BUF:%.*]], i64 1024, x86_amx [[TMP5]])
+; CHECK-NEXT:    [[TMP1:%.*]] = phi x86_amx [ [[TMP0]], [[ENTRY:%.*]] ], [ [[T1]], [[L1]] ]
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 8, i16 32, i8* [[BUF:%.*]], i64 1024, x86_amx [[TMP1]])
 ; CHECK-NEXT:    br label [[EXIT]]
 ; CHECK:       exit:
 ; CHECK-NEXT:    ret void
@@ -44,21 +37,14 @@ exit:
 define void @foo_zero(i8 *%buf) {
 ; CHECK-LABEL: @foo_zero(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP1:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[TMP0:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 32)
 ; CHECK-NEXT:    br i1 undef, label [[L1:%.*]], label [[L2:%.*]]
 ; CHECK:       l1:
 ; CHECK-NEXT:    [[T1:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 32)
-; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8*
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 8, i16 32, i8* [[TMP2]], i64 32, x86_amx [[T1]])
-; CHECK-NEXT:    [[TMP3:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024
 ; CHECK-NEXT:    br i1 undef, label [[L2]], label [[EXIT:%.*]]
 ; CHECK:       l2:
-; CHECK-NEXT:    [[T3:%.*]] = phi <256 x i32> [ zeroinitializer, [[ENTRY:%.*]] ], [ [[TMP3]], [[L1]] ]
-; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8*
-; CHECK-NEXT:    store <256 x i32> [[T3]], <256 x i32>* [[TMP0]], align 1024
-; CHECK-NEXT:    [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 32, i8* [[TMP4]], i64 32)
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 8, i16 32, i8* [[BUF:%.*]], i64 1024, x86_amx [[TMP5]])
+; CHECK-NEXT:    [[TMP1:%.*]] = phi x86_amx [ [[TMP0]], [[ENTRY:%.*]] ], [ [[T1]], [[L1]] ]
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 8, i16 32, i8* [[BUF:%.*]], i64 1024, x86_amx [[TMP1]])
 ; CHECK-NEXT:    br label [[EXIT]]
 ; CHECK:       exit:
 ; CHECK-NEXT:    ret void
@@ -163,8 +149,8 @@ exit:
   ret void
 }
 
-define void @foo_noshape(i8 *%buf) {
-; CHECK-LABEL: @foo_noshape(
+define void @noshape(i8 *%buf) {
+; CHECK-LABEL: @noshape(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[TMP0:%.*]] = alloca <256 x i32>, align 64
 ; CHECK-NEXT:    br i1 undef, label [[L1:%.*]], label [[L2:%.*]]
@@ -202,6 +188,48 @@ exit:
   ret void
 }
 
+define void @noshape2(i8 *%buf) {
+; CHECK-LABEL: @noshape2(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    br i1 undef, label [[L1:%.*]], label [[L2:%.*]]
+; CHECK:       l1:
+; CHECK-NEXT:    [[T1:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 32)
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 8, i16 32, i8* [[TMP1]], i64 32, x86_amx [[T1]])
+; CHECK-NEXT:    [[TMP2:%.*]] = load <256 x i32>, <256 x i32>* [[TMP0]], align 1024
+; CHECK-NEXT:    br i1 undef, label [[L2]], label [[EXIT:%.*]]
+; CHECK:       l2:
+; CHECK-NEXT:    [[T3:%.*]] = phi <256 x i32> [ undef, [[ENTRY:%.*]] ], [ [[TMP2]], [[L1]] ]
+; CHECK-NEXT:    [[T6:%.*]] = call <256 x i32> @llvm.abs.v256i32(<256 x i32> [[T3]], i1 true)
+; CHECK-NEXT:    [[P:%.*]] = bitcast i8* [[BUF:%.*]] to <256 x i32>*
+; CHECK-NEXT:    store <256 x i32> [[T6]], <256 x i32>* [[P]], align 1024
+; CHECK-NEXT:    br label [[EXIT]]
+; CHECK:       exit:
+; CHECK-NEXT:    ret void
+;
+entry:
+  br i1 undef, label %l1, label %l2
+
+l1:
+  %t1 = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 32)
+  %t2 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t1)
+  br i1 undef, label %l2, label %exit
+
+l2:
+  %t3 = phi <256 x i32> [ undef, %entry ], [ %t2, %l1 ]
+  %t4 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t3)
+  %t5 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t4)
+  %t6 = call <256 x i32> @llvm.abs.v256i32(<256 x i32> %t5, i1 1)
+  %p = bitcast i8* %buf to <256 x i32>*
+  store <256 x i32> %t6, <256 x i32>* %p
+  br label %exit
+
+exit:
+  ret void
+}
+
+declare <256 x i32> @llvm.abs.v256i32(<256 x i32>, i1)
 declare x86_amx @llvm.x86.tilezero.internal(i16, i16)
 declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
 declare <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx)

diff  --git a/llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll b/llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll
index 4aa5c7e3e1b9a..e62d8ebfc70b0 100644
--- a/llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll
+++ b/llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll
@@ -187,33 +187,26 @@ exit:
 define void @fail_to_combine_amx_cast_and_phi_due_to_const_value() {
 ; CHECK-LABEL: @fail_to_combine_amx_cast_and_phi_due_to_const_value(
 ; CHECK-NEXT:  wrapper_entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = alloca <110 x i32>, align 64
-; CHECK-NEXT:    [[TMP1:%.*]] = alloca <110 x i32>, align 64
-; CHECK-NEXT:    [[TMP2:%.*]] = alloca <560 x i8>, align 64
-; CHECK-NEXT:    [[TMP3:%.*]] = alloca <616 x i8>, align 64
-; CHECK-NEXT:    [[TMP4:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT:    [[TMP0:%.*]] = alloca <560 x i8>, align 64
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca <616 x i8>, align 64
+; CHECK-NEXT:    [[TMP2:%.*]] = alloca <110 x i32>, align 64
+; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 11, i16 40)
 ; CHECK-NEXT:    br i1 undef, label [[FOR_COND_CLEANUP_I_I:%.*]], label [[FOR_BODY_I_LR_PH_I:%.*]]
 ; CHECK:       for.body.i.lr.ph.i:
-; CHECK-NEXT:    [[TMP5:%.*]] = bitcast <110 x i32>* [[TMP4]] to i8*
-; CHECK-NEXT:    store <110 x i32> undef, <110 x i32>* [[TMP4]], align 512
-; CHECK-NEXT:    [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP5]], i64 40)
-; CHECK-NEXT:    [[TMP7:%.*]] = bitcast <616 x i8>* [[TMP3]] to i8*
-; CHECK-NEXT:    store <616 x i8> undef, <616 x i8>* [[TMP3]], align 1024
-; CHECK-NEXT:    [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP7]], i64 56)
-; CHECK-NEXT:    [[TMP9:%.*]] = bitcast <560 x i8>* [[TMP2]] to i8*
-; CHECK-NEXT:    store <560 x i8> undef, <560 x i8>* [[TMP2]], align 1024
-; CHECK-NEXT:    [[TMP10:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP9]], i64 40)
-; CHECK-NEXT:    [[TMP11:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP6]], x86_amx [[TMP8]], x86_amx [[TMP10]])
-; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <110 x i32>* [[TMP1]] to i8*
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* [[TMP12]], i64 40, x86_amx [[TMP11]])
-; CHECK-NEXT:    [[TMP13:%.*]] = load <110 x i32>, <110 x i32>* [[TMP1]], align 512
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <110 x i32>* [[TMP2]] to i8*
+; CHECK-NEXT:    store <110 x i32> undef, <110 x i32>* [[TMP2]], align 512
+; CHECK-NEXT:    [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP4]], i64 40)
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast <616 x i8>* [[TMP1]] to i8*
+; CHECK-NEXT:    store <616 x i8> undef, <616 x i8>* [[TMP1]], align 1024
+; CHECK-NEXT:    [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP6]], i64 56)
+; CHECK-NEXT:    [[TMP8:%.*]] = bitcast <560 x i8>* [[TMP0]] to i8*
+; CHECK-NEXT:    store <560 x i8> undef, <560 x i8>* [[TMP0]], align 1024
+; CHECK-NEXT:    [[TMP9:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP8]], i64 40)
+; CHECK-NEXT:    [[TMP10:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP5]], x86_amx [[TMP7]], x86_amx [[TMP9]])
 ; CHECK-NEXT:    br label [[FOR_COND_CLEANUP_I_I]]
 ; CHECK:       for.cond.cleanup.i.i:
-; CHECK-NEXT:    [[EVILPHI:%.*]] = phi <110 x i32> [ undef, [[WRAPPER_ENTRY:%.*]] ], [ [[TMP13]], [[FOR_BODY_I_LR_PH_I]] ]
-; CHECK-NEXT:    [[TMP14:%.*]] = bitcast <110 x i32>* [[TMP0]] to i8*
-; CHECK-NEXT:    store <110 x i32> [[EVILPHI]], <110 x i32>* [[TMP0]], align 512
-; CHECK-NEXT:    [[TMP15:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP14]], i64 40)
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP15]])
+; CHECK-NEXT:    [[TMP11:%.*]] = phi x86_amx [ [[TMP3]], [[WRAPPER_ENTRY:%.*]] ], [ [[TMP10]], [[FOR_BODY_I_LR_PH_I]] ]
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP11]])
 ; CHECK-NEXT:    ret void
 ;
 wrapper_entry:


        


More information about the llvm-commits mailing list