[Mlir-commits] [mlir] [mlir][gpu] Introduce `gpu.dynamic_shared_memory` Op (PR #71546)

Guray Ozen llvmlistbot at llvm.org
Mon Nov 13 02:24:39 PST 2023


================
@@ -554,6 +558,105 @@ static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
   return IntegerAttr::get(IntegerType::get(ctx, 64), space);
 }
 
+/// Generates a symbol with 0-sized array type for dynamic shared memory usage,
+/// or uses existing symbol.
+template <typename ModuleTy>
+LLVM::GlobalOp getDynamicSharedMemorySymbol(
+    ConversionPatternRewriter &rewriter, ModuleTy moduleOp,
+    gpu::DynamicSharedMemoryOp op, const LLVMTypeConverter *typeConverter,
+    MemRefType memrefType, unsigned alignmentBit, unsigned addressSpace) {
+
+  uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
+
+  // Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of
+  // LLVM::GlobalOp is suitable for shared memory, return it.
+  llvm::StringSet<> existingGlobalNames;
+  for (auto globalOp :
+       moduleOp->getRegion(0).front().template getOps<LLVM::GlobalOp>()) {
+    existingGlobalNames.insert(globalOp.getSymName());
+    if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
+      if (globalOp.getAddrSpace() == addressSpace &&
+          arrayType.getNumElements() == 0 &&
+          globalOp.getAlignment().value_or(0) == alignmentByte) {
+        return globalOp;
+      }
+    }
+  }
+
+  // Step 2. Find a unique symbol name
+  unsigned uniquingCounter = 0;
+  SmallString<128> symName = SymbolTable::generateSymbolName<128>(
+      "__dynamic_shmem_",
+      [&](StringRef candidate) {
+        return existingGlobalNames.contains(candidate);
+      },
+      uniquingCounter);
+
+  // Step 3. Generate a global op
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(&moduleOp.front());
+
+  auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
+      typeConverter->convertType(memrefType.getElementType()), 0);
+
+  return rewriter.create<LLVM::GlobalOp>(
+      op->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
+      LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
+      mlir::gpu::GPUMemorySpace::kSharedMemorySpace);
+}
+
+LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
+    gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  MemRefType memrefType = op.getResultMemref().getType();
+  Type elementType = typeConverter->convertType(memrefType.getElementType());
+
+  // Step 1: Generate a memref<0xi8> type
+  MemRefLayoutAttrInterface layout = {};
+  auto memrefType0sz =
+      MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());
+
+  // Step 2: Generate a global symbol or existing for the dynamic shared
+  // memory with memref<0xi8> type
+  LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
+  LLVM::GlobalOp shmemOp = {};
+  if (gpu::GPUModuleOp moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>()) {
+    shmemOp =
+        getDynamicSharedMemorySymbol(rewriter, moduleOp, op, getTypeConverter(),
+                                     memrefType0sz, alignmentBit, addressSpace);
+  } else if (ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>()) {
+    shmemOp =
+        getDynamicSharedMemorySymbol(rewriter, moduleOp, op, getTypeConverter(),
+                                     memrefType0sz, alignmentBit, addressSpace);
+  }
----------------
grypp wrote:

I find it now as you suggested 
```
Operation *moduleOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
```

https://github.com/llvm/llvm-project/pull/71546


More information about the Mlir-commits mailing list