[Mlir-commits] [mlir] [mlir][spirv] Add gpu printf op lowering to spirv.CL.printf op (PR #78510)

Jakub Kuderski llvmlistbot at llvm.org
Wed Jan 17 21:08:42 PST 2024


================
@@ -607,6 +616,108 @@ class GPUSubgroupReduceConversion final
   }
 };
 
+LogicalResult GPUPrintfConversion::matchAndRewrite(
+    gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+
+  auto loc = gpuPrintfOp.getLoc();
+
+  auto funcOp =
+      rewriter.getBlock()->getParent()->getParentOfType<mlir::spirv::FuncOp>();
+
+  auto moduleOp = funcOp->getParentOfType<mlir::spirv::ModuleOp>();
+
+  const char formatStringPrefix[] = "printfMsg";
+  unsigned stringNumber = 0;
+  mlir::SmallString<16> globalVarName;
+  mlir::spirv::GlobalVariableOp globalVar;
+
+  // formulate spirv global variable name
+  do {
+    globalVarName.clear();
+    (formatStringPrefix + llvm::Twine(stringNumber++))
+        .toStringRef(globalVarName);
+  } while (moduleOp.lookupSymbol(globalVarName));
+
+  auto i8Type = rewriter.getI8Type();
+  auto i32Type = rewriter.getI32Type();
+
+  unsigned scNum = 0;
+  auto createSpecConstant = [&](unsigned value) {
+    auto attr = rewriter.getI8IntegerAttr(value);
+    mlir::SmallString<16> specCstName;
+    (llvm::Twine(globalVarName) + "_sc" + llvm::Twine(scNum++))
+        .toStringRef(specCstName);
+
+    return rewriter.create<mlir::spirv::SpecConstantOp>(
+        loc, rewriter.getStringAttr(specCstName), attr);
+  };
+
+  // define GlobalVarOp with printf format string using SpecConstants
+  // and make composite of SpecConstants
+  {
+    mlir::Operation *parent =
+        mlir::SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp());
+
+    mlir::ConversionPatternRewriter::InsertionGuard guard(rewriter);
+
+    mlir::Block &entryBlock = *parent->getRegion(0).begin();
+    rewriter.setInsertionPointToStart(
+        &entryBlock); // insertion point at module level
+
+    // Create Constituents with SpecConstant to construct
+    // SpecConstantCompositeOp
+    llvm::SmallString<20> formatString(gpuPrintfOp.getFormat());
+    formatString.push_back('\0'); // Null terminate for C
+    mlir::SmallVector<mlir::Attribute, 4> constituents;
+    for (auto c : formatString) {
+      auto cSpecConstantOp = createSpecConstant(c);
+      constituents.push_back(mlir::SymbolRefAttr::get(cSpecConstantOp));
+    }
+
+    // Create specialization constant composite defined via spirv.SpecConstant
+    size_t contentSize = constituents.size();
+    auto globalType = mlir::spirv::ArrayType::get(i8Type, contentSize);
+    mlir::spirv::SpecConstantCompositeOp specCstComposite;
+    mlir::SmallString<16> specCstCompositeName;
+    (llvm::Twine(globalVarName) + "_scc").toStringRef(specCstCompositeName);
+    specCstComposite = rewriter.create<mlir::spirv::SpecConstantCompositeOp>(
+        loc, mlir::TypeAttr::get(globalType),
+        rewriter.getStringAttr(specCstCompositeName),
+        rewriter.getArrayAttr(constituents));
+
+    // Define GlobalVariable initialized from Constant Composite
+    globalVar = rewriter.create<mlir::spirv::GlobalVariableOp>(
+        loc,
+        mlir::spirv::PointerType::get(
+            globalType, mlir::spirv::StorageClass::UniformConstant),
+        globalVarName, mlir::FlatSymbolRefAttr::get(specCstComposite));
+
+    globalVar->setAttr("Constant", rewriter.getUnitAttr());
+  }
+
+  // Get SSA value of Global variable
+  mlir::Value globalPtr =
+      rewriter.create<mlir::spirv::AddressOfOp>(loc, globalVar);
+  mlir::Value fmtStr = rewriter.create<mlir::spirv::BitcastOp>(
+      loc,
+      mlir::spirv::PointerType::get(i8Type,
+                                    mlir::spirv::StorageClass::UniformConstant),
+      globalPtr);
+
+  // Get printf arguments
+  auto argsRange = adaptor.getArgs();
+  mlir::SmallVector<mlir::Value, 4> printfArgs;
+  printfArgs.reserve(argsRange.size() + 1);
+  printfArgs.append(argsRange.begin(), argsRange.end());
+
+  rewriter.create<mlir::spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
+
+  rewriter.eraseOp(gpuPrintfOp);
----------------
kuhar wrote:

Use `rewriter.replaceWithNewOp`. 

https://github.com/llvm/llvm-project/pull/78510


More information about the Mlir-commits mailing list