[clang] [llvm] [X86][AMX] Support AMX-TRANSPOSE (PR #113532)
Phoebe Wang via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 1 00:16:07 PDT 2024
================
@@ -121,12 +137,96 @@ static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) {
llvm_unreachable("No terminator in the entry block!");
}
-static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
+class ShapeCalculator {
+private:
+ TargetMachine *TM = nullptr;
+
+ // In AMX intrinsics we let Shape = {Row, Col}, but the
+ // RealCol = Col / ElementSize. We may use the RealCol
+ // as a new Row for other new created AMX intrinsics.
+ std::map<Value *, Value *> Col2Row, Row2Col;
+
+public:
+ ShapeCalculator(TargetMachine *TargetM) : TM(TargetM) {}
+ std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo);
+ std::pair<Value *, Value *> getShape(PHINode *Phi);
+ Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity);
+ Value *getColFromRow(Instruction *II, Value *V, unsigned Granularity);
+};
+
+Value *ShapeCalculator::getRowFromCol(Instruction *II, Value *V,
+ unsigned Granularity) {
+ if (Col2Row.count(V))
+ return Col2Row[V];
+ IRBuilder<> Builder(II);
+ Value *RealRow = nullptr;
+ if (isa<ConstantInt>(V))
+ RealRow =
+ Builder.getInt16((cast<ConstantInt>(V)->getSExtValue()) / Granularity);
+ else if (isa<Instruction>(V)) {
+ // When it is not a const value and it is not a function argument, we
+ // create Row after the definition of V instead of
+ // before II. For example, II is %118, we try to getshape for %117:
+ // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
+ // i32> %115).
+ // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
+ // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
+ // %117).
+ // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
+ // definition is after its user(new tileload for %117).
+ // So, the best choice is to create %row right after the definition of
+ // %106.
+ Builder.SetInsertPoint(cast<Instruction>(V));
+ RealRow = Builder.CreateUDiv(V, Builder.getInt16(4));
+ cast<Instruction>(RealRow)->moveAfter(cast<Instruction>(V));
+ } else {
+ // When it is not a const value and it is a function argument, we create
+ // Row at the entry bb.
+ IRBuilder<> NewBuilder(
+ getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
+ RealRow = NewBuilder.CreateUDiv(V, NewBuilder.getInt16(Granularity));
+ }
+ Col2Row[V] = RealRow;
+ return RealRow;
+}
+
+Value *ShapeCalculator::getColFromRow(Instruction *II, Value *V,
+ unsigned Granularity) {
+ if (Row2Col.count(V))
+ return Row2Col[V];
+ IRBuilder<> Builder(II);
+ Value *RealCol = nullptr;
+ if (isa<ConstantInt>(V))
+ RealCol =
+ Builder.getInt16((cast<ConstantInt>(V)->getSExtValue()) * Granularity);
+ else if (isa<Instruction>(V)) {
+ Builder.SetInsertPoint(cast<Instruction>(V));
+ RealCol = Builder.CreateNUWMul(V, Builder.getInt16(Granularity));
+ cast<Instruction>(RealCol)->moveAfter(cast<Instruction>(V));
+ } else {
+ // When it is not a const value and it is a function argument, we create
+ // Row at the entry bb.
----------------
phoebewang wrote:
Row is correct.
https://github.com/llvm/llvm-project/pull/113532
More information about the llvm-commits
mailing list