[Mlir-commits] [mlir] ce6ef99 - [mlir] Remove convertible identity restriction for memref.atomic_rmw to LLVM (#72262)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 14 19:45:18 PST 2023


Author: Max191
Date: 2023-11-14T22:45:14-05:00
New Revision: ce6ef990fee750ed32af180c4c548c5e6677f27d

URL: https://github.com/llvm/llvm-project/commit/ce6ef990fee750ed32af180c4c548c5e6677f27d
DIFF: https://github.com/llvm/llvm-project/commit/ce6ef990fee750ed32af180c4c548c5e6677f27d.diff

LOG: [mlir] Remove convertible identity restriction for memref.atomic_rmw to LLVM (#72262)

memref.atomic_rmw will fail to convert for memref types that have an offset because they do not have identity maps. This restriction is overly conservative, so this changes the restriction to only strided memref types.

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 4ae6e865f2a49e0..2bfca303b5fd489 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/MathExtras.h"
@@ -1562,12 +1563,14 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
   LogicalResult
   matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (failed(match(atomicOp)))
-      return failure();
     auto maybeKind = matchSimpleAtomicOp(atomicOp);
     if (!maybeKind)
       return failure();
     auto memRefType = atomicOp.getMemRefType();
+    SmallVector<int64_t> strides;
+    int64_t offset;
+    if (failed(getStridesAndOffset(memRefType, strides, offset)))
+      return failure();
     auto dataPtr =
         getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
                              adaptor.getIndices(), rewriter);

diff  --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 3b3a51d609be972..37999d6fc14ad19 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -400,6 +400,24 @@ func.func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fv
 
 // -----
 
+func.func @atomic_rmw_with_offset(%I : memref<10xi32, strided<[1], offset: 5>>, %ival : i32, %i : index) {
+  memref.atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32, strided<[1], offset: 5>>) -> i32
+  return
+}
+// CHECK-LABEL:  func @atomic_rmw_with_offset
+// CHECK-SAME:   %[[ARG0:.+]]: memref<10xi32, strided<[1], offset: 5>>
+// CHECK-SAME:   %[[ARG1:.+]]: i32
+// CHECK-SAME:   %[[ARG2:.+]]: index
+// CHECK:        %[[MEMREF_STRUCT:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<10xi32, strided<[1], offset: 5>> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK:        %[[INDEX:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i64
+// CHECK:        %[[BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_STRUCT]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+// CHECK:        %[[OFFSET:.+]] = llvm.mlir.constant(5 : index) : i64
+// CHECK:        %[[OFFSET_PTR:.+]] = llvm.getelementptr %[[BASE_PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
+// CHECK:        %[[PTR:.+]] = llvm.getelementptr %[[OFFSET_PTR]][%[[INDEX]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
+// CHECK:        llvm.atomicrmw _and %[[PTR]], %[[ARG1]] acq_rel
+
+// -----
+
 // CHECK-LABEL: func @generic_atomic_rmw
 func.func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) {
   %x = memref.generic_atomic_rmw %I[%i] : memref<10xi32> {


        


More information about the Mlir-commits mailing list