[Mlir-commits] [mlir] [MLIR][GPU-LLVM] Convert `gpu.func` to `llvm.func` (PR #101664)
Jakub Kuderski
llvmlistbot at llvm.org
Fri Aug 2 08:36:12 PDT 2024
================
@@ -125,24 +158,49 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
rewriter.setInsertionPointToStart(&gpuFuncOp.front());
unsigned numProperArguments = gpuFuncOp.getNumArguments();
- for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
- auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
- global.getAddrSpace());
- Value address = rewriter.create<LLVM::AddressOfOp>(
- loc, ptrType, global.getSymNameAttr());
- Value memory =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getType(), address,
- ArrayRef<LLVM::GEPArg>{0, 0});
-
- // Build a memref descriptor pointing to the buffer to plug with the
- // existing memref infrastructure. This may use more registers than
- // otherwise necessary given that memref sizes are fixed, but we can try
- // and canonicalize that away later.
- Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx];
- auto type = cast<MemRefType>(attribution.getType());
- auto descr = MemRefDescriptor::fromStaticShape(
- rewriter, loc, *getTypeConverter(), type, memory);
- signatureConversion.remapInput(numProperArguments + idx, descr);
+ if (encodeWorkgroupAttributionsAsArguments) {
+ unsigned numAttributions = gpuFuncOp.getNumWorkgroupAttributions();
+ assert(numProperArguments >= numAttributions &&
+ "Expecting attributions to be encoded as arguments already");
+
+ // Arguments encoding workgroup attributions will be in positions
+ // [numProperArguments, numProperArguments+numAttributions)
+ ArrayRef<BlockArgument> attributionArguments =
+ gpuFuncOp.getArguments().slice(numProperArguments - numAttributions,
+ numAttributions);
+ for (auto [idx, vals] : llvm::enumerate(llvm::zip_equal(
+ gpuFuncOp.getWorkgroupAttributions(), attributionArguments))) {
+ auto [attribution, arg] = vals;
+ auto type = cast<MemRefType>(attribution.getType());
+
+ // Arguments are of llvm.ptr type and attributions are of memref type:
+ // we need to wrap them in memref descriptors.
+ Value descr = MemRefDescriptor::fromStaticShape(
+ rewriter, loc, *getTypeConverter(), type, arg);
+
+ // And remap the arguments
+ signatureConversion.remapInput(numProperArguments + idx, descr);
+ }
+ } else {
+ for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
----------------
kuhar wrote:
```suggestion
for (auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
```
this const doesn't do anything
https://github.com/llvm/llvm-project/pull/101664
More information about the Mlir-commits
mailing list