[Mlir-commits] [mlir] [MLIR][NVVM] Update dot.accumulate NVVM Ops (PR #140518)
Durgadoss R
llvmlistbot at llvm.org
Tue May 20 02:39:20 PDT 2025
================
@@ -3508,6 +3513,84 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
}];
}
+def DotAccumulate2WayModeLo : I32EnumAttrCase<"LO", 0, "lo">;
+def DotAccumulate2WayModeHi : I32EnumAttrCase<"HI", 1, "hi">;
+
+def DotAccumulate2WayMode : I32EnumAttr<"DotAccumulate2WayMode",
+ "NVVM DotAccumulate2WayMode",
+ [DotAccumulate2WayModeLo, DotAccumulate2WayModeHi]> {
+ let cppNamespace = "::mlir::NVVM";
+ let genSpecializedAttr = 0;
+}
+
+def DotAccumulate2WayModeAttr : EnumAttr<NVVM_Dialect, DotAccumulate2WayMode, "dot_accumulate_2way_mode"> {
+ let assemblyFormat = "$value";
+}
+
+def NVVM_DotAccumulate2WayOp : NVVM_Op<"dot.accumulate.2way"> {
+ let summary = "Two-way 16-bit to 8-bit dot product-accumulate instruction";
+ let description = [{
+ Performs a two-way 16-bit to 8-bit dot-product which is accumulated in a
+ 32-bit result.
+ Operand `a` is a vector of two 16-bit elements and operand `b` a vector
+ of four 8-bit elements between which the dot product is computed.
+
+ The `a_type` and `b_type` attributes specify the type of the elements in `a`
+ and `b` respectively.
+ If `a_type` is `s16`, then the elements in `a` are sign-extended to
+ 32-bit before the dot product is computed.
+ If `a_type` is `u16`, then the elements in `a` are zero-extended to
+ 32-bit instead.
+ If `b_type` is `s8`, then the elements in `b` are sign-extended to
+ 32-bit before the dot product is computed.
+ If `b_type` is `u8`, then the elements in `b` are zero-extended to
+ 32-bit instead.
+
+ The 'mode` attribute specifies which two bytes of `b` are used for the dot
+ product. If `mode` is `lo`, then the dot product is computed between `a`
+ and elements at indices 2 and 3 of `b`. If `mode` is `hi`, then the dot
+ product is computed between `a` and elements at indices 0 and 1 of `b`.
+ Operand `c` is a 32-bit integer to which the result is accumulated. It is
+ treated as holding a signed integer if any of `a_type` or `b_type` is
+ signed.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp2a)
+ }];
+
+ let hasVerifier = 1;
+
+ let arguments = (ins
+ DotAccumulate2WayModeAttr:$mode,
+ VectorOfLengthAndType<[2], [I16]>:$a,
+ DotAccumulateTypeAttr:$a_type,
+ VectorOfLengthAndType<[4], [I8]>:$b,
+ DotAccumulateTypeAttr:$b_type,
+ I32:$c
+ );
+
+ let results = (outs I32:$res);
+
+ let assemblyFormat = "$mode $a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
+
+ let extraClassDeclaration = [{
+ static llvm::Intrinsic::ID
+ getIntrinsicID(NVVM::DotAccumulateType a_type,
+ NVVM::DotAccumulateType b_type);
+ llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
+ llvm::Value* isHi(NVVM::DotAccumulate2WayMode mode,
+ llvm::IRBuilderBase& builder);
+ }];
+
+ string llvmBuilder = [{
+ llvm::Intrinsic::ID id = NVVM::DotAccumulate2WayOp::getIntrinsicID($a_type, $b_type);
----------------
durga4github wrote:
we can do `getIntrinsicIDAndArgs()` here to simplify the llvmBuilder here..
https://github.com/llvm/llvm-project/pull/140518
More information about the Mlir-commits
mailing list