[Mlir-commits] [mlir] [mlir][spirv] Add gpu printf op lowering to spirv.CL.printf op (PR #78510)
Jakub Kuderski
llvmlistbot at llvm.org
Wed Sep 18 06:50:46 PDT 2024
================
@@ -597,6 +606,120 @@ class GPUSubgroupReduceConversion final
}
};
+/// Pattern to convert a gpu.printf op into a SPIR-V CLPrintf op.
+
+LogicalResult GPUPrintfConversion::matchAndRewrite(
+ gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+
+ Location loc = gpuPrintfOp.getLoc();
+
+ auto moduleOp = gpuPrintfOp->getParentOfType<spirv::ModuleOp>();
+
+ if (!moduleOp) {
+ return success();
+ }
+
+ const char formatStringPrefix[] = "printfMsg";
+ unsigned stringNumber = 0;
+ SmallString<16> globalVarName;
+ spirv::GlobalVariableOp globalVar;
+
+ // SPIR-V global variable is used to initialize printf
+ // format string value, if there are multiple printf messages,
+ // each global var needs to be created with a unique name.
+ // like printfMsg0, printfMsg1, ...
+ // Formulate unique global variable name after
+ // searching in the module for existing global variable names.
+ // This is to avoid name collision with existing global variables.
+ do {
+ globalVarName.clear();
+ (formatStringPrefix + llvm::Twine(stringNumber++))
+ .toStringRef(globalVarName);
+ } while (moduleOp.lookupSymbol(globalVarName));
+
+ IntegerType i8Type = rewriter.getI8Type();
+ IntegerType i32Type = rewriter.getI32Type();
+
+ // Each character of printf format string is
+ // stored as a spec constant. We need to create
+ // unique name for this spec constant like
+ // @printfMsg0_sc0, @printfMsg0_sc1, ... by searching in the module
+ // for existing spec constant names.
+ unsigned specConstantNum = 0;
+ auto createSpecConstant = [&](unsigned value) {
+ auto attr = rewriter.getI8IntegerAttr(value);
+ SmallString<16> specCstName;
+ (llvm::Twine(globalVarName) + "_sc" + llvm::Twine(specConstantNum++))
+ .toStringRef(specCstName);
+
+ return rewriter.create<spirv::SpecConstantOp>(
+ loc, rewriter.getStringAttr(specCstName), attr);
+ };
+ {
+ Operation *parent =
+ SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp());
+
+ ConversionPatternRewriter::InsertionGuard guard(rewriter);
+
+ Block &entryBlock = *parent->getRegion(0).begin();
+ rewriter.setInsertionPointToStart(
+ &entryBlock); // insertion point at module level
+
+ // Create Constituents with SpecConstant by scanning format string
+ // Each character of format string is stored as a spec constant
+ // and then these spec constants are used to create a
+ // SpecConstantCompositeOp
+ llvm::SmallString<20> formatString(adaptor.getFormat());
+ formatString.push_back('\0'); // Null terminate for C
+ SmallVector<Attribute, 4> constituents;
+ for (auto c : formatString) {
+ auto cSpecConstantOp = createSpecConstant(c);
+ constituents.push_back(SymbolRefAttr::get(cSpecConstantOp));
+ }
+
+ // Create SpecConstantCompositeOp to initialize the global variable
+ size_t contentSize = constituents.size();
+ auto globalType = spirv::ArrayType::get(i8Type, contentSize);
+ spirv::SpecConstantCompositeOp specCstComposite;
+ SmallString<16> specCstCompositeName;
+ (llvm::Twine(globalVarName) + "_scc").toStringRef(specCstCompositeName);
+ specCstComposite = rewriter.create<spirv::SpecConstantCompositeOp>(
+ loc, TypeAttr::get(globalType),
+ rewriter.getStringAttr(specCstCompositeName),
+ rewriter.getArrayAttr(constituents));
+
+ auto ptrType = spirv::PointerType::get(
+ globalType, spirv::StorageClass::UniformConstant);
+
+ // Define a GlobalVarOp initialized using specialized constants
+ // that is used to specify the printf format string
+ // to be passed to the SPIRV CLPrintfOp.
+ globalVar = rewriter.create<spirv::GlobalVariableOp>(
+ loc, ptrType, globalVarName, FlatSymbolRefAttr::get(specCstComposite));
+
+ globalVar->setAttr("Constant", rewriter.getUnitAttr());
+ }
+ // Get SSA value of Global variable
+ Value globalPtr = rewriter.create<spirv::AddressOfOp>(loc, globalVar);
+ Value fmtStr = rewriter.create<spirv::BitcastOp>(
+ loc,
+ spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant),
+ globalPtr);
+
+ // Get printf arguments
+ auto argsRange = adaptor.getArgs();
+ SmallVector<Value, 4> printfArgs;
+ printfArgs.reserve(argsRange.size() + 1);
+ printfArgs.append(argsRange.begin(), argsRange.end());
+
+ rewriter.create<spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
+
+ rewriter.eraseOp(gpuPrintfOp);
----------------
kuhar wrote:
Could you add a comment that we need this because the spirv op produces a result that we need to discard?
https://github.com/llvm/llvm-project/pull/78510
More information about the Mlir-commits
mailing list