[Mlir-commits] [mlir] [mlir][gpu] Introduce `gpu.dynamic_shared_memory` Op (PR #71546)

Mehdi Amini llvmlistbot at llvm.org
Thu Nov 9 06:30:21 PST 2023


================
@@ -554,6 +557,101 @@ static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
   return IntegerAttr::get(IntegerType::get(ctx, 64), space);
 }
 
+/// Generates a symbol with 0-sized array type for dynamic shared memory usage,
+/// or uses existing symbol.
+LLVM::GlobalOp
+getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
+                             gpu::DynamicSharedMemoryOp op,
+                             const LLVMTypeConverter *typeConverter,
+                             MemRefType memrefType, unsigned alignmentBit) {
+  LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
+  assert(funcOp && "cannot find llvm.func op");
+
+  gpu::GPUModuleOp moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>();
+  assert(moduleOp && "cannot find gpu.module op");
+
+  uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
+
+  LLVM::GlobalOp existingGlobalOp;
+  moduleOp->walk([&](LLVM::GlobalOp globalOp) {
+    if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
+      if (arrayType.getNumElements() == 0 &&
+          globalOp.getAlignment().value_or(0) == alignmentByte) {
+        existingGlobalOp = globalOp;
+      }
+    }
+  });
+  if (existingGlobalOp)
+    return existingGlobalOp;
+
+  // Find unique name
+  int index = 0;
+  std::string symName, name = llvm::formatv("__shmem_{0}", funcOp.getSymName());
+  WalkResult walkResult;
+  do {
+    symName = llvm::formatv("{0}_{1}", name, index++);
+    walkResult = moduleOp->walk([&](LLVM::GlobalOp globalOp) {
+      if (globalOp.getSymName() == symName)
+        return WalkResult::interrupt();
+      return WalkResult::advance();
+    });
+  } while (walkResult.wasInterrupted());
+
+  // Generate a new global op
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(&moduleOp.front());
+
+  auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
+      typeConverter->convertType(memrefType.getElementType()), 0);
+
+  return rewriter.create<LLVM::GlobalOp>(
+      funcOp->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
+      LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
+      mlir::gpu::GPUMemorySpace::kSharedMemorySpace);
+}
+
+LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
+    gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  MemRefType memrefType = op.getResultMemref().getType();
+  Type elementType = typeConverter->convertType(memrefType.getElementType());
+  assert(memrefType && "memref is not valid");
----------------
joker-eph wrote:

You're asserting on Memref type after dereferencing it: this does not seem useful?

https://github.com/llvm/llvm-project/pull/71546


More information about the Mlir-commits mailing list