[Mlir-commits] [mlir] [MLIR][NVVM][Refactor] Refactor intrinsic lowering for NVVM Ops (PR #157079)

Rajat Bajpai llvmlistbot at llvm.org
Tue Sep 16 04:35:45 PDT 2025


================
@@ -2054,25 +2137,628 @@ PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
   switch (addressSpace) {
   case MemSpace::kGenericMemorySpace:
     return *cacheLevel == CacheLevel::L1
-               ? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args})
-               : NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
+               ? NVVM::IIDArgsWithTypes(llvm::Intrinsic::nvvm_prefetch_L1, args,
+                                        {})
+               : NVVM::IIDArgsWithTypes(llvm::Intrinsic::nvvm_prefetch_L2, args,
+                                        {});
   case MemSpace::kGlobalMemorySpace:
     return *cacheLevel == CacheLevel::L1
-               ? NVVM::IDArgPair(
-                     {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
-               : NVVM::IDArgPair(
-                     {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
+               ? NVVM::IIDArgsWithTypes(
+                     llvm::Intrinsic::nvvm_prefetch_global_L1, args, {})
+               : NVVM::IIDArgsWithTypes(
+                     llvm::Intrinsic::nvvm_prefetch_global_L2, args, {});
   case MemSpace::kLocalMemorySpace:
     return *cacheLevel == CacheLevel::L1
-               ? NVVM::IDArgPair(
-                     {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
-               : NVVM::IDArgPair(
-                     {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
+               ? NVVM::IIDArgsWithTypes(llvm::Intrinsic::nvvm_prefetch_local_L1,
+                                        args, {})
+               : NVVM::IIDArgsWithTypes(llvm::Intrinsic::nvvm_prefetch_local_L2,
+                                        args, {});
   default:
     llvm_unreachable("Invalid pointer address space");
   }
 }
 
+#define REDUX_F32_ID_IMPL(op, abs, hasNaN)                                     \
+  hasNaN ? llvm::Intrinsic::nvvm_redux_sync_f##op##abs##_NaN                   \
+         : llvm::Intrinsic::nvvm_redux_sync_f##op##abs
+
+#define GET_REDUX_F32_ID(op, hasAbs, hasNaN)                                   \
+  hasAbs ? REDUX_F32_ID_IMPL(op, _abs, hasNaN) : REDUX_F32_ID_IMPL(op, , hasNaN)
+
+NVVM::IIDArgsWithTypes
+ReduxOp::getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt,
+                                llvm::IRBuilderBase &builder) {
+  auto thisOp = cast<NVVM::ReduxOp>(op);
+
+  llvm::SmallVector<llvm::Value *> args;
+  args.push_back(mt.lookupValue(thisOp.getVal()));
+  args.push_back(mt.lookupValue(thisOp.getMaskAndClamp()));
+
+  bool hasAbs = thisOp.getAbs();
+  bool hasNaN = thisOp.getNan();
+  NVVM::ReduxKind kind = thisOp.getKind();
+
+  llvm::Intrinsic::ID id;
+
+  switch (kind) {
+  case NVVM::ReduxKind::ADD:
+    id = llvm::Intrinsic::nvvm_redux_sync_add;
+    break;
+  case NVVM::ReduxKind::UMAX:
+    id = llvm::Intrinsic::nvvm_redux_sync_umax;
+    break;
+  case NVVM::ReduxKind::UMIN:
+    id = llvm::Intrinsic::nvvm_redux_sync_umin;
+    break;
+  case NVVM::ReduxKind::AND:
+    id = llvm::Intrinsic::nvvm_redux_sync_and;
+    break;
+  case NVVM::ReduxKind::OR:
+    id = llvm::Intrinsic::nvvm_redux_sync_or;
+    break;
+  case NVVM::ReduxKind::XOR:
+    id = llvm::Intrinsic::nvvm_redux_sync_xor;
+    break;
+  case NVVM::ReduxKind::MAX:
+    id = llvm::Intrinsic::nvvm_redux_sync_max;
+    break;
+  case NVVM::ReduxKind::MIN:
+    id = llvm::Intrinsic::nvvm_redux_sync_min;
+    break;
+  case NVVM::ReduxKind::FMIN:
+    id = GET_REDUX_F32_ID(min, hasAbs, hasNaN);
+    break;
+  case NVVM::ReduxKind::FMAX:
+    id = GET_REDUX_F32_ID(max, hasAbs, hasNaN);
+    break;
+  }
+
+  return {id, std::move(args), {}};
+}
+
+NVVM::IIDArgsWithTypes
+ShflOp::getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt,
+                               llvm::IRBuilderBase &builder) {
+  auto thisOp = cast<NVVM::ShflOp>(op);
+
+  llvm::SmallVector<llvm::Value *> args;
+  args.push_back(mt.lookupValue(thisOp.getThreadMask()));
+  args.push_back(mt.lookupValue(thisOp.getVal()));
+  args.push_back(mt.lookupValue(thisOp.getOffset()));
+  args.push_back(mt.lookupValue(thisOp.getMaskAndClamp()));
+
+  mlir::Type resultType = thisOp.getResult().getType();
+  NVVM::ShflKind kind = thisOp.getKind();
+  bool withPredicate = static_cast<bool>(thisOp.getReturnValueAndIsValid());
+
+  llvm::Intrinsic::ID id;
+
+  if (withPredicate) {
+    resultType = cast<LLVM::LLVMStructType>(resultType).getBody()[0];
+    switch (kind) {
+    case NVVM::ShflKind::bfly:
+      id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
+                                : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
+      break;
+    case NVVM::ShflKind::up:
+      id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
+                                : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
+      break;
+    case NVVM::ShflKind::down:
+      id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
+                                : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
+      break;
+    case NVVM::ShflKind::idx:
+      id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
+                                : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
+      break;
+    }
+  } else {
+    switch (kind) {
+    case NVVM::ShflKind::bfly:
+      id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
+                                : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
+      break;
+    case NVVM::ShflKind::up:
+      id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
+                                : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
+      break;
+    case NVVM::ShflKind::down:
+      id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
+                                : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
+      break;
+    case NVVM::ShflKind::idx:
+      id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
+                                : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
+      break;
+    }
+  }
+
+  return {id, std::move(args), {}};
+}
+
+NVVM::IIDArgsWithTypes
+MatchSyncOp::getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt,
+                                    llvm::IRBuilderBase &builder) {
+  auto thisOp = cast<NVVM::MatchSyncOp>(op);
+
+  llvm::SmallVector<llvm::Value *> args;
+  args.push_back(mt.lookupValue(thisOp.getThreadMask()));
+  args.push_back(mt.lookupValue(thisOp.getVal()));
+
+  llvm::Intrinsic::ID id;
+
+  mlir::Type valType = thisOp.getVal().getType();
+  NVVM::MatchSyncKind kind = thisOp.getKind();
+
+  switch (kind) {
+  case NVVM::MatchSyncKind::any:
+    id = valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_any_sync_i32
+                               : llvm::Intrinsic::nvvm_match_any_sync_i64;
+    break;
+  case NVVM::MatchSyncKind::all:
+    // match.all instruction has two variants -- one returns a single value,
+    // another returns a pair {value, predicate}. We currently only implement
+    // the latter as that's the variant exposed by CUDA API.
+    id = valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_all_sync_i32p
+                               : llvm::Intrinsic::nvvm_match_all_sync_i64p;
+    break;
----------------
rajatbajpai wrote:

Please add default case with llvm unreachable. See other places as well.

https://github.com/llvm/llvm-project/pull/157079


More information about the Mlir-commits mailing list