[Mlir-commits] [mlir] [mlir][spirv] Add gpu printf op lowering to spirv.CL.printf op (PR #78510)
Dimple Prajapati
llvmlistbot at llvm.org
Tue Sep 24 09:36:56 PDT 2024
https://github.com/drprajap updated https://github.com/llvm/llvm-project/pull/78510
>From 97971903d54e3cc7c77eae3a69bd227ea1f38f86 Mon Sep 17 00:00:00 2001
From: "Prajapati, Dimpalben R" <dimpalben.r.prajapati at intel.com>
Date: Mon, 11 Dec 2023 22:06:22 +0000
Subject: [PATCH 1/5] [mlir][spirv] Add gpu printf op lowering to
spirv.CL.printf op
This change contains following:
- adds lowering of printf op to spirv.CL.printf op in GPUToSPIRV pass.
- Fixes Constant decoration parsing for spirv GlobalVariable.
- minor modification to spirv.CL.printf op assembly format.
---
.../mlir/Dialect/SPIRV/IR/SPIRVCLOps.td | 4 +-
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 114 +++++++++++++++++-
.../SPIRV/Deserialization/Deserializer.cpp | 1 +
.../Target/SPIRV/Serialization/Serializer.cpp | 1 +
mlir/test/Conversion/GPUToSPIRV/printf.mlir | 71 +++++++++++
mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir | 4 +-
6 files changed, 190 insertions(+), 5 deletions(-)
create mode 100644 mlir/test/Conversion/GPUToSPIRV/printf.mlir
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
index c7c2fe8bc742c1..b5ca27d7d75316 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
@@ -875,7 +875,7 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> {
#### Example:
```mlir
- %0 = spirv.CL.printf %0 %1 %2 : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
+ %0 = spirv.CL.printf %0 : !spirv.ptr<i8, UniformConstant>(%1, %2 : i32, i32) -> i32
```
}];
@@ -889,7 +889,7 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> {
);
let assemblyFormat = [{
- $format `,` $arguments attr-dict `:` `(` type($format) `,` `(` type($arguments) `)` `)` `->` type($result)
+ $format `:` type($format) ( `(` $arguments^ `:` type($arguments) `)`)? attr-dict `->` type($result)
}];
let hasVerifier = 0;
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index a8ff9247e796ab..e1902868c21184 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -121,6 +121,15 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
ConversionPatternRewriter &rewriter) const override;
};
+class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
+public:
+ using OpConversionPattern<gpu::PrintfOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -597,6 +606,108 @@ class GPUSubgroupReduceConversion final
}
};
+LogicalResult GPUPrintfConversion::matchAndRewrite(
+ gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+
+ auto loc = gpuPrintfOp.getLoc();
+
+ auto funcOp =
+ rewriter.getBlock()->getParent()->getParentOfType<mlir::spirv::FuncOp>();
+
+ auto moduleOp = funcOp->getParentOfType<mlir::spirv::ModuleOp>();
+
+ const char formatStringPrefix[] = "printfMsg";
+ unsigned stringNumber = 0;
+ mlir::SmallString<16> globalVarName;
+ mlir::spirv::GlobalVariableOp globalVar;
+
+ // formulate spirv global variable name
+ do {
+ globalVarName.clear();
+ (formatStringPrefix + llvm::Twine(stringNumber++))
+ .toStringRef(globalVarName);
+ } while (moduleOp.lookupSymbol(globalVarName));
+
+ auto i8Type = rewriter.getI8Type();
+ auto i32Type = rewriter.getI32Type();
+
+ unsigned scNum = 0;
+ auto createSpecConstant = [&](unsigned value) {
+ auto attr = rewriter.getI8IntegerAttr(value);
+ mlir::SmallString<16> specCstName;
+ (llvm::Twine(globalVarName) + "_sc" + llvm::Twine(scNum++))
+ .toStringRef(specCstName);
+
+ return rewriter.create<mlir::spirv::SpecConstantOp>(
+ loc, rewriter.getStringAttr(specCstName), attr);
+ };
+
+ // define GlobalVarOp with printf format string using SpecConstants
+ // and make composite of SpecConstants
+ {
+ mlir::Operation *parent =
+ mlir::SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp());
+
+ mlir::ConversionPatternRewriter::InsertionGuard guard(rewriter);
+
+ mlir::Block &entryBlock = *parent->getRegion(0).begin();
+ rewriter.setInsertionPointToStart(
+ &entryBlock); // insertion point at module level
+
+ // Create Constituents with SpecConstant to construct
+ // SpecConstantCompositeOp
+ llvm::SmallString<20> formatString(gpuPrintfOp.getFormat());
+ formatString.push_back('\0'); // Null terminate for C
+ mlir::SmallVector<mlir::Attribute, 4> constituents;
+ for (auto c : formatString) {
+ auto cSpecConstantOp = createSpecConstant(c);
+ constituents.push_back(mlir::SymbolRefAttr::get(cSpecConstantOp));
+ }
+
+ // Create specialization constant composite defined via spirv.SpecConstant
+ size_t contentSize = constituents.size();
+ auto globalType = mlir::spirv::ArrayType::get(i8Type, contentSize);
+ mlir::spirv::SpecConstantCompositeOp specCstComposite;
+ mlir::SmallString<16> specCstCompositeName;
+ (llvm::Twine(globalVarName) + "_scc").toStringRef(specCstCompositeName);
+ specCstComposite = rewriter.create<mlir::spirv::SpecConstantCompositeOp>(
+ loc, mlir::TypeAttr::get(globalType),
+ rewriter.getStringAttr(specCstCompositeName),
+ rewriter.getArrayAttr(constituents));
+
+ // Define GlobalVariable initialized from Constant Composite
+ globalVar = rewriter.create<mlir::spirv::GlobalVariableOp>(
+ loc,
+ mlir::spirv::PointerType::get(
+ globalType, mlir::spirv::StorageClass::UniformConstant),
+ globalVarName, mlir::FlatSymbolRefAttr::get(specCstComposite));
+
+ globalVar->setAttr("Constant", rewriter.getUnitAttr());
+ }
+
+ // Get SSA value of Global variable
+ mlir::Value globalPtr =
+ rewriter.create<mlir::spirv::AddressOfOp>(loc, globalVar);
+ mlir::Value fmtStr = rewriter.create<mlir::spirv::BitcastOp>(
+ loc,
+ mlir::spirv::PointerType::get(i8Type,
+ mlir::spirv::StorageClass::UniformConstant),
+ globalPtr);
+
+ // Get printf arguments
+ auto argsRange = adaptor.getArgs();
+ mlir::SmallVector<mlir::Value, 4> printfArgs;
+ printfArgs.reserve(argsRange.size() + 1);
+ printfArgs.append(argsRange.begin(), argsRange.end());
+
+ rewriter.create<mlir::spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
+
+ rewriter.eraseOp(gpuPrintfOp);
+
+ return mlir::success();
+}
+
//===----------------------------------------------------------------------===//
// GPU To SPIRV Patterns.
//===----------------------------------------------------------------------===//
@@ -620,5 +731,6 @@ void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
spirv::BuiltIn::SubgroupSize>,
WorkGroupSizeConversion, GPUAllReduceConversion,
- GPUSubgroupReduceConversion>(typeConverter, patterns.getContext());
+ GPUSubgroupReduceConversion, GPUPrintfConversion>(typeConverter,
+ patterns.getContext());
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 38293f7106a05a..6c7fe41069824f 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -319,6 +319,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
case spirv::Decoration::Restrict:
case spirv::Decoration::RestrictPointer:
case spirv::Decoration::NoContraction:
+ case spirv::Decoration::Constant:
if (words.size() != 2) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index b0feda0517caa6..f4e6f677e2fb70 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -286,6 +286,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
case spirv::Decoration::Restrict:
case spirv::Decoration::RestrictPointer:
case spirv::Decoration::NoContraction:
+ case spirv::Decoration::Constant:
// For unit attributes and decoration attributes, the args list
// has no values so we do nothing.
if (isa<UnitAttr, DecorationAttr>(attr))
diff --git a/mlir/test/Conversion/GPUToSPIRV/printf.mlir b/mlir/test/Conversion/GPUToSPIRV/printf.mlir
new file mode 100644
index 00000000000000..1b951e7dad5e8d
--- /dev/null
+++ b/mlir/test/Conversion/GPUToSPIRV/printf.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Int8, Kernel], []>, #spirv.resource_limits<>>
+} {
+ func.func @main() {
+ %c1 = arith.constant 1 : index
+
+ gpu.launch_func @kernels::@printf
+ blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
+ args()
+ return
+ }
+
+ gpu.module @kernels {
+ // CHECK: spirv.module @{{.*}} Physical32 OpenCL {
+ // CHECK-DAG: spirv.SpecConstant [[SPECCST:@.*]] = {{.*}} : i8
+ // CHECK-DAG: spirv.SpecConstantComposite [[SPECCSTCOMPOSITE:@.*]] ([[SPECCST]], {{.*}}) : !spirv.array<[[ARRAYSIZE:.*]] x i8>
+ // CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
+ gpu.func @printf() kernel
+ attributes
+ {spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+ // CHECK: [[FMTSTR_ADDR:%.*]] = spirv.mlir.addressof [[PRINTMSG]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
+ // CHECK-NEXT: [[FMTSTR_PTR:%.*]] = spirv.Bitcast [[FMTSTR_ADDR]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant> to !spirv.ptr<i8, UniformConstant>
+ // CHECK-NEXT {{%.*}} = spirv.CL.printf [[FMTSTR_PTR]] : (!spirv.ptr<i8, UniformConstant>) -> i32
+ gpu.printf "\nHello\n"
+ // CHECK: spirv.Return
+ gpu.return
+ }
+ }
+}
+
+// -----
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Int8, Kernel], []>, #spirv.resource_limits<>>
+} {
+ func.func @main() {
+ %c1 = arith.constant 1 : index
+ %c100 = arith.constant 100: i32
+ %cst_f32 = arith.constant 314.4: f32
+
+ gpu.launch_func @kernels1::@printf_args
+ blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
+ args(%c100: i32, %cst_f32: f32)
+ return
+ }
+
+ gpu.module @kernels1 {
+ // CHECK: spirv.module @{{.*}} Physical32 OpenCL {
+ // CHECK-DAG: spirv.SpecConstant [[SPECCST:@.*]] = {{.*}} : i8
+ // CHECK-DAG: spirv.SpecConstantComposite [[SPECCSTCOMPOSITE:@.*]] ([[SPECCST]], {{.*}}) : !spirv.array<[[ARRAYSIZE:.*]] x i8>
+ // CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
+ gpu.func @printf_args(%arg0: i32, %arg1: f32) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+ %0 = gpu.block_id x
+ %1 = gpu.block_id y
+ %2 = gpu.thread_id x
+
+ // CHECK: [[FMTSTR_ADDR:%.*]] = spirv.mlir.addressof [[PRINTMSG]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
+ // CHECK-NEXT: [[FMTSTR_PTR1:%.*]] = spirv.Bitcast [[FMTSTR_ADDR]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant> to !spirv.ptr<i8, UniformConstant>
+ // CHECK-NEXT: {{%.*}} = spirv.CL.printf [[FMTSTR_PTR1]], {{%.*}}, {{%.*}}, {{%.*}} : (!spirv.ptr<i8, UniformConstant>, (i32, f32, i32)) -> i32
+ gpu.printf "\nHello, world : %d %f \n Thread id: %d\n" %arg0, %arg1, %2: i32, f32, index
+
+ // CHECK: spirv.Return
+ gpu.return
+ }
+ }
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
index 81ba471d3f51e3..171087a167850f 100644
--- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
@@ -275,8 +275,8 @@ func.func @rintvec(%arg0 : vector<3xf16>) -> () {
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func.func @printf(
func.func @printf(%arg0 : !spirv.ptr<i8, UniformConstant>, %arg1 : i32, %arg2 : i32) -> i32 {
- // CHECK: spirv.CL.printf {{%.*}}, {{%.*}}, {{%.*}} : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
- %0 = spirv.CL.printf %arg0, %arg1, %arg2 : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
+ // CHECK: spirv.CL.printf {{%.*}} : !spirv.ptr<i8, UniformConstant>({{%.*}}, {{%.*}} : i32, i32) -> i32
+ %0 = spirv.CL.printf %arg0 : !spirv.ptr<i8, UniformConstant>(%arg1, %arg2 : i32, i32) -> i32
return %0 : i32
}
>From 9408b46857ab4bc4274735c33675f1404f7616d3 Mon Sep 17 00:00:00 2001
From: Dimple Prajapati <dimpalben.r.prajapati at intel.com>
Date: Wed, 17 Jan 2024 22:19:26 +0000
Subject: [PATCH 2/5] clang format fix
---
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index e1902868c21184..b6abc8aa31e534 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -642,7 +642,7 @@ LogicalResult GPUPrintfConversion::matchAndRewrite(
return rewriter.create<mlir::spirv::SpecConstantOp>(
loc, rewriter.getStringAttr(specCstName), attr);
};
-
+
// define GlobalVarOp with printf format string using SpecConstants
// and make composite of SpecConstants
{
@@ -685,7 +685,7 @@ LogicalResult GPUPrintfConversion::matchAndRewrite(
globalVar->setAttr("Constant", rewriter.getUnitAttr());
}
-
+
// Get SSA value of Global variable
mlir::Value globalPtr =
rewriter.create<mlir::spirv::AddressOfOp>(loc, globalVar);
>From 7b1f78eece4918cc520f2c47ca2c8ce855c9974c Mon Sep 17 00:00:00 2001
From: Dimple Prajapati <dimpalben.r.prajapati at intel.com>
Date: Wed, 17 Jan 2024 22:33:03 +0000
Subject: [PATCH 3/5] Fix test
---
mlir/test/Conversion/GPUToSPIRV/printf.mlir | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/test/Conversion/GPUToSPIRV/printf.mlir b/mlir/test/Conversion/GPUToSPIRV/printf.mlir
index 1b951e7dad5e8d..4c77195f916014 100644
--- a/mlir/test/Conversion/GPUToSPIRV/printf.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/printf.mlir
@@ -17,7 +17,7 @@ module attributes {
// CHECK: spirv.module @{{.*}} Physical32 OpenCL {
// CHECK-DAG: spirv.SpecConstant [[SPECCST:@.*]] = {{.*}} : i8
// CHECK-DAG: spirv.SpecConstantComposite [[SPECCSTCOMPOSITE:@.*]] ([[SPECCST]], {{.*}}) : !spirv.array<[[ARRAYSIZE:.*]] x i8>
- // CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
+ // CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) {Constant} : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
gpu.func @printf() kernel
attributes
{spirv.entry_point_abi = #spirv.entry_point_abi<>} {
@@ -52,7 +52,7 @@ module attributes {
// CHECK: spirv.module @{{.*}} Physical32 OpenCL {
// CHECK-DAG: spirv.SpecConstant [[SPECCST:@.*]] = {{.*}} : i8
// CHECK-DAG: spirv.SpecConstantComposite [[SPECCSTCOMPOSITE:@.*]] ([[SPECCST]], {{.*}}) : !spirv.array<[[ARRAYSIZE:.*]] x i8>
- // CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
+ // CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) {Constant} : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
gpu.func @printf_args(%arg0: i32, %arg1: f32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%0 = gpu.block_id x
@@ -61,7 +61,7 @@ module attributes {
// CHECK: [[FMTSTR_ADDR:%.*]] = spirv.mlir.addressof [[PRINTMSG]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
// CHECK-NEXT: [[FMTSTR_PTR1:%.*]] = spirv.Bitcast [[FMTSTR_ADDR]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant> to !spirv.ptr<i8, UniformConstant>
- // CHECK-NEXT: {{%.*}} = spirv.CL.printf [[FMTSTR_PTR1]], {{%.*}}, {{%.*}}, {{%.*}} : (!spirv.ptr<i8, UniformConstant>, (i32, f32, i32)) -> i32
+ // CHECK-NEXT: {{%.*}} = spirv.CL.printf [[FMTSTR_PTR1]] : !spirv.ptr<i8, UniformConstant>({{%.*}}, {{%.*}}, {{%.*}} : i32, f32, i32) -> i32
gpu.printf "\nHello, world : %d %f \n Thread id: %d\n" %arg0, %arg1, %2: i32, f32, index
// CHECK: spirv.Return
>From b0e9c5f05c0036b78c37c89cd37c050abdb09dec Mon Sep 17 00:00:00 2001
From: Dimple Prajapati <dimpalben.r.prajapati at intel.com>
Date: Wed, 18 Sep 2024 01:02:35 +0000
Subject: [PATCH 4/5] address feedback
---
.../mlir/Dialect/SPIRV/IR/SPIRVCLOps.td | 6 +-
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 106 ++++++++++--------
mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir | 6 +-
3 files changed, 66 insertions(+), 52 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
index b5ca27d7d75316..4771153dbd0274 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
@@ -875,7 +875,9 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> {
#### Example:
```mlir
- %0 = spirv.CL.printf %0 : !spirv.ptr<i8, UniformConstant>(%1, %2 : i32, i32) -> i32
+ %0 = spirv.CL.printf %fmt %1, %2 : !spirv.ptr<i8, UniformConstant>, i32, i32 -> i32
+
+ $format `,` ( $arguments^ )? attr-dict `:` type($format) ( `,` type($arguments)^ )? `->` type($result)
```
}];
@@ -889,7 +891,7 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> {
);
let assemblyFormat = [{
- $format `:` type($format) ( `(` $arguments^ `:` type($arguments) `)`)? attr-dict `->` type($result)
+ $format ( $arguments^ )? attr-dict `:` type($format) ( `,` type($arguments)^ )? `->` type($result)
}];
let hasVerifier = 0;
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index b6abc8aa31e534..b10b5fd2c80f1c 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -123,7 +123,7 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
public:
- using OpConversionPattern<gpu::PrintfOp>::OpConversionPattern;
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
@@ -606,106 +606,118 @@ class GPUSubgroupReduceConversion final
}
};
+/// Pattern to convert a gpu.printf op into a SPIR-V CLPrintf op.
+
LogicalResult GPUPrintfConversion::matchAndRewrite(
gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- auto loc = gpuPrintfOp.getLoc();
+ Location loc = gpuPrintfOp.getLoc();
- auto funcOp =
- rewriter.getBlock()->getParent()->getParentOfType<mlir::spirv::FuncOp>();
+ auto moduleOp = gpuPrintfOp->getParentOfType<spirv::ModuleOp>();
- auto moduleOp = funcOp->getParentOfType<mlir::spirv::ModuleOp>();
+ if (!moduleOp) {
+ return success();
+ }
const char formatStringPrefix[] = "printfMsg";
unsigned stringNumber = 0;
- mlir::SmallString<16> globalVarName;
- mlir::spirv::GlobalVariableOp globalVar;
-
- // formulate spirv global variable name
+ SmallString<16> globalVarName;
+ spirv::GlobalVariableOp globalVar;
+
+ // SPIR-V global variable is used to initialize printf
+ // format string value, if there are multiple printf messages,
+ // each global var needs to be created with a unique name.
+ // like printfMsg0, printfMsg1, ...
+ // Formulate unique global variable name after
+ // searching in the module for existing global variable names.
+ // This is to avoid name collision with existing global variables.
do {
globalVarName.clear();
(formatStringPrefix + llvm::Twine(stringNumber++))
.toStringRef(globalVarName);
} while (moduleOp.lookupSymbol(globalVarName));
- auto i8Type = rewriter.getI8Type();
- auto i32Type = rewriter.getI32Type();
+ IntegerType i8Type = rewriter.getI8Type();
+ IntegerType i32Type = rewriter.getI32Type();
- unsigned scNum = 0;
+ // Each character of printf format string is
+ // stored as a spec constant. We need to create
+ // unique name for this spec constant like
+ // @printfMsg0_sc0, @printfMsg0_sc1, ... by searching in the module
+ // for existing spec constant names.
+ unsigned specConstantNum = 0;
auto createSpecConstant = [&](unsigned value) {
auto attr = rewriter.getI8IntegerAttr(value);
- mlir::SmallString<16> specCstName;
- (llvm::Twine(globalVarName) + "_sc" + llvm::Twine(scNum++))
+ SmallString<16> specCstName;
+ (llvm::Twine(globalVarName) + "_sc" + llvm::Twine(specConstantNum++))
.toStringRef(specCstName);
- return rewriter.create<mlir::spirv::SpecConstantOp>(
+ return rewriter.create<spirv::SpecConstantOp>(
loc, rewriter.getStringAttr(specCstName), attr);
};
-
- // define GlobalVarOp with printf format string using SpecConstants
- // and make composite of SpecConstants
{
- mlir::Operation *parent =
- mlir::SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp());
+ Operation *parent =
+ SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp());
- mlir::ConversionPatternRewriter::InsertionGuard guard(rewriter);
+ ConversionPatternRewriter::InsertionGuard guard(rewriter);
- mlir::Block &entryBlock = *parent->getRegion(0).begin();
+ Block &entryBlock = *parent->getRegion(0).begin();
rewriter.setInsertionPointToStart(
&entryBlock); // insertion point at module level
- // Create Constituents with SpecConstant to construct
+ // Create Constituents with SpecConstant by scanning format string
+ // Each character of format string is stored as a spec constant
+ // and then these spec constants are used to create a
// SpecConstantCompositeOp
- llvm::SmallString<20> formatString(gpuPrintfOp.getFormat());
+ llvm::SmallString<20> formatString(adaptor.getFormat());
formatString.push_back('\0'); // Null terminate for C
- mlir::SmallVector<mlir::Attribute, 4> constituents;
+ SmallVector<Attribute, 4> constituents;
for (auto c : formatString) {
auto cSpecConstantOp = createSpecConstant(c);
- constituents.push_back(mlir::SymbolRefAttr::get(cSpecConstantOp));
+ constituents.push_back(SymbolRefAttr::get(cSpecConstantOp));
}
- // Create specialization constant composite defined via spirv.SpecConstant
+ // Create SpecConstantCompositeOp to initialize the global variable
size_t contentSize = constituents.size();
- auto globalType = mlir::spirv::ArrayType::get(i8Type, contentSize);
- mlir::spirv::SpecConstantCompositeOp specCstComposite;
- mlir::SmallString<16> specCstCompositeName;
+ auto globalType = spirv::ArrayType::get(i8Type, contentSize);
+ spirv::SpecConstantCompositeOp specCstComposite;
+ SmallString<16> specCstCompositeName;
(llvm::Twine(globalVarName) + "_scc").toStringRef(specCstCompositeName);
- specCstComposite = rewriter.create<mlir::spirv::SpecConstantCompositeOp>(
- loc, mlir::TypeAttr::get(globalType),
+ specCstComposite = rewriter.create<spirv::SpecConstantCompositeOp>(
+ loc, TypeAttr::get(globalType),
rewriter.getStringAttr(specCstCompositeName),
rewriter.getArrayAttr(constituents));
- // Define GlobalVariable initialized from Constant Composite
- globalVar = rewriter.create<mlir::spirv::GlobalVariableOp>(
- loc,
- mlir::spirv::PointerType::get(
- globalType, mlir::spirv::StorageClass::UniformConstant),
- globalVarName, mlir::FlatSymbolRefAttr::get(specCstComposite));
+ auto ptrType = spirv::PointerType::get(
+ globalType, spirv::StorageClass::UniformConstant);
+
+ // Define a GlobalVarOp initialized using specialized constants
+ // that is used to specify the printf format string
+ // to be passed to the SPIRV CLPrintfOp.
+ globalVar = rewriter.create<spirv::GlobalVariableOp>(
+ loc, ptrType, globalVarName, FlatSymbolRefAttr::get(specCstComposite));
globalVar->setAttr("Constant", rewriter.getUnitAttr());
}
-
// Get SSA value of Global variable
- mlir::Value globalPtr =
- rewriter.create<mlir::spirv::AddressOfOp>(loc, globalVar);
- mlir::Value fmtStr = rewriter.create<mlir::spirv::BitcastOp>(
+ Value globalPtr = rewriter.create<spirv::AddressOfOp>(loc, globalVar);
+ Value fmtStr = rewriter.create<spirv::BitcastOp>(
loc,
- mlir::spirv::PointerType::get(i8Type,
- mlir::spirv::StorageClass::UniformConstant),
+ spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant),
globalPtr);
// Get printf arguments
auto argsRange = adaptor.getArgs();
- mlir::SmallVector<mlir::Value, 4> printfArgs;
+ SmallVector<Value, 4> printfArgs;
printfArgs.reserve(argsRange.size() + 1);
printfArgs.append(argsRange.begin(), argsRange.end());
- rewriter.create<mlir::spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
+ rewriter.create<spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
rewriter.eraseOp(gpuPrintfOp);
- return mlir::success();
+ return success();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
index 171087a167850f..106f978afc8ccd 100644
--- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
@@ -274,9 +274,9 @@ func.func @rintvec(%arg0 : vector<3xf16>) -> () {
// spirv.CL.printf
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func.func @printf(
-func.func @printf(%arg0 : !spirv.ptr<i8, UniformConstant>, %arg1 : i32, %arg2 : i32) -> i32 {
- // CHECK: spirv.CL.printf {{%.*}} : !spirv.ptr<i8, UniformConstant>({{%.*}}, {{%.*}} : i32, i32) -> i32
- %0 = spirv.CL.printf %arg0 : !spirv.ptr<i8, UniformConstant>(%arg1, %arg2 : i32, i32) -> i32
+func.func @printf(%fmt : !spirv.ptr<i8, UniformConstant>, %arg1 : i32, %arg2 : i32) -> i32 {
+ // CHECK: spirv.CL.printf {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i8, UniformConstant>, i32, i32 -> i32
+ %0 = spirv.CL.printf %fmt, %arg1, %arg2 : !spirv.ptr<i8, UniformConstant>, i32, i32 -> i32
return %0 : i32
}
>From b7a5ea0845c533b1333a514c67223b683f099210 Mon Sep 17 00:00:00 2001
From: Dimple Prajapati <dimpalben.r.prajapati at intel.com>
Date: Tue, 24 Sep 2024 09:36:45 -0700
Subject: [PATCH 5/5] Update mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
review feedback
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 8 +++-----
1 file changed, 3 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index b10b5fd2c80f1c..a8542a003e967e 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -614,11 +614,9 @@ LogicalResult GPUPrintfConversion::matchAndRewrite(
Location loc = gpuPrintfOp.getLoc();
- auto moduleOp = gpuPrintfOp->getParentOfType<spirv::ModuleOp>();
-
- if (!moduleOp) {
- return success();
- }
+ auto moduleOp = gpuPrintfOp.getParentOfType<spirv::ModuleOp>();
+ if (!moduleOp)
+ return failure();
const char formatStringPrefix[] = "printfMsg";
unsigned stringNumber = 0;
More information about the Mlir-commits
mailing list