[Mlir-commits] [mlir] [mlir][AMDGPU] Fix raw buffer ptr ops lowering (PR #122293)

Ivan Butygin llvmlistbot at llvm.org
Fri Jan 10 11:34:10 PST 2025


================
@@ -167,28 +197,36 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
 
     MemRefDescriptor memrefDescriptor(memref);
 
-    Value ptr = memrefDescriptor.alignedPtr(rewriter, loc);
+    Value ptr = memrefDescriptor.bufferPtr(
+        rewriter, loc, *this->getTypeConverter(), memrefType);
     // The stride value is always 0 for raw buffers. This also disables
     // swizling.
     Value stride = rewriter.create<LLVM::ConstantOp>(
-        loc, llvmI16, rewriter.getI16IntegerAttr(0));
+        loc, i16, rewriter.getI16IntegerAttr(0));
+    // Get the number of elements.
     Value numRecords;
-    if (memrefType.hasStaticShape() && memrefType.getLayout().isIdentity()) {
-      numRecords = createI32Constant(
-          rewriter, loc,
-          static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth));
+    if (memrefType.hasStaticShape() &&
+        !llvm::any_of(strides, ShapedType::isDynamic)) {
+      int64_t size = memrefType.getRank() == 0 ? 1 : 0;
+      ArrayRef<int64_t> shape = memrefType.getShape();
+      for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
+        size = std::max(shape[i] * strides[i], size);
+      size = size * elementByteWidth;
+      assert(size < std::numeric_limits<uint32_t>::max() &&
+             "the memref buffer is too large");
+      numRecords = createI32Constant(rewriter, loc, static_cast<int32_t>(size));
     } else {
       Value maxIndex;
       for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
-        Value size = toI32(memrefDescriptor.size(rewriter, loc, i));
-        Value stride = toI32(memrefDescriptor.stride(rewriter, loc, i));
-        stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst);
-        Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
-        maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex,
-                                                               maxThisDim)
-                            : maxThisDim;
+        Value maxThisDim = rewriter.create<LLVM::MulOp>(
+            loc, memrefDescriptor.size(rewriter, loc, i),
+            memrefDescriptor.stride(rewriter, loc, i));
----------------
Hardcode84 wrote:

Order in which `.stride` and `.size` are evaluated is unspecified and can result in different IR being generated by MLIR built different compilers.

https://github.com/llvm/llvm-project/pull/122293


More information about the Mlir-commits mailing list