[clang] [llvm] [X86][AMX] Support AMX-TRANSPOSE (PR #113532)

Phoebe Wang via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 1 00:16:07 PDT 2024


================
@@ -121,12 +137,96 @@ static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) {
   llvm_unreachable("No terminator in the entry block!");
 }
 
-static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
+class ShapeCalculator {
+private:
+  TargetMachine *TM = nullptr;
+
+  // In AMX intrinsics we let Shape = {Row, Col}, but the
+  // RealCol = Col / ElementSize. We may use the RealCol
+  // as a new Row for other new created AMX intrinsics.
+  std::map<Value *, Value *> Col2Row, Row2Col;
+
+public:
+  ShapeCalculator(TargetMachine *TargetM) : TM(TargetM) {}
+  std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo);
+  std::pair<Value *, Value *> getShape(PHINode *Phi);
+  Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity);
+  Value *getColFromRow(Instruction *II, Value *V, unsigned Granularity);
+};
+
+Value *ShapeCalculator::getRowFromCol(Instruction *II, Value *V,
+                                      unsigned Granularity) {
+  if (Col2Row.count(V))
+    return Col2Row[V];
+  IRBuilder<> Builder(II);
+  Value *RealRow = nullptr;
+  if (isa<ConstantInt>(V))
+    RealRow =
+        Builder.getInt16((cast<ConstantInt>(V)->getSExtValue()) / Granularity);
+  else if (isa<Instruction>(V)) {
+    // When it is not a const value and it is not a function argument, we
+    // create Row after the definition of V instead of
+    // before II. For example, II is %118, we try to getshape for %117:
+    //   %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
+    //   i32> %115).
+    //   %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
+    //   %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
+    //   %117).
+    // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
+    // definition is after its user(new tileload for %117).
+    // So, the best choice is to create %row right after the definition of
+    // %106.
+    Builder.SetInsertPoint(cast<Instruction>(V));
+    RealRow = Builder.CreateUDiv(V, Builder.getInt16(4));
+    cast<Instruction>(RealRow)->moveAfter(cast<Instruction>(V));
+  } else {
+    // When it is not a const value and it is a function argument, we create
+    // Row at the entry bb.
+    IRBuilder<> NewBuilder(
+        getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
+    RealRow = NewBuilder.CreateUDiv(V, NewBuilder.getInt16(Granularity));
+  }
+  Col2Row[V] = RealRow;
+  return RealRow;
+}
+
+Value *ShapeCalculator::getColFromRow(Instruction *II, Value *V,
+                                      unsigned Granularity) {
+  if (Row2Col.count(V))
+    return Row2Col[V];
+  IRBuilder<> Builder(II);
+  Value *RealCol = nullptr;
+  if (isa<ConstantInt>(V))
+    RealCol =
+        Builder.getInt16((cast<ConstantInt>(V)->getSExtValue()) * Granularity);
+  else if (isa<Instruction>(V)) {
+    Builder.SetInsertPoint(cast<Instruction>(V));
+    RealCol = Builder.CreateNUWMul(V, Builder.getInt16(Granularity));
+    cast<Instruction>(RealCol)->moveAfter(cast<Instruction>(V));
+  } else {
+    // When it is not a const value and it is a function argument, we create
+    // Row at the entry bb.
----------------
phoebewang wrote:

Row is correct.

https://github.com/llvm/llvm-project/pull/113532


More information about the llvm-commits mailing list