[Mlir-commits] [mlir] 45d8759 - Emit nuw and nsw for mul and add when lowering to llvm.getelementptr (#140966)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 22 13:40:29 PDT 2025


Author: Peiyong Lin
Date: 2025-05-22T15:40:26-05:00
New Revision: 45d8759cbed0f216786729718608a8be72a505c6

URL: https://github.com/llvm/llvm-project/commit/45d8759cbed0f216786729718608a8be72a505c6
DIFF: https://github.com/llvm/llvm-project/commit/45d8759cbed0f216786729718608a8be72a505c6.diff

LOG: Emit nuw and nsw for mul and add when lowering to llvm.getelementptr (#140966)

Now that the GEP no wrap flags are known when lowering to
llvm.getelementptr, we can also emit nuw and nsw for the generated
llvm.mul and llvm.add when no unsigned wrap and no signed wrap are used
respectively.

fixes: iree-org/iree#20483

Signed-off-by: Lin, Peiyong <linpyong at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Conversion/LLVMCommon/Pattern.cpp
    mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
    mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 8da850678878d..48fbcbcdbbde9 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -73,6 +73,15 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
   Value base =
       memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), type);
 
+  LLVM::IntegerOverflowFlags intOverflowFlags =
+      LLVM::IntegerOverflowFlags::none;
+  if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) {
+    intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw;
+  }
+  if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) {
+    intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw;
+  }
+
   Type indexType = getIndexType();
   Value index;
   for (int i = 0, e = indices.size(); i < e; ++i) {
@@ -82,10 +91,12 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
           ShapedType::isDynamic(strides[i])
               ? memRefDescriptor.stride(rewriter, loc, i)
               : createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
-      increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
+      increment = rewriter.create<LLVM::MulOp>(loc, increment, stride,
+                                               intOverflowFlags);
     }
-    index =
-        index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
+    index = index ? rewriter.create<LLVM::AddOp>(loc, index, increment,
+                                                 intOverflowFlags)
+                  : increment;
   }
 
   Type elementPtrType = memRefDescriptor.getElementPtrType();

diff  --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
index 9ca8bcd1491bc..543fdf5c26f5e 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
@@ -175,8 +175,8 @@ func.func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) {
 //   CHECK-DAG:  %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]]
 //       CHECK:  %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 //  CHECK-NEXT:  %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK-NEXT:  %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64
-//  CHECK-NEXT:  %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64
+//  CHECK-NEXT:  %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] overflow<nsw, nuw> : i64
+//  CHECK-NEXT:  %[[off1:.*]] = llvm.add %[[offI]], %[[J]] overflow<nsw, nuw> : i64
 //  CHECK-NEXT:  %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
 //  CHECK-NEXT:  llvm.load %[[addr]] : !llvm.ptr -> f32
   %0 = memref.load %mixed[%i, %j] : memref<42x?xf32>
@@ -192,8 +192,8 @@ func.func @dynamic_load(%dynamic : memref<?x?xf32>, %i : index, %j : index) {
 //   CHECK-DAG:  %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]]
 //       CHECK:  %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 //  CHECK-NEXT:  %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK-NEXT:  %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64
-//  CHECK-NEXT:  %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64
+//  CHECK-NEXT:  %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] overflow<nsw, nuw> : i64
+//  CHECK-NEXT:  %[[off1:.*]] = llvm.add %[[offI]], %[[J]] overflow<nsw, nuw> : i64
 //  CHECK-NEXT:  %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
 //  CHECK-NEXT:  llvm.load %[[addr]] : !llvm.ptr -> f32
   %0 = memref.load %dynamic[%i, %j] : memref<?x?xf32>
@@ -230,8 +230,8 @@ func.func @dynamic_store(%dynamic : memref<?x?xf32>, %i : index, %j : index, %va
 //   CHECK-DAG:  %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]]
 //       CHECK:  %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 //  CHECK-NEXT:  %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK-NEXT:  %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64
-//  CHECK-NEXT:  %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64
+//  CHECK-NEXT:  %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] overflow<nsw, nuw> : i64
+//  CHECK-NEXT:  %[[off1:.*]] = llvm.add %[[offI]], %[[J]] overflow<nsw, nuw> : i64
 //  CHECK-NEXT:  %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
 //  CHECK-NEXT:  llvm.store %{{.*}}, %[[addr]] : f32, !llvm.ptr
   memref.store %val, %dynamic[%i, %j] : memref<?x?xf32>
@@ -247,8 +247,8 @@ func.func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val :
 //   CHECK-DAG:  %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]]
 //       CHECK:  %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 //  CHECK-NEXT:  %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK-NEXT:  %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64
-//  CHECK-NEXT:  %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64
+//  CHECK-NEXT:  %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] overflow<nsw, nuw> : i64
+//  CHECK-NEXT:  %[[off1:.*]] = llvm.add %[[offI]], %[[J]] overflow<nsw, nuw> : i64
 //  CHECK-NEXT:  %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
 //  CHECK-NEXT:  llvm.store %{{.*}}, %[[addr]] : f32, !llvm.ptr
   memref.store %val, %mixed[%i, %j] : memref<42x?xf32>

diff  --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
index b03ac2c20112b..040a27e160557 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
@@ -138,8 +138,8 @@ func.func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) {
 // CHECK-DAG:  %[[JJ:.*]] = builtin.unrealized_conversion_cast %[[J]]
 // CHECK:  %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:  %[[st0:.*]] = llvm.mlir.constant(42 : index) : i64
-// CHECK:  %[[offI:.*]] = llvm.mul %[[II]], %[[st0]] : i64
-// CHECK:  %[[off1:.*]] = llvm.add %[[offI]], %[[JJ]] : i64
+// CHECK:  %[[offI:.*]] = llvm.mul %[[II]], %[[st0]] overflow<nsw, nuw> : i64
+// CHECK:  %[[off1:.*]] = llvm.add %[[offI]], %[[JJ]] overflow<nsw, nuw> : i64
 // CHECK:  %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
 // CHECK:  llvm.load %[[addr]] : !llvm.ptr -> f32
   %0 = memref.load %static[%i, %j] : memref<10x42xf32>
@@ -166,8 +166,8 @@ func.func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %va
 // CHECK-DAG: %[[JJ:.*]] = builtin.unrealized_conversion_cast %[[J]]
 // CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[st0:.*]] = llvm.mlir.constant(42 : index) : i64
-// CHECK: %[[offI:.*]] = llvm.mul %[[II]], %[[st0]] : i64
-// CHECK: %[[off1:.*]] = llvm.add %[[offI]], %[[JJ]] : i64
+// CHECK: %[[offI:.*]] = llvm.mul %[[II]], %[[st0]] overflow<nsw, nuw> : i64
+// CHECK: %[[off1:.*]] = llvm.add %[[offI]], %[[JJ]] overflow<nsw, nuw> : i64
 // CHECK: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
 // CHECK: llvm.store %{{.*}}, %[[addr]] : f32, !llvm.ptr
 


        


More information about the Mlir-commits mailing list