[Mlir-commits] [mlir] [mlir][spirv] Add gpu printf op lowering to spirv.CL.printf op (PR #78510)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 17 15:16:57 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir
Author: Dimple Prajapati (drprajap)
<details>
<summary>Changes</summary>
This change contains following:
- adds lowering of printf op to spirv.CL.printf op in GPUToSPIRV pass.
- Fixes Constant decoration parsing for spirv GlobalVariable.
- minor modification to spirv.CL.printf op assembly format.
---
Full diff: https://github.com/llvm/llvm-project/pull/78510.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td (+2-2)
- (modified) mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp (+113-1)
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+1)
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+1)
- (added) mlir/test/Conversion/GPUToSPIRV/printf.mlir (+71)
- (modified) mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir (+2-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
index c7c2fe8bc742c1..b5ca27d7d75316 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
@@ -875,7 +875,7 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> {
#### Example:
```mlir
- %0 = spirv.CL.printf %0 %1 %2 : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
+ %0 = spirv.CL.printf %0 : !spirv.ptr<i8, UniformConstant>(%1, %2 : i32, i32) -> i32
```
}];
@@ -889,7 +889,7 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> {
);
let assemblyFormat = [{
- $format `,` $arguments attr-dict `:` `(` type($format) `,` `(` type($arguments) `)` `)` `->` type($result)
+ $format `:` type($format) ( `(` $arguments^ `:` type($arguments) `)`)? attr-dict `->` type($result)
}];
let hasVerifier = 0;
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index d7885e0359592d..8d9f4554d8d799 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -135,6 +135,15 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
ConversionPatternRewriter &rewriter) const override;
};
+class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
+public:
+ using OpConversionPattern<gpu::PrintfOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -607,6 +616,108 @@ class GPUSubgroupReduceConversion final
}
};
+LogicalResult GPUPrintfConversion::matchAndRewrite(
+ gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+
+ auto loc = gpuPrintfOp.getLoc();
+
+ auto funcOp =
+ rewriter.getBlock()->getParent()->getParentOfType<mlir::spirv::FuncOp>();
+
+ auto moduleOp = funcOp->getParentOfType<mlir::spirv::ModuleOp>();
+
+ const char formatStringPrefix[] = "printfMsg";
+ unsigned stringNumber = 0;
+ mlir::SmallString<16> globalVarName;
+ mlir::spirv::GlobalVariableOp globalVar;
+
+ // formulate spirv global variable name
+ do {
+ globalVarName.clear();
+ (formatStringPrefix + llvm::Twine(stringNumber++))
+ .toStringRef(globalVarName);
+ } while (moduleOp.lookupSymbol(globalVarName));
+
+ auto i8Type = rewriter.getI8Type();
+ auto i32Type = rewriter.getI32Type();
+
+ unsigned scNum = 0;
+ auto createSpecConstant = [&](unsigned value) {
+ auto attr = rewriter.getI8IntegerAttr(value);
+ mlir::SmallString<16> specCstName;
+ (llvm::Twine(globalVarName) + "_sc" + llvm::Twine(scNum++))
+ .toStringRef(specCstName);
+
+ return rewriter.create<mlir::spirv::SpecConstantOp>(
+ loc, rewriter.getStringAttr(specCstName), attr);
+ };
+
+ // define GlobalVarOp with printf format string using SpecConstants
+ // and make composite of SpecConstants
+ {
+ mlir::Operation *parent =
+ mlir::SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp());
+
+ mlir::ConversionPatternRewriter::InsertionGuard guard(rewriter);
+
+ mlir::Block &entryBlock = *parent->getRegion(0).begin();
+ rewriter.setInsertionPointToStart(
+ &entryBlock); // insertion point at module level
+
+ // Create Constituents with SpecConstant to construct
+ // SpecConstantCompositeOp
+ llvm::SmallString<20> formatString(gpuPrintfOp.getFormat());
+ formatString.push_back('\0'); // Null terminate for C
+ mlir::SmallVector<mlir::Attribute, 4> constituents;
+ for (auto c : formatString) {
+ auto cSpecConstantOp = createSpecConstant(c);
+ constituents.push_back(mlir::SymbolRefAttr::get(cSpecConstantOp));
+ }
+
+ // Create specialization constant composite defined via spirv.SpecConstant
+ size_t contentSize = constituents.size();
+ auto globalType = mlir::spirv::ArrayType::get(i8Type, contentSize);
+ mlir::spirv::SpecConstantCompositeOp specCstComposite;
+ mlir::SmallString<16> specCstCompositeName;
+ (llvm::Twine(globalVarName) + "_scc").toStringRef(specCstCompositeName);
+ specCstComposite = rewriter.create<mlir::spirv::SpecConstantCompositeOp>(
+ loc, mlir::TypeAttr::get(globalType),
+ rewriter.getStringAttr(specCstCompositeName),
+ rewriter.getArrayAttr(constituents));
+
+ // Define GlobalVariable initialized from Constant Composite
+ globalVar = rewriter.create<mlir::spirv::GlobalVariableOp>(
+ loc,
+ mlir::spirv::PointerType::get(
+ globalType, mlir::spirv::StorageClass::UniformConstant),
+ globalVarName, mlir::FlatSymbolRefAttr::get(specCstComposite));
+
+ globalVar->setAttr("Constant", rewriter.getUnitAttr());
+ }
+
+ // Get SSA value of Global variable
+ mlir::Value globalPtr =
+ rewriter.create<mlir::spirv::AddressOfOp>(loc, globalVar);
+ mlir::Value fmtStr = rewriter.create<mlir::spirv::BitcastOp>(
+ loc,
+ mlir::spirv::PointerType::get(i8Type,
+ mlir::spirv::StorageClass::UniformConstant),
+ globalPtr);
+
+ // Get printf arguments
+ auto argsRange = adaptor.getArgs();
+ mlir::SmallVector<mlir::Value, 4> printfArgs;
+ printfArgs.reserve(argsRange.size() + 1);
+ printfArgs.append(argsRange.begin(), argsRange.end());
+
+ rewriter.create<mlir::spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
+
+ rewriter.eraseOp(gpuPrintfOp);
+
+ return mlir::success();
+}
+
//===----------------------------------------------------------------------===//
// GPU To SPIRV Patterns.
//===----------------------------------------------------------------------===//
@@ -630,5 +741,6 @@ void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
spirv::BuiltIn::SubgroupSize>,
WorkGroupSizeConversion, GPUAllReduceConversion,
- GPUSubgroupReduceConversion>(typeConverter, patterns.getContext());
+ GPUSubgroupReduceConversion, GPUPrintfConversion>(typeConverter,
+ patterns.getContext());
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 02d03b3a0faeee..89a72260290e22 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -309,6 +309,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
case spirv::Decoration::RelaxedPrecision:
case spirv::Decoration::Restrict:
case spirv::Decoration::RestrictPointer:
+ case spirv::Decoration::Constant:
if (words.size() != 2) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 40337e007bbf74..2252c339af0a75 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -272,6 +272,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
case spirv::Decoration::RelaxedPrecision:
case spirv::Decoration::Restrict:
case spirv::Decoration::RestrictPointer:
+ case spirv::Decoration::Constant:
// For unit attributes and decoration attributes, the args list
// has no values so we do nothing.
if (isa<UnitAttr, DecorationAttr>(attr))
diff --git a/mlir/test/Conversion/GPUToSPIRV/printf.mlir b/mlir/test/Conversion/GPUToSPIRV/printf.mlir
new file mode 100644
index 00000000000000..4c77195f916014
--- /dev/null
+++ b/mlir/test/Conversion/GPUToSPIRV/printf.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Int8, Kernel], []>, #spirv.resource_limits<>>
+} {
+ func.func @main() {
+ %c1 = arith.constant 1 : index
+
+ gpu.launch_func @kernels::@printf
+ blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
+ args()
+ return
+ }
+
+ gpu.module @kernels {
+ // CHECK: spirv.module @{{.*}} Physical32 OpenCL {
+ // CHECK-DAG: spirv.SpecConstant [[SPECCST:@.*]] = {{.*}} : i8
+ // CHECK-DAG: spirv.SpecConstantComposite [[SPECCSTCOMPOSITE:@.*]] ([[SPECCST]], {{.*}}) : !spirv.array<[[ARRAYSIZE:.*]] x i8>
+ // CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) {Constant} : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
+ gpu.func @printf() kernel
+ attributes
+ {spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+ // CHECK: [[FMTSTR_ADDR:%.*]] = spirv.mlir.addressof [[PRINTMSG]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
+ // CHECK-NEXT: [[FMTSTR_PTR:%.*]] = spirv.Bitcast [[FMTSTR_ADDR]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant> to !spirv.ptr<i8, UniformConstant>
+ // CHECK-NEXT {{%.*}} = spirv.CL.printf [[FMTSTR_PTR]] : (!spirv.ptr<i8, UniformConstant>) -> i32
+ gpu.printf "\nHello\n"
+ // CHECK: spirv.Return
+ gpu.return
+ }
+ }
+}
+
+// -----
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Int8, Kernel], []>, #spirv.resource_limits<>>
+} {
+ func.func @main() {
+ %c1 = arith.constant 1 : index
+ %c100 = arith.constant 100: i32
+ %cst_f32 = arith.constant 314.4: f32
+
+ gpu.launch_func @kernels1::@printf_args
+ blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
+ args(%c100: i32, %cst_f32: f32)
+ return
+ }
+
+ gpu.module @kernels1 {
+ // CHECK: spirv.module @{{.*}} Physical32 OpenCL {
+ // CHECK-DAG: spirv.SpecConstant [[SPECCST:@.*]] = {{.*}} : i8
+ // CHECK-DAG: spirv.SpecConstantComposite [[SPECCSTCOMPOSITE:@.*]] ([[SPECCST]], {{.*}}) : !spirv.array<[[ARRAYSIZE:.*]] x i8>
+ // CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) {Constant} : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
+ gpu.func @printf_args(%arg0: i32, %arg1: f32) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+ %0 = gpu.block_id x
+ %1 = gpu.block_id y
+ %2 = gpu.thread_id x
+
+ // CHECK: [[FMTSTR_ADDR:%.*]] = spirv.mlir.addressof [[PRINTMSG]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
+ // CHECK-NEXT: [[FMTSTR_PTR1:%.*]] = spirv.Bitcast [[FMTSTR_ADDR]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant> to !spirv.ptr<i8, UniformConstant>
+ // CHECK-NEXT: {{%.*}} = spirv.CL.printf [[FMTSTR_PTR1]] : !spirv.ptr<i8, UniformConstant>({{%.*}}, {{%.*}}, {{%.*}} : i32, f32, i32) -> i32
+ gpu.printf "\nHello, world : %d %f \n Thread id: %d\n" %arg0, %arg1, %2: i32, f32, index
+
+ // CHECK: spirv.Return
+ gpu.return
+ }
+ }
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
index 81ba471d3f51e3..171087a167850f 100644
--- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
@@ -275,8 +275,8 @@ func.func @rintvec(%arg0 : vector<3xf16>) -> () {
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func.func @printf(
func.func @printf(%arg0 : !spirv.ptr<i8, UniformConstant>, %arg1 : i32, %arg2 : i32) -> i32 {
- // CHECK: spirv.CL.printf {{%.*}}, {{%.*}}, {{%.*}} : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
- %0 = spirv.CL.printf %arg0, %arg1, %arg2 : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
+ // CHECK: spirv.CL.printf {{%.*}} : !spirv.ptr<i8, UniformConstant>({{%.*}}, {{%.*}} : i32, i32) -> i32
+ %0 = spirv.CL.printf %arg0 : !spirv.ptr<i8, UniformConstant>(%arg1, %arg2 : i32, i32) -> i32
return %0 : i32
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/78510
More information about the Mlir-commits
mailing list