[Mlir-commits] [mlir] [mlir][spirv] Add gpu printf op lowering to spirv.CL.printf op (PR #78510)

Jakub Kuderski llvmlistbot at llvm.org
Fri Sep 27 11:59:34 PDT 2024


================
@@ -597,6 +606,124 @@ class GPUSubgroupReduceConversion final
   }
 };
 
+// Formulate a unique variable/constant name after
+// searching in the module for existing variable/constant names.
+// This is to avoid name collision with existing variables.
+// Example: printfMsg0, printfMsg1, printfMsg2, ...
+static std::string makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix) {
+  std::string name;
+  unsigned number = 0;
+
+  do {
+    name.clear();
+    name = (prefix + llvm::Twine(number++)).str();
+  } while (moduleOp.lookupSymbol(name));
+
+  return name;
+}
+
+/// Pattern to convert a gpu.printf op into a SPIR-V CLPrintf op.
+
+LogicalResult GPUPrintfConversion::matchAndRewrite(
+    gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+
+  Location loc = gpuPrintfOp.getLoc();
+
+  auto moduleOp = gpuPrintfOp->getParentOfType<spirv::ModuleOp>();
+  if (!moduleOp)
+    return failure();
+
+  // SPIR-V global variable is used to initialize printf
+  // format string value, if there are multiple printf messages,
+  // each global var needs to be created with a unique name.
+  std::string globalVarName = makeVarName(moduleOp, llvm::Twine("printfMsg"));
+  spirv::GlobalVariableOp globalVar;
+
+  IntegerType i8Type = rewriter.getI8Type();
+  IntegerType i32Type = rewriter.getI32Type();
+
+  // Each character of printf format string is
+  // stored as a spec constant. We need to create
+  // unique name for this spec constant like
+  // @printfMsg0_sc0, @printfMsg0_sc1, ... by searching in the module
+  // for existing spec constant names.
+  auto createSpecConstant = [&](unsigned value) {
+    auto attr = rewriter.getI8IntegerAttr(value);
+    std::string specCstName =
+        makeVarName(moduleOp, (llvm::Twine(globalVarName) + "_sc"));
+
+    return rewriter.create<spirv::SpecConstantOp>(
+        loc, rewriter.getStringAttr(specCstName), attr);
+  };
+  {
+    Operation *parent =
+        SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp());
+
+    ConversionPatternRewriter::InsertionGuard guard(rewriter);
+
+    Block &entryBlock = *parent->getRegion(0).begin();
+    rewriter.setInsertionPointToStart(
+        &entryBlock); // insertion point at module level
+
+    // Create Constituents with SpecConstant by scanning format string
+    // Each character of format string is stored as a spec constant
+    // and then these spec constants are used to create a
+    // SpecConstantCompositeOp.
+    llvm::SmallString<20> formatString(adaptor.getFormat());
+    formatString.push_back('\0'); // Null terminate for C.
+    SmallVector<Attribute, 4> constituents;
+    for (char c : formatString) {
+      spirv::SpecConstantOp cSpecConstantOp = createSpecConstant(c);
+      constituents.push_back(SymbolRefAttr::get(cSpecConstantOp));
+    }
+
+    // Create SpecConstantCompositeOp to initialize the global variable
+    size_t contentSize = constituents.size();
+    auto globalType = spirv::ArrayType::get(i8Type, contentSize);
+    spirv::SpecConstantCompositeOp specCstComposite;
+    // There will be one SpecConstantCompositeOp per printf message/global var,
+    // so no need do lookup for existing ones.
+    std::string specCstCompositeName =
+        (llvm::Twine(globalVarName) + "_scc").str();
+
+    specCstComposite = rewriter.create<spirv::SpecConstantCompositeOp>(
+        loc, TypeAttr::get(globalType),
+        rewriter.getStringAttr(specCstCompositeName),
+        rewriter.getArrayAttr(constituents));
+
+    auto ptrType = spirv::PointerType::get(
+        globalType, spirv::StorageClass::UniformConstant);
+
+    // Define a GlobalVarOp initialized using specialized constants
+    // that is used to specify the printf format string
+    // to be passed to the SPIRV CLPrintfOp.
+    globalVar = rewriter.create<spirv::GlobalVariableOp>(
+        loc, ptrType, globalVarName, FlatSymbolRefAttr::get(specCstComposite));
+
+    globalVar->setAttr("Constant", rewriter.getUnitAttr());
+  }
+  // Get SSA value of Global variable and create pointer to i8 to point to
+  // the format string.
+  Value globalPtr = rewriter.create<spirv::AddressOfOp>(loc, globalVar);
+  Value fmtStr = rewriter.create<spirv::BitcastOp>(
+      loc,
+      spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant),
+      globalPtr);
+
+  // Get printf arguments
----------------
kuhar wrote:

```suggestion
  // Get printf arguments.
```

https://github.com/llvm/llvm-project/pull/78510


More information about the Mlir-commits mailing list