[llvm] [NFC][NVVM][NVPTX] Moved common code for tcgen05.mma to the base class (PR #176327)
Pradeep Kumar via llvm-commits
llvm-commits at lists.llvm.org
Sun Jan 18 22:07:02 PST 2026
================
@@ -6071,195 +6060,142 @@ foreach sp = [0, 1] in {
!eq(kind, "tf32")), [0, 1], [0]) in {
foreach ashift = !if(!eq(space, "tensor"), [0, 1], [0]) in {
def :
- Tcgen05MMADisableOutputLaneInst<sp, space, kind, cta_group,
- collector_usage, scale_input_d,
- ashift>;
- }
- }
- }
- }
- }
- }
-}
+ Tcgen05MMADisableOutputLaneInst<sparse, space, kind, cta_group,
+ collector_usage, scale_input_d,
+ ashift>;
+ } // ashift
+ } // scale_input_d
+ } // collector_usage
+ } // cta_group
+ } // kind
+ } // space
+} // sparse
-class Tcgen05MMABlockScaleInst<bit Sp, string ASpace, string KindStr,
- int CtaGroup, string ScaleVecSize, string CollectorUsageStr>:
- NVPTXInst<(outs), (ins), "?", []>, Requires<[]> {
+//
+// tcgen05.mma.block_scale Instructions
+//
+
+class Tcgen05MMABlockScaleInst<bit IsSparse, string ASpace, string Kind,
+ int CtaGroup, string ScaleVecSize,
+ string CollectorUsage>:
+ Tcgen05MMABase<IsSparse, ASpace, Kind, CtaGroup, CollectorUsage> {
let Predicates = !cond(
- !and(!eq(Sp, 1),
- !eq(KindStr, "mxf4")) : [callSubtarget<"hasTcgen05MMASparseMxf4">],
- !and(!eq(Sp, 1),
- !eq(KindStr, "mxf4nvf4")) : [callSubtarget<"hasTcgen05MMASparseMxf4nvf4">],
+ !and(IsSparse,
+ !eq(Kind, "mxf4")) : [callSubtarget<"hasTcgen05MMASparseMxf4">],
+ !and(IsSparse,
+ !eq(Kind, "mxf4nvf4")) : [callSubtarget<"hasTcgen05MMASparseMxf4nvf4">],
!ne(ScaleVecSize, "") : [callSubtarget<"hasTcgen05InstSupport">, hasPTX<88>],
true : [callSubtarget<"hasTcgen05InstSupport">]
);
Intrinsic Intrin = !cast<Intrinsic>(
- NVVM_TCGEN05_MMA_BLOCKSCALE<Sp, ASpace, KindStr, ScaleVecSize>.record_name);
-
- dag SparseMetadataIns = !if(!eq(Sp, 1), (ins B32:$spmetadata), (ins));
- dag SparseMetadataIntr = !if(!eq(Sp, 1), (Intrin i32:$spmetadata), (Intrin));
- string SparseMetadataStr = !if(!eq(Sp, 1), ", [$spmetadata]", "");
+ NVVM_TCGEN05_MMA_BLOCKSCALE<IsSparse, ASpace, Kind, ScaleVecSize>.record_name);
- int KindVal = !cond(
- !eq(KindStr, "mxf8f6f4") : 0,
- !eq(KindStr, "mxf4") : 1,
- !eq(KindStr, "mxf4nvf4") : 2,
+ let KindVal = !cond(
+ !eq(Kind, "mxf8f6f4") : 0,
+ !eq(Kind, "mxf4") : 1,
+ !eq(Kind, "mxf4nvf4") : 2,
);
- int CollectorUsage = !cond(
- !eq(CollectorUsageStr, "discard") : 0,
- !eq(CollectorUsageStr, "lastuse") : 1,
- !eq(CollectorUsageStr, "fill") : 2,
- !eq(CollectorUsageStr, "use") : 3,
- );
-
- string AOperandStr = !if(!eq(ASpace, "tensor"), "[$a]", "$a");
- NVPTXRegClass ARegClass = !if(!eq(ASpace, "tensor"), B32, B64);
+ let InOperandList = !con(BaseInOperandList,
+ (ins B32:$scale_a,
+ B32:$scale_b));
+ let AsmString = Prefix
+ # SpCtaKindStr
+ # ".block_scale" # ScaleVecSize
+ # ".collector::a::" # CollectorUsage
+ # BaseOperandsStr
+ # ", [$scale_a], [$scale_b]"
+ # InputDStr
+ # ";";
- dag input = !con((ins B32:$dtmem, ARegClass:$a, B64:$b,
- B32:$idesc, B1:$enable_inp_d),
- SparseMetadataIns,
- (ins B32:$scale_a,
- B32:$scale_b));
+ dag IntrinsicPattern = !con(!foreach(tmp, BasePatternArgs, !subst(ins, Intrin, tmp)),
+ (Intrin i32:$scale_a,
+ i32:$scale_b));
- let InOperandList = input;
- let OutOperandList = (outs);
- let AsmString = "tcgen05.mma"
- # !if(!eq(Sp, 1), ".sp", "")
- # ".cta_group::" # CtaGroup
- # ".kind::" # KindStr
- # ".block_scale" # ScaleVecSize
- # ".collector::a::" # CollectorUsageStr
- # " [$dtmem], " # AOperandStr # ", $b"
- # SparseMetadataStr
- # ", $idesc, [$scale_a], [$scale_b], $enable_inp_d;";
-
- dag IntrinsicPattern = !con((Intrin i32:$dtmem,
- ARegClass:$a, i64:$b,
- i32:$idesc,
- i1:$enable_inp_d),
- SparseMetadataIntr,
- (Intrin i32:$scale_a,
- i32:$scale_b));
-
- dag FlagOperands = (Intrin (i32 CtaGroup), (i32 CollectorUsage));
+ dag FlagOperands = (Intrin (i32 CtaGroup), (i32 CollectorUsageVal));
let Pattern = [!con(IntrinsicPattern, FlagOperands)];
}
// tcgen05.mma.block_scale
-foreach sp = [0, 1] in {
+foreach sparse = [0, 1] in {
foreach space = ["tensor", "shared"] in {
foreach kind = ["mxf8f6f4", "mxf4", "mxf4nvf4"] in {
foreach scale_vec_size = ["", ".block16", ".block32"] in {
foreach cta_group = [1, 2] in {
foreach collector_usage = ["fill", "use", "lastuse", "discard"] in {
if NVVM_TCGEN05_MMA_BLOCKSCALE_SUPPORTED<kind, scale_vec_size>.ret then {
- def : Tcgen05MMABlockScaleInst<sp, space, kind, cta_group, scale_vec_size,
- collector_usage>;
+ def : Tcgen05MMABlockScaleInst<sparse, space, kind, cta_group,
+ scale_vec_size, collector_usage>;
}
- }
- }
- }
- }
- }
-}
+ } // collector_usage
+ } // cta_group
+ } // scale_vec_size
+ } // kind
+ } // space
+} // sparse
//
// tcgen05.mma.ws Instructions
//
-class Tcgen05MMAWSInst<bit Sp, string ASpace, string KindStr,
- int CollectorBufferB, string CollectorUsageOpStr,
- bit HasZeroColMask> :
- NVPTXInst<(outs), (ins), "?", []>, Requires<[]> {
+class Tcgen05MMAWSInst<bit IsSparse, string ASpace, string Kind,
+ int CollectorBufferB, string CollectorUsage,
+ bit IsZeroColMask> :
+ Tcgen05MMABase<IsSparse, ASpace, Kind, /*CtaGroup=*/ 1, CollectorUsage> {
let Predicates = !cond(
- !eq(KindStr, "i8") : [callSubtarget<"hasTcgen05MMAI8Kind">],
+ !eq(Kind, "i8") : [callSubtarget<"hasTcgen05MMAI8Kind">],
true : [callSubtarget<"hasTcgen05InstSupport">]
);
Intrinsic Intrin = !cast<Intrinsic>(
- NVVM_TCGEN05_MMA_WS<Sp, ASpace, HasZeroColMask>.record_name);
-
- dag ZeroColMaskIns = !if(!eq(HasZeroColMask, 1),
- (ins B64:$zero_col_mask), (ins));
- string ZeroColMaskStr = !if(!eq(HasZeroColMask, 1), ", $zero_col_mask", "");
- dag ZeroColMaskIntr = !if(!eq(HasZeroColMask, 1),
- (Intrin i64:$zero_col_mask), (Intrin));
-
- dag SparseMetadataIns = !if(!eq(Sp, 1), (ins B32:$spmetadata), (ins));
- dag SparseMetadataIntr = !if(!eq(Sp, 1), (Intrin B32:$spmetadata), (Intrin));
- string SparseMetadataStr = !if(!eq(Sp, 1), ", [$spmetadata]", "");
-
- int KindVal = !cond(
- !eq(KindStr, "f16") : 0,
- !eq(KindStr, "tf32") : 1,
- !eq(KindStr, "f8f6f4"): 2,
- !eq(KindStr, "i8") : 3,
- );
-
- int CollectorUsageOp = !cond(
- !eq(CollectorUsageOpStr, "discard"): 0,
- !eq(CollectorUsageOpStr, "lastuse"): 1,
- !eq(CollectorUsageOpStr, "fill") : 2,
- !eq(CollectorUsageOpStr, "use") : 3,
- );
-
- string AOperandStr = !if(!eq(ASpace, "tensor"), "[$a]", "$a");
- NVPTXRegClass ARegClass = !if(!eq(ASpace, "tensor"), B32, B64);
-
- dag input = !con((ins B32:$dtmem,
- ARegClass:$a, B64:$b,
- B32:$idesc,
- B1:$enable_inp_d),
- SparseMetadataIns,
- ZeroColMaskIns);
+ NVVM_TCGEN05_MMA_WS<IsSparse, ASpace, IsZeroColMask>.record_name);
+
+ dag ZeroColMaskIns = !if(IsZeroColMask, (ins B64:$zero_col_mask), (ins));
+ string ZeroColMaskStr = !if(IsZeroColMask, ", $zero_col_mask", "");
+ dag ZeroColMaskIntr = !if(IsZeroColMask,
+ (Intrin i64:$zero_col_mask), (Intrin));
+
+ let InOperandList = !con(BaseInOperandList,
+ ZeroColMaskIns);
+
+ let AsmString = Prefix
+ # ".ws"
+ # SpCtaKindStr
+ # ".collector::b" # CollectorBufferB
+ # "::" # CollectorUsage
+ # BaseOperandsStr
+ # InputDStr
+ # ZeroColMaskStr
+ # ";";
- let InOperandList = input;
- let OutOperandList = (outs);
- let AsmString = "tcgen05.mma.ws"
- # !if(!eq(Sp, 1), ".sp", "")
- # ".cta_group::1"
- # ".kind::" # KindStr
- # ".collector::b" # CollectorBufferB
- # "::" # CollectorUsageOpStr
- # " [$dtmem], " # AOperandStr # ", $b"
- # SparseMetadataStr
- # ", $idesc, $enable_inp_d"
- # ZeroColMaskStr
- # ";";
-
- dag IntrinsicPattern = !con((Intrin i32:$dtmem,
- ARegClass:$a, i64:$b,
- i32:$idesc,
- i1:$enable_inp_d),
- SparseMetadataIntr,
- ZeroColMaskIntr);
+ dag IntrinsicPattern = !con(!foreach(tmp, BasePatternArgs, !subst(ins, Intrin, tmp)),
+ ZeroColMaskIntr);
dag FlagOperands = (Intrin (i32 KindVal), (i32 CollectorBufferB),
- (i32 CollectorUsageOp));
+ (i32 CollectorUsageVal));
let Pattern = [!con(IntrinsicPattern, FlagOperands)];
}
// tcgen05.mma.ws
-foreach sp = [0, 1] in {
+foreach sparse = [0, 1] in {
foreach space = ["shared", "tensor"] in {
foreach kind = ["f16", "tf32", "f8f6f4", "i8"] in {
foreach collector_buffer_b = [0, 1, 2, 3] in {
foreach collector_usage_op = ["discard", "fill", "use", "lastuse"] in {
foreach zero_col_mask = [0, 1] in {
- def : Tcgen05MMAWSInst<sp, space, kind, collector_buffer_b,
+ def : Tcgen05MMAWSInst<sparse, space, kind, collector_buffer_b,
collector_usage_op, zero_col_mask>;
- }
- }
- }
- }
- }
-}
+ } // zero_col_mask
+ } // collector_usage_op
+ } // collector_buffer_b
+ } //kind
----------------
schwarzschild-radius wrote:
Missing space here?
https://github.com/llvm/llvm-project/pull/176327
More information about the llvm-commits
mailing list