[Mlir-commits] [mlir] c945022 - [MLIR][NVVM] Support packed registers in `inline_ptx` (#154904)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 2 05:15:59 PDT 2025
Author: Guray Ozen
Date: 2025-09-02T14:15:55+02:00
New Revision: c945022f2fd8321559d84e2272005487c5ced924
URL: https://github.com/llvm/llvm-project/commit/c945022f2fd8321559d84e2272005487c5ced924
DIFF: https://github.com/llvm/llvm-project/commit/c945022f2fd8321559d84e2272005487c5ced924.diff
LOG: [MLIR][NVVM] Support packed registers in `inline_ptx` (#154904)
Add support for packed registers with vectors.
Example:
```
%wo0 = nvvm.inline_ptx
"dp4a.s32.s32 {$w0}, {$r0}, {$r1}, {$r2};"
ro(%src, %mask, %zero : vector<4xi8>, i32, i32)
-> i32
```
Here, `vector<4xi8>` is lowered to an `i32` register (i.e., an `r` in
PTX).
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
index 21331e5aa89f3..cb2489335a317 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
@@ -18,6 +18,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
+#include "llvm/Support/LogicalResult.h"
namespace mlir {
namespace NVVM {
@@ -82,7 +83,8 @@ class PtxBuilder {
needsManualRegisterMapping(needsManualRegisterMapping) {}
/// Add an operand with the read/write input type.
- void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read);
+ LogicalResult insertValue(Value v,
+ PTXRegisterMod itype = PTXRegisterMod::Read);
/// Builds the inline assembly Op and returns it. The `insertValue` needs to
/// be called to pass operands before building the PTX.
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index c67ec3642f121..314cbed2e4f33 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -26,6 +26,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "nvvm-to-llvm"
@@ -62,7 +63,8 @@ struct PtxLowering
PtxBuilder generator(op, rewriter, needsManualMapping);
for (auto &[asmValue, modifier] : asmValues) {
LDBG() << asmValue << "\t Modifier : " << modifier;
- generator.insertValue(asmValue, modifier);
+ if (failed(generator.insertValue(asmValue, modifier)))
+ return failure();
}
generator.buildAndReplaceOp();
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index 6d2a64f94e3ca..7220e10ea84d3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -12,10 +12,17 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/Regex.h"
#define DEBUG_TYPE "ptx-builder"
@@ -31,35 +38,88 @@ using namespace NVVM;
static constexpr int64_t kSharedMemorySpace = 3;
-static char getRegisterType(Type type) {
- if (type.isInteger(1))
- return 'b';
- if (type.isInteger(16))
- return 'h';
- if (type.isInteger(32))
- return 'r';
- if (type.isInteger(64))
- return 'l';
- if (type.isF32())
- return 'f';
- if (type.isF64())
- return 'd';
- if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
- // Shared address spaces is addressed with 32-bit pointers.
- if (ptr.getAddressSpace() == kSharedMemorySpace) {
+static FailureOr<char> getRegisterType(Type type, Location loc) {
+ MLIRContext *ctx = type.getContext();
+ auto i16 = IntegerType::get(ctx, 16);
+ auto i32 = IntegerType::get(ctx, 32);
+ auto f32 = Float32Type::get(ctx);
+
+ auto getRegisterTypeForScalar = [&](Type type) -> FailureOr<char> {
+ if (type.isInteger(1))
+ return 'b';
+ if (type.isInteger(16))
+ return 'h';
+ if (type.isInteger(32))
return 'r';
+ if (type.isInteger(64))
+ return 'l';
+ if (type.isF32())
+ return 'f';
+ if (type.isF64())
+ return 'd';
+ if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
+ // Shared address spaces is addressed with 32-bit pointers.
+ if (ptr.getAddressSpace() == kSharedMemorySpace) {
+ return 'r';
+ }
+ return 'l';
+ }
+ // register type for struct is not supported.
+ mlir::emitError(
+ loc, "The register type could not be deduced from MLIR type. The ")
+ << type
+ << " is not supported. Supported types are:"
+ "i1, i16, i32, i64, f32, f64,"
+ "pointers.\nPlease use llvm.bitcast if you have
diff erent type. "
+ "\nSee the constraints from here: "
+ "https://docs.nvidia.com/cuda/inline-ptx-assembly/"
+ "index.html#constraints";
+ return failure();
+ };
+
+ // Packed registers
+ if (auto v = dyn_cast<VectorType>(type)) {
+ assert(v.getNumDynamicDims() == 0 && "Dynamic vectors are not supported");
+
+ int64_t lanes = v.getNumElements();
+ Type elem = v.getElementType();
+
+ // Case 1. Single vector
+ if (lanes <= 1)
+ return getRegisterTypeForScalar(elem);
+
+ // Case 2. Packed registers
+ Type widened = elem;
+ switch (lanes) {
+
+ case 2:
+ if (elem.isF16() || elem.isBF16()) // vector<2xf16>
+ widened = f32;
+ else if (elem.isFloat(8)) // vector<2xf8>
+ widened = i16;
+ break;
+ case 4:
+ if (elem.isInteger(8)) // vector<i8x4>
+ widened = i32;
+ else if (elem.isFloat(8)) // vector<f8x4>
+ widened = f32;
+ else if (elem.isFloat(4)) // vector<f4x4>
+ widened = i16;
+ break;
+ // Other packing is not supported
+ default:
+ break;
}
- return 'l';
+ return getRegisterTypeForScalar(widened);
}
- // register type for struct is not supported.
- llvm_unreachable("The register type could not deduced from MLIR type");
- return '?';
+
+ return getRegisterTypeForScalar(type);
}
-static char getRegisterType(Value v) {
+static FailureOr<char> getRegisterType(Value v, Location loc) {
if (v.getDefiningOp<LLVM::ConstantOp>())
return 'n';
- return getRegisterType(v.getType());
+ return getRegisterType(v.getType(), loc);
}
/// Extract every element of a struct value.
@@ -75,10 +135,11 @@ static SmallVector<Value> extractStructElements(PatternRewriter &rewriter,
return elems;
}
-void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
+LogicalResult PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
LDBG() << v << "\t Modifier : " << itype << "\n";
registerModifiers.push_back(itype);
+ Location loc = interfaceOp->getLoc();
auto getModifier = [&]() -> const char * {
switch (itype) {
case PTXRegisterMod::Read:
@@ -111,21 +172,29 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
}
for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
if (itype != PTXRegisterMod::Write) {
- Value extractValue = LLVM::ExtractValueOp::create(
- rewriter, interfaceOp->getLoc(), v, idx);
+ Value extractValue =
+ LLVM::ExtractValueOp::create(rewriter, loc, v, idx);
addValue(extractValue);
}
if (itype == PTXRegisterMod::ReadWrite) {
ss << idx << ",";
} else {
- ss << getModifier() << getRegisterType(t) << ",";
+ FailureOr<char> regType = getRegisterType(t, loc);
+ if (failed(regType))
+ return rewriter.notifyMatchFailure(loc,
+ "failed to get register type");
+ ss << getModifier() << regType.value() << ",";
}
}
- return;
+ return success();
}
// Handle Scalars
addValue(v);
- ss << getModifier() << getRegisterType(v) << ",";
+ FailureOr<char> regType = getRegisterType(v, loc);
+ if (failed(regType))
+ return rewriter.notifyMatchFailure(loc, "failed to get register type");
+ ss << getModifier() << regType.value() << ",";
+ return success();
}
/// Check if the operation needs to pack and unpack results.
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 92930f9cbaa49..bf80d9a1668a1 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -745,3 +745,38 @@ llvm.func @nvvm_pmevent() {
nvvm.pmevent id = 4
llvm.return
}
+
+// -----
+
+llvm.func @inline_ptx_pack_4i8(%src : vector<4xi8>, %mask : i32, %zero: i32) {
+// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "dp4a.s32.s32 $0, $1, $2, $3;", "=r,r,r,r" %{{.*}}, %{{.*}}, %{{.*}} : (vector<4xi8>, i32, i32) -> i32
+ %wo0 = nvvm.inline_ptx "dp4a.s32.s32 {$w0}, {$r0}, {$r1}, {$r2};"
+ ro(%src, %mask, %zero : vector<4xi8>, i32, i32)
+ -> i32
+ llvm.return
+}
+
+llvm.func @inline_ptx_pack_2bf16(%a : f32, %b : f32) {
+ // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "cvt.rn.satfinite.bf16x2.f32 $0, $1, $2;", "=f,f,f" %{{.*}}, %{{.*}} : (f32, f32) -> vector<2xbf16>
+ %wo0 = nvvm.inline_ptx "cvt.rn.satfinite.bf16x2.f32 {$w0}, {$r0}, {$r1};"
+ ro(%a, %b : f32, f32)
+ -> vector<2xbf16>
+ llvm.return
+}
+
+llvm.func @inline_ptx_cvt_rn_e4m3x2_f16x2(%a : i16) {
+// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "cvt.rz.satfinite.ue8m0x2.bf16x2 $0, $1", "=f,h" %{{.*}} : (i16) -> vector<2xbf16>
+ %wo0 = nvvm.inline_ptx "cvt.rz.satfinite.ue8m0x2.bf16x2 {$w0}, {$r0}"
+ ro(%a : i16)
+ -> vector<2xbf16>
+ llvm.return
+}
+
+llvm.func @cvt_i8_bf16(%a : i8) {
+ // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .b16 r;\0A\09.reg .b8 s;\0A\09mov.b16 {s,_}, $0;\0A\09cvt.rn.bf16.s8 r, s;\0A\09mov.b16 $1, r;\0A\09", "=h,h" %{{.*}} : (i16) -> i16
+ %za = llvm.zext %a : i8 to i16
+ %wo0 = nvvm.inline_ptx "{\n\t.reg .b16 r;\n\t.reg .b8 s;\n\tmov.b16 {s,_}, {$w0};\n\tcvt.rn.bf16.s8 r, s;\n\tmov.b16 {$r0}, r;\n\t"
+ ro(%za : i16)
+ -> i16
+ llvm.return
+}
More information about the Mlir-commits
mailing list