[Mlir-commits] [mlir] 45d8759 - Emit nuw and nsw for mul and add when lowering to llvm.getelementptr (#140966)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 22 13:40:29 PDT 2025
Author: Peiyong Lin
Date: 2025-05-22T15:40:26-05:00
New Revision: 45d8759cbed0f216786729718608a8be72a505c6
URL: https://github.com/llvm/llvm-project/commit/45d8759cbed0f216786729718608a8be72a505c6
DIFF: https://github.com/llvm/llvm-project/commit/45d8759cbed0f216786729718608a8be72a505c6.diff
LOG: Emit nuw and nsw for mul and add when lowering to llvm.getelementptr (#140966)
Now that the GEP no wrap flags are known when lowering to
llvm.getelementptr, we can also emit nuw and nsw for the generated
llvm.mul and llvm.add when no unsigned wrap and no signed wrap are used
respectively.
fixes: iree-org/iree#20483
Signed-off-by: Lin, Peiyong <linpyong at gmail.com>
Added:
Modified:
mlir/lib/Conversion/LLVMCommon/Pattern.cpp
mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 8da850678878d..48fbcbcdbbde9 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -73,6 +73,15 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
Value base =
memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), type);
+ LLVM::IntegerOverflowFlags intOverflowFlags =
+ LLVM::IntegerOverflowFlags::none;
+ if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) {
+ intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw;
+ }
+ if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) {
+ intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw;
+ }
+
Type indexType = getIndexType();
Value index;
for (int i = 0, e = indices.size(); i < e; ++i) {
@@ -82,10 +91,12 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
ShapedType::isDynamic(strides[i])
? memRefDescriptor.stride(rewriter, loc, i)
: createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
- increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
+ increment = rewriter.create<LLVM::MulOp>(loc, increment, stride,
+ intOverflowFlags);
}
- index =
- index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
+ index = index ? rewriter.create<LLVM::AddOp>(loc, index, increment,
+ intOverflowFlags)
+ : increment;
}
Type elementPtrType = memRefDescriptor.getElementPtrType();
diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
index 9ca8bcd1491bc..543fdf5c26f5e 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
@@ -175,8 +175,8 @@ func.func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) {
// CHECK-DAG: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]]
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64
-// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64
+// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] overflow<nsw, nuw> : i64
+// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] overflow<nsw, nuw> : i64
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK-NEXT: llvm.load %[[addr]] : !llvm.ptr -> f32
%0 = memref.load %mixed[%i, %j] : memref<42x?xf32>
@@ -192,8 +192,8 @@ func.func @dynamic_load(%dynamic : memref<?x?xf32>, %i : index, %j : index) {
// CHECK-DAG: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]]
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64
-// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64
+// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] overflow<nsw, nuw> : i64
+// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] overflow<nsw, nuw> : i64
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK-NEXT: llvm.load %[[addr]] : !llvm.ptr -> f32
%0 = memref.load %dynamic[%i, %j] : memref<?x?xf32>
@@ -230,8 +230,8 @@ func.func @dynamic_store(%dynamic : memref<?x?xf32>, %i : index, %j : index, %va
// CHECK-DAG: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]]
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64
-// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64
+// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] overflow<nsw, nuw> : i64
+// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] overflow<nsw, nuw> : i64
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : f32, !llvm.ptr
memref.store %val, %dynamic[%i, %j] : memref<?x?xf32>
@@ -247,8 +247,8 @@ func.func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val :
// CHECK-DAG: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]]
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64
-// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64
+// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] overflow<nsw, nuw> : i64
+// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] overflow<nsw, nuw> : i64
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : f32, !llvm.ptr
memref.store %val, %mixed[%i, %j] : memref<42x?xf32>
diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
index b03ac2c20112b..040a27e160557 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
@@ -138,8 +138,8 @@ func.func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) {
// CHECK-DAG: %[[JJ:.*]] = builtin.unrealized_conversion_cast %[[J]]
// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[st0:.*]] = llvm.mlir.constant(42 : index) : i64
-// CHECK: %[[offI:.*]] = llvm.mul %[[II]], %[[st0]] : i64
-// CHECK: %[[off1:.*]] = llvm.add %[[offI]], %[[JJ]] : i64
+// CHECK: %[[offI:.*]] = llvm.mul %[[II]], %[[st0]] overflow<nsw, nuw> : i64
+// CHECK: %[[off1:.*]] = llvm.add %[[offI]], %[[JJ]] overflow<nsw, nuw> : i64
// CHECK: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: llvm.load %[[addr]] : !llvm.ptr -> f32
%0 = memref.load %static[%i, %j] : memref<10x42xf32>
@@ -166,8 +166,8 @@ func.func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %va
// CHECK-DAG: %[[JJ:.*]] = builtin.unrealized_conversion_cast %[[J]]
// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[st0:.*]] = llvm.mlir.constant(42 : index) : i64
-// CHECK: %[[offI:.*]] = llvm.mul %[[II]], %[[st0]] : i64
-// CHECK: %[[off1:.*]] = llvm.add %[[offI]], %[[JJ]] : i64
+// CHECK: %[[offI:.*]] = llvm.mul %[[II]], %[[st0]] overflow<nsw, nuw> : i64
+// CHECK: %[[off1:.*]] = llvm.add %[[offI]], %[[JJ]] overflow<nsw, nuw> : i64
// CHECK: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: llvm.store %{{.*}}, %[[addr]] : f32, !llvm.ptr
More information about the Mlir-commits
mailing list