[Mlir-commits] [mlir] [MLIR][NVVM] Add dot.accumulate.2way Op (PR #140518)
Durgadoss R
llvmlistbot at llvm.org
Tue Jun 3 04:07:20 PDT 2025
================
@@ -1627,26 +1620,65 @@ static void nvvmInferResultRanges(Operation *op, Value result,
}
}
-llvm::Intrinsic::ID
-DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
- NVVM::DotAccumulate4WayType b_type) {
- bool is_a_siext = a_type == NVVM::DotAccumulate4WayType::S8;
- bool is_b_siext = b_type == NVVM::DotAccumulate4WayType::S8;
+static llvm::Value *getAsPackedI32(llvm::Value *arg,
+ llvm::IRBuilderBase &builder) {
+ return builder.CreateBitCast(arg,
+ llvm::Type::getInt32Ty(builder.getContext()));
+}
+
+NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
+ args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
+ args.push_back(mt.lookupValue(curOp.getC()));
+
+ bool is_a_siext = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
+ bool is_b_siext = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
unsigned type = (is_a_siext << 1) | is_b_siext;
switch (type) {
case 0:
- return llvm::Intrinsic::nvvm_idp4a_u_u;
+ return {llvm::Intrinsic::nvvm_idp4a_u_u, args};
case 1:
- return llvm::Intrinsic::nvvm_idp4a_u_s;
+ return {llvm::Intrinsic::nvvm_idp4a_u_s, args};
case 2:
- return llvm::Intrinsic::nvvm_idp4a_s_u;
+ return {llvm::Intrinsic::nvvm_idp4a_s_u, args};
case 3:
- return llvm::Intrinsic::nvvm_idp4a_s_s;
+ return {llvm::Intrinsic::nvvm_idp4a_s_s, args};
default:
----------------
durga4github wrote:
ok, the latest array implementation looks neat.
https://github.com/llvm/llvm-project/pull/140518
More information about the Mlir-commits
mailing list