[Mlir-commits] [mlir] [MLIR][NVVM][Refactor] Refactor intrinsic lowering for NVVM Ops (PR #157079)
Srinivasa Ravi
llvmlistbot at llvm.org
Tue Sep 16 10:15:02 PDT 2025
================
@@ -2054,25 +2137,628 @@ PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
switch (addressSpace) {
case MemSpace::kGenericMemorySpace:
return *cacheLevel == CacheLevel::L1
- ? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args})
- : NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
+ ? NVVM::IIDArgsWithTypes(llvm::Intrinsic::nvvm_prefetch_L1, args,
+ {})
+ : NVVM::IIDArgsWithTypes(llvm::Intrinsic::nvvm_prefetch_L2, args,
+ {});
case MemSpace::kGlobalMemorySpace:
return *cacheLevel == CacheLevel::L1
- ? NVVM::IDArgPair(
- {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
- : NVVM::IDArgPair(
- {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
+ ? NVVM::IIDArgsWithTypes(
+ llvm::Intrinsic::nvvm_prefetch_global_L1, args, {})
+ : NVVM::IIDArgsWithTypes(
+ llvm::Intrinsic::nvvm_prefetch_global_L2, args, {});
case MemSpace::kLocalMemorySpace:
return *cacheLevel == CacheLevel::L1
- ? NVVM::IDArgPair(
- {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
- : NVVM::IDArgPair(
- {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
+ ? NVVM::IIDArgsWithTypes(llvm::Intrinsic::nvvm_prefetch_local_L1,
+ args, {})
+ : NVVM::IIDArgsWithTypes(llvm::Intrinsic::nvvm_prefetch_local_L2,
+ args, {});
default:
llvm_unreachable("Invalid pointer address space");
}
}
+#define REDUX_F32_ID_IMPL(op, abs, hasNaN) \
+ hasNaN ? llvm::Intrinsic::nvvm_redux_sync_f##op##abs##_NaN \
+ : llvm::Intrinsic::nvvm_redux_sync_f##op##abs
+
+#define GET_REDUX_F32_ID(op, hasAbs, hasNaN) \
+ hasAbs ? REDUX_F32_ID_IMPL(op, _abs, hasNaN) : REDUX_F32_ID_IMPL(op, , hasNaN)
+
+NVVM::IIDArgsWithTypes
+ReduxOp::getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::ReduxOp>(op);
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getVal()));
+ args.push_back(mt.lookupValue(thisOp.getMaskAndClamp()));
+
+ bool hasAbs = thisOp.getAbs();
+ bool hasNaN = thisOp.getNan();
+ NVVM::ReduxKind kind = thisOp.getKind();
+
+ llvm::Intrinsic::ID id;
+
+ switch (kind) {
+ case NVVM::ReduxKind::ADD:
+ id = llvm::Intrinsic::nvvm_redux_sync_add;
+ break;
+ case NVVM::ReduxKind::UMAX:
+ id = llvm::Intrinsic::nvvm_redux_sync_umax;
+ break;
+ case NVVM::ReduxKind::UMIN:
+ id = llvm::Intrinsic::nvvm_redux_sync_umin;
+ break;
+ case NVVM::ReduxKind::AND:
+ id = llvm::Intrinsic::nvvm_redux_sync_and;
+ break;
+ case NVVM::ReduxKind::OR:
+ id = llvm::Intrinsic::nvvm_redux_sync_or;
+ break;
+ case NVVM::ReduxKind::XOR:
+ id = llvm::Intrinsic::nvvm_redux_sync_xor;
+ break;
+ case NVVM::ReduxKind::MAX:
+ id = llvm::Intrinsic::nvvm_redux_sync_max;
+ break;
+ case NVVM::ReduxKind::MIN:
+ id = llvm::Intrinsic::nvvm_redux_sync_min;
+ break;
+ case NVVM::ReduxKind::FMIN:
+ id = GET_REDUX_F32_ID(min, hasAbs, hasNaN);
+ break;
+ case NVVM::ReduxKind::FMAX:
+ id = GET_REDUX_F32_ID(max, hasAbs, hasNaN);
+ break;
+ }
+
+ return {id, std::move(args), {}};
+}
+
+NVVM::IIDArgsWithTypes
+ShflOp::getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::ShflOp>(op);
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getThreadMask()));
+ args.push_back(mt.lookupValue(thisOp.getVal()));
+ args.push_back(mt.lookupValue(thisOp.getOffset()));
+ args.push_back(mt.lookupValue(thisOp.getMaskAndClamp()));
+
+ mlir::Type resultType = thisOp.getResult().getType();
+ NVVM::ShflKind kind = thisOp.getKind();
+ bool withPredicate = static_cast<bool>(thisOp.getReturnValueAndIsValid());
+
+ llvm::Intrinsic::ID id;
+
+ if (withPredicate) {
+ resultType = cast<LLVM::LLVMStructType>(resultType).getBody()[0];
+ switch (kind) {
+ case NVVM::ShflKind::bfly:
+ id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
+ : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
+ break;
+ case NVVM::ShflKind::up:
+ id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
+ : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
+ break;
+ case NVVM::ShflKind::down:
+ id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
+ : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
+ break;
+ case NVVM::ShflKind::idx:
+ id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
+ : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
+ break;
+ }
+ } else {
+ switch (kind) {
+ case NVVM::ShflKind::bfly:
+ id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
+ : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
+ break;
+ case NVVM::ShflKind::up:
+ id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
+ : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
+ break;
+ case NVVM::ShflKind::down:
+ id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
+ : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
+ break;
+ case NVVM::ShflKind::idx:
+ id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
+ : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
+ break;
+ }
+ }
+
+ return {id, std::move(args), {}};
+}
+
+NVVM::IIDArgsWithTypes
+MatchSyncOp::getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MatchSyncOp>(op);
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getThreadMask()));
+ args.push_back(mt.lookupValue(thisOp.getVal()));
+
+ llvm::Intrinsic::ID id;
+
+ mlir::Type valType = thisOp.getVal().getType();
+ NVVM::MatchSyncKind kind = thisOp.getKind();
+
+ switch (kind) {
+ case NVVM::MatchSyncKind::any:
+ id = valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_any_sync_i32
+ : llvm::Intrinsic::nvvm_match_any_sync_i64;
+ break;
+ case NVVM::MatchSyncKind::all:
+ // match.all instruction has two variants -- one returns a single value,
+ // another returns a pair {value, predicate}. We currently only implement
+ // the latter as that's the variant exposed by CUDA API.
+ id = valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_all_sync_i32p
+ : llvm::Intrinsic::nvvm_match_all_sync_i64p;
+ break;
----------------
Wolfram70 wrote:
Actually this results in a warning (treated as an error) due to the use of `default` when all the possible values of the enum are already added as cases.
https://github.com/llvm/llvm-project/pull/157079
More information about the Mlir-commits
mailing list