[Mlir-commits] [mlir] [MLIR][GPU-LLVM] Convert `gpu.func` to `llvm.func` (PR #101664)

Victor Perez llvmlistbot at llvm.org
Fri Aug 2 05:49:12 PDT 2024


================
@@ -125,24 +158,49 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
     rewriter.setInsertionPointToStart(&gpuFuncOp.front());
     unsigned numProperArguments = gpuFuncOp.getNumArguments();
 
-    for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
-      auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
-                                                global.getAddrSpace());
-      Value address = rewriter.create<LLVM::AddressOfOp>(
-          loc, ptrType, global.getSymNameAttr());
-      Value memory =
-          rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getType(), address,
-                                       ArrayRef<LLVM::GEPArg>{0, 0});
-
-      // Build a memref descriptor pointing to the buffer to plug with the
-      // existing memref infrastructure. This may use more registers than
-      // otherwise necessary given that memref sizes are fixed, but we can try
-      // and canonicalize that away later.
-      Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx];
-      auto type = cast<MemRefType>(attribution.getType());
-      auto descr = MemRefDescriptor::fromStaticShape(
-          rewriter, loc, *getTypeConverter(), type, memory);
-      signatureConversion.remapInput(numProperArguments + idx, descr);
+    if (encodeWorkgroupAttributionsAsArguments) {
+      unsigned numAttributions = gpuFuncOp.getNumWorkgroupAttributions();
+      assert(numProperArguments >= numAttributions &&
+             "Expecting attributions to be encoded as arguments already");
+
+      // Arguments encoding workgroup attributions will be in positions
+      // [numProperArguments, numProperArguments+numAttributions)
+      ArrayRef<BlockArgument> attributionArguments =
+          gpuFuncOp.getArguments().slice(numProperArguments - numAttributions,
+                                         numAttributions);
+      for (auto [idx, vals] : llvm::enumerate(llvm::zip_equal(
+               gpuFuncOp.getWorkgroupAttributions(), attributionArguments))) {
+        auto [attribution, arg] = vals;
+        auto type = cast<MemRefType>(attribution.getType());
+
+        // Arguments are of llvm.ptr type and attributions are of memref type:
+        // we need to wrap them in memref descriptors.
+        Value descr = MemRefDescriptor::fromStaticShape(
+            rewriter, loc, *getTypeConverter(), type, arg);
+
+        // And remap the arguments
+        signatureConversion.remapInput(numProperArguments + idx, descr);
+      }
+    } else {
+      for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
+        auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
+                                                  global.getAddrSpace());
+        Value address = rewriter.create<LLVM::AddressOfOp>(
+            loc, ptrType, global.getSymNameAttr());
+        Value memory =
+            rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getType(),
+                                         address, ArrayRef<LLVM::GEPArg>{0, 0});
+
+        // Build a memref descriptor pointing to the buffer to plug with the
+        // existing memref infrastructure. This may use more registers than
+        // otherwise necessary given that memref sizes are fixed, but we can try
+        // and canonicalize that away later.
+        Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx];
+        auto type = cast<MemRefType>(attribution.getType());
+        auto descr = MemRefDescriptor::fromStaticShape(
+            rewriter, loc, *getTypeConverter(), type, memory);
+        signatureConversion.remapInput(numProperArguments + idx, descr);
+      }
----------------
victor-eds wrote:

Original code

https://github.com/llvm/llvm-project/pull/101664


More information about the Mlir-commits mailing list