[Mlir-commits] [mlir] [mlir][Affine] Add nuw/nsw to lowering of affine ops. (PR #121535)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 2 17:50:57 PST 2025


https://github.com/MaheshRavishankar created https://github.com/llvm/llvm-project/pull/121535

Since index operations have no set bitwidth, it is ill-defined to use signed/unsigned wrapping behavior. The corollary to which is that it is always safe to add nsw/nuw to lowering of affine ops.

Also add a folder to fold `div(s|u)i (mul (a, v), v) -> a`

>From eae38f077cd7d908907ac5d7289ee41f30ee9d32 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mravisha at amd.com>
Date: Thu, 2 Jan 2025 16:37:20 -0600
Subject: [PATCH] [mlir][Affine] Add nuw/nsw to lowering of affine ops.

Since index operations have no set bitwidth, it is ill-defined to use
signed/unsigned wrapping behavior. The corollary to which is that it
is always safe to add nsw/nuw to lowering of affine ops.

Also add a folder to fold `div(s|u)i (mul (a, v), v) -> a`

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
 mlir/lib/Dialect/Affine/Utils/Utils.cpp       | 14 ++-
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 38 ++++++--
 .../lower-affine-to-vector.mlir               |  8 +-
 .../AffineToStandard/lower-affine.mlir        | 96 +++++++++----------
 .../expand-then-convert-to-llvm.mlir          | 28 +++---
 mlir/test/Dialect/Arith/canonicalize.mlir     | 62 ++++++++++++
 .../lower-to-llvm-e2e-with-target-tag.mlir    |  4 +-
 ...lvm-e2e-with-top-level-named-sequence.mlir |  4 +-
 .../Linalg/vectorization-with-patterns.mlir   |  8 +-
 .../vectorize-tensor-extract-masked.mlir      | 10 +-
 .../Linalg/vectorize-tensor-extract.mlir      |  8 +-
 11 files changed, 185 insertions(+), 95 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 4d3ead20fb5cd3..8328b6430e182b 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -56,7 +56,9 @@ class AffineApplyExpander
     auto rhs = visit(expr.getRHS());
     if (!lhs || !rhs)
       return nullptr;
-    auto op = builder.create<OpTy>(loc, lhs, rhs);
+    auto op = builder.create<OpTy>(loc, lhs, rhs,
+                                   arith::IntegerOverflowFlags::nsw |
+                                       arith::IntegerOverflowFlags::nuw);
     return op.getResult();
   }
 
@@ -93,8 +95,9 @@ class AffineApplyExpander
     Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
     Value isRemainderNegative = builder.create<arith::CmpIOp>(
         loc, arith::CmpIPredicate::slt, remainder, zeroCst);
-    Value correctedRemainder =
-        builder.create<arith::AddIOp>(loc, remainder, rhs);
+    Value correctedRemainder = builder.create<arith::AddIOp>(
+        loc, remainder, rhs,
+        arith::IntegerOverflowFlags::nsw | arith::IntegerOverflowFlags::nuw);
     Value result = builder.create<arith::SelectOp>(
         loc, isRemainderNegative, correctedRemainder, remainder);
     return result;
@@ -178,8 +181,9 @@ class AffineApplyExpander
     Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
     Value negatedQuotient =
         builder.create<arith::SubIOp>(loc, zeroCst, quotient);
-    Value incrementedQuotient =
-        builder.create<arith::AddIOp>(loc, quotient, oneCst);
+    Value incrementedQuotient = builder.create<arith::AddIOp>(
+        loc, quotient, oneCst,
+        arith::IntegerOverflowFlags::nsw | arith::IntegerOverflowFlags::nuw);
     Value result = builder.create<arith::SelectOp>(
         loc, nonPositive, negatedQuotient, incrementedQuotient);
     return result;
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d8b314a3fa43c0..c43f182f40947b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -596,6 +596,18 @@ OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
                                                  return a.udiv(b);
                                                });
 
+  // divui (muli (a, v), v) -> a
+  if (auto muliOp = getLhs().getDefiningOp<arith::MulIOp>()) {
+    if (muliOp.hasNoUnsignedWrap()) {
+      if (getRhs() == muliOp.getRhs()) {
+        return muliOp.getLhs();
+      }
+      if (getRhs() == muliOp.getLhs()) {
+        return muliOp.getRhs();
+      }
+    }
+  }
+
   return div0 ? Attribute() : result;
 }
 
@@ -632,6 +644,18 @@ OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
         return a.sdiv_ov(b, overflowOrDiv0);
       });
 
+  // divsi (muli (a, v), v) -> a
+  if (auto muliOp = getLhs().getDefiningOp<arith::MulIOp>()) {
+    if (muliOp.hasNoSignedWrap()) {
+      if (getRhs() == muliOp.getRhs()) {
+        return muliOp.getLhs();
+      }
+      if (getRhs() == muliOp.getLhs()) {
+        return muliOp.getRhs();
+      }
+    }
+  }
+
   return overflowOrDiv0 ? Attribute() : result;
 }
 
@@ -2341,12 +2365,12 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
 
   // Constant-fold constant operands over non-splat constant condition.
   // select %cst_vec, %cst0, %cst1 => %cst2
-  if (auto cond =
-          llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
-    if (auto lhs =
-            llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
-      if (auto rhs =
-              llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
+  if (auto cond = llvm::dyn_cast_if_present<DenseElementsAttr>(
+          adaptor.getCondition())) {
+    if (auto lhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
+            adaptor.getTrueValue())) {
+      if (auto rhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
+              adaptor.getFalseValue())) {
         SmallVector<Attribute> results;
         results.reserve(static_cast<size_t>(cond.getNumElements()));
         auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
@@ -2614,7 +2638,7 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
     return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
   case AtomicRMWKind::minimumf:
     return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
-   case AtomicRMWKind::maxnumf:
+  case AtomicRMWKind::maxnumf:
     return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
   case AtomicRMWKind::minnumf:
     return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir
index 58580a194df0c7..e9915ea550a5d8 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir
@@ -9,7 +9,7 @@ func.func @affine_vector_load(%arg0 : index) {
 // CHECK:       %[[buf:.*]] = memref.alloc
 // CHECK:       %[[a:.*]] = arith.addi %{{.*}}, %{{.*}} : index
 // CHECK-NEXT:  %[[c7:.*]] = arith.constant 7 : index
-// CHECK-NEXT:  %[[b:.*]] = arith.addi %[[a]], %[[c7]] : index
+// CHECK-NEXT:  %[[b:.*]] = arith.addi %[[a]], %[[c7]] overflow<nsw, nuw> : index
 // CHECK-NEXT:  vector.load %[[buf]][%[[b]]] : memref<100xf32>, vector<8xf32>
   return
 }
@@ -26,10 +26,10 @@ func.func @affine_vector_store(%arg0 : index) {
 // CHECK:       %[[buf:.*]] = memref.alloc
 // CHECK:       %[[val:.*]] = arith.constant dense
 // CHECK:       %[[c_1:.*]] = arith.constant -1 : index
-// CHECK-NEXT:  %[[a:.*]] = arith.muli %arg0, %[[c_1]] : index
-// CHECK-NEXT:  %[[b:.*]] = arith.addi %{{.*}}, %[[a]] : index
+// CHECK-NEXT:  %[[a:.*]] = arith.muli %arg0, %[[c_1]] overflow<nsw, nuw> : index
+// CHECK-NEXT:  %[[b:.*]] = arith.addi %{{.*}}, %[[a]] overflow<nsw, nuw> : index
 // CHECK-NEXT:  %[[c7:.*]] = arith.constant 7 : index
-// CHECK-NEXT:  %[[c:.*]] = arith.addi %[[b]], %[[c7]] : index
+// CHECK-NEXT:  %[[c:.*]] = arith.addi %[[b]], %[[c7]] overflow<nsw, nuw> : index
 // CHECK-NEXT:  vector.store %[[val]], %[[buf]][%[[c]]] : memref<100xf32>, vector<4xf32>
   return
 }
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 00d7b6b8d65f67..e21d20995a7d01 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -156,9 +156,9 @@ func.func private @get_idx() -> (index)
 // CHECK-NEXT:   %[[v0:.*]] = call @get_idx() : () -> index
 // CHECK-NEXT:   %[[c0:.*]] = arith.constant 0 : index
 // CHECK-NEXT:   %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT:   %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index
+// CHECK-NEXT:   %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[c20:.*]] = arith.constant 20 : index
-// CHECK-NEXT:   %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] : index
+// CHECK-NEXT:   %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[v3:.*]] = arith.cmpi sge, %[[v2]], %[[c0]] : index
 // CHECK-NEXT:   if %[[v3]] {
 // CHECK-NEXT:     call @body(%[[v0:.*]]) : (index) -> ()
@@ -177,9 +177,9 @@ func.func @if_only() {
 // CHECK-NEXT:   %[[v0:.*]] = call @get_idx() : () -> index
 // CHECK-NEXT:   %[[c0:.*]] = arith.constant 0 : index
 // CHECK-NEXT:   %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT:   %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index
+// CHECK-NEXT:   %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[c20:.*]] = arith.constant 20 : index
-// CHECK-NEXT:   %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] : index
+// CHECK-NEXT:   %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[v3:.*]] = arith.cmpi sge, %[[v2]], %[[c0]] : index
 // CHECK-NEXT:   if %[[v3]] {
 // CHECK-NEXT:     call @body(%[[v0:.*]]) : (index) -> ()
@@ -202,14 +202,14 @@ func.func @if_else() {
 // CHECK-NEXT:   %[[v0:.*]] = call @get_idx() : () -> index
 // CHECK-NEXT:   %[[c0:.*]] = arith.constant 0 : index
 // CHECK-NEXT:   %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT:   %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index
+// CHECK-NEXT:   %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[c20:.*]] = arith.constant 20 : index
-// CHECK-NEXT:   %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] : index
+// CHECK-NEXT:   %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[v3:.*]] = arith.cmpi sge, %[[v2]], %[[c0]] : index
 // CHECK-NEXT:   if %[[v3]] {
 // CHECK-NEXT:     %[[c0_0:.*]] = arith.constant 0 : index
 // CHECK-NEXT:     %[[cm10:.*]] = arith.constant -10 : index
-// CHECK-NEXT:     %[[v4:.*]] = arith.addi %[[v0]], %[[cm10]] : index
+// CHECK-NEXT:     %[[v4:.*]] = arith.addi %[[v0]], %[[cm10]] overflow<nsw, nuw> : index
 // CHECK-NEXT:     %[[v5:.*]] = arith.cmpi sge, %[[v4]], %[[c0_0]] : index
 // CHECK-NEXT:     if %[[v5]] {
 // CHECK-NEXT:       call @body(%[[v0:.*]]) : (index) -> ()
@@ -217,7 +217,7 @@ func.func @if_else() {
 // CHECK-NEXT:   } else {
 // CHECK-NEXT:     %[[c0_0:.*]] = arith.constant 0 : index
 // CHECK-NEXT:     %[[cm10:.*]] = arith.constant -10 : index
-// CHECK-NEXT:     %{{.*}} = arith.addi %[[v0]], %[[cm10]] : index
+// CHECK-NEXT:     %{{.*}} = arith.addi %[[v0]], %[[cm10]] overflow<nsw, nuw> : index
 // CHECK-NEXT:     %{{.*}} = arith.cmpi sge, %{{.*}}, %[[c0_0]] : index
 // CHECK-NEXT:     if %{{.*}} {
 // CHECK-NEXT:       call @mid(%[[v0:.*]]) : (index) -> ()
@@ -245,7 +245,7 @@ func.func @nested_ifs() {
 // CHECK-NEXT:   %[[v0:.*]] = call @get_idx() : () -> index
 // CHECK-NEXT:   %[[c0:.*]] = arith.constant 0 : index
 // CHECK-NEXT:   %[[cm10:.*]] = arith.constant -10 : index
-// CHECK-NEXT:   %[[v1:.*]] = arith.addi %[[v0]], %[[cm10]] : index
+// CHECK-NEXT:   %[[v1:.*]] = arith.addi %[[v0]], %[[cm10]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[v2:.*]] = arith.cmpi sge, %[[v1]], %[[c0]] : index
 // CHECK-NEXT:   %[[v3:.*]] = scf.if %[[v2]] -> (i64) {
 // CHECK-NEXT:     scf.yield %[[c0_i64]] : i64
@@ -272,25 +272,25 @@ func.func @if_with_yield() -> (i64) {
 // CHECK-NEXT:   %[[v0:.*]] = call @get_idx() : () -> index
 // CHECK-NEXT:   %[[c0:.*]] = arith.constant 0 : index
 // CHECK-NEXT:   %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT:   %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index
+// CHECK-NEXT:   %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[v2:.*]] = arith.addi %[[v1]], %{{.*}} : index
 // CHECK-NEXT:   %[[c1:.*]] = arith.constant 1 : index
-// CHECK-NEXT:   %[[v3:.*]] = arith.addi %[[v2]], %[[c1]] : index
+// CHECK-NEXT:   %[[v3:.*]] = arith.addi %[[v2]], %[[c1]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[v4:.*]] = arith.cmpi sge, %[[v3]], %[[c0]] : index
 // CHECK-NEXT:   %[[cm1_0:.*]] = arith.constant -1 : index
-// CHECK-NEXT:   %[[v5:.*]] = arith.addi %{{.*}}, %[[cm1_0]] : index
+// CHECK-NEXT:   %[[v5:.*]] = arith.addi %{{.*}}, %[[cm1_0]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[v6:.*]] = arith.cmpi sge, %[[v5]], %[[c0]] : index
 // CHECK-NEXT:   %[[v7:.*]] = arith.andi %[[v4]], %[[v6]] : i1
 // CHECK-NEXT:   %[[cm1_1:.*]] = arith.constant -1 : index
-// CHECK-NEXT:   %[[v8:.*]] = arith.addi %{{.*}}, %[[cm1_1]] : index
+// CHECK-NEXT:   %[[v8:.*]] = arith.addi %{{.*}}, %[[cm1_1]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[v9:.*]] = arith.cmpi sge, %[[v8]], %[[c0]] : index
 // CHECK-NEXT:   %[[v10:.*]] = arith.andi %[[v7]], %[[v9]] : i1
 // CHECK-NEXT:   %[[cm1_2:.*]] = arith.constant -1 : index
-// CHECK-NEXT:   %[[v11:.*]] = arith.addi %{{.*}}, %[[cm1_2]] : index
+// CHECK-NEXT:   %[[v11:.*]] = arith.addi %{{.*}}, %[[cm1_2]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[v12:.*]] = arith.cmpi sge, %[[v11]], %[[c0]] : index
 // CHECK-NEXT:   %[[v13:.*]] = arith.andi %[[v10]], %[[v12]] : i1
 // CHECK-NEXT:   %[[cm42:.*]] = arith.constant -42 : index
-// CHECK-NEXT:   %[[v14:.*]] = arith.addi %{{.*}}, %[[cm42]] : index
+// CHECK-NEXT:   %[[v14:.*]] = arith.addi %{{.*}}, %[[cm42]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[v15:.*]] = arith.cmpi eq, %[[v14]], %[[c0]] : index
 // CHECK-NEXT:   %[[v16:.*]] = arith.andi %[[v13]], %[[v15]] : i1
 // CHECK-NEXT:   if %[[v16]] {
@@ -316,9 +316,9 @@ func.func @if_for() {
   %i = call @get_idx() : () -> (index)
 // CHECK-NEXT:   %[[c0:.*]] = arith.constant 0 : index
 // CHECK-NEXT:   %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT:   %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index
+// CHECK-NEXT:   %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[c20:.*]] = arith.constant 20 : index
-// CHECK-NEXT:   %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] : index
+// CHECK-NEXT:   %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[v3:.*]] = arith.cmpi sge, %[[v2]], %[[c0]] : index
 // CHECK-NEXT:   if %[[v3]] {
 // CHECK-NEXT:     %[[c0:.*]]{{.*}} = arith.constant 0 : index
@@ -327,7 +327,7 @@ func.func @if_for() {
 // CHECK-NEXT:     for %{{.*}} = %[[c0:.*]]{{.*}} to %[[c42:.*]]{{.*}} step %[[c1:.*]]{{.*}} {
 // CHECK-NEXT:       %[[c0_:.*]]{{.*}} = arith.constant 0 : index
 // CHECK-NEXT:       %[[cm10:.*]] = arith.constant -10 : index
-// CHECK-NEXT:       %[[v4:.*]] = arith.addi %{{.*}}, %[[cm10]] : index
+// CHECK-NEXT:       %[[v4:.*]] = arith.addi %{{.*}}, %[[cm10]] overflow<nsw, nuw> : index
 // CHECK-NEXT:       %[[v5:.*]] = arith.cmpi sge, %[[v4]], %[[c0_:.*]]{{.*}} : index
 // CHECK-NEXT:       if %[[v5]] {
 // CHECK-NEXT:         call @body2(%[[v0]], %{{.*}}) : (index, index) -> ()
@@ -371,11 +371,11 @@ func.func @if_for() {
 // CHECK-NEXT:   %[[c1:.*]] = arith.constant 1 : index
 // CHECK-NEXT:   for %{{.*}} = %[[c0]] to %[[c42]] step %[[c1]] {
 // CHECK-NEXT:     %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT:     %[[mul0:.*]] = arith.muli %{{.*}}, %[[cm1]] : index
-// CHECK-NEXT:     %[[add0:.*]] = arith.addi %[[mul0]], %{{.*}} : index
+// CHECK-NEXT:     %[[mul0:.*]] = arith.muli %{{.*}}, %[[cm1]] overflow<nsw, nuw> : index
+// CHECK-NEXT:     %[[add0:.*]] = arith.addi %[[mul0]], %{{.*}} overflow<nsw, nuw> : index
 // CHECK-NEXT:     %[[max:.*]] = arith.maxsi %{{.*}}, %[[add0]] : index
 // CHECK-NEXT:     %[[c10:.*]] = arith.constant 10 : index
-// CHECK-NEXT:     %[[add1:.*]] = arith.addi %{{.*}}, %[[c10]] : index
+// CHECK-NEXT:     %[[add1:.*]] = arith.addi %{{.*}}, %[[c10]] overflow<nsw, nuw> : index
 // CHECK-NEXT:     %[[min:.*]] = arith.minsi %{{.*}}, %[[add1]] : index
 // CHECK-NEXT:     %[[c1_0:.*]] = arith.constant 1 : index
 // CHECK-NEXT:     for %{{.*}} = %[[max]] to %[[min]] step %[[c1_0]] {
@@ -442,29 +442,29 @@ func.func @affine_applies(%arg0 : index) {
   %102 = arith.constant 102 : index
   %copy = affine.apply #map2(%zero)
 
-// CHECK-NEXT: %[[v0:.*]] = arith.addi %[[c0]], %[[c0]] : index
+// CHECK-NEXT: %[[v0:.*]] = arith.addi %[[c0]], %[[c0]] overflow<nsw, nuw> : index
 // CHECK-NEXT: %[[c1:.*]] = arith.constant 1 : index
-// CHECK-NEXT: %[[v1:.*]] = arith.addi %[[v0]], %[[c1]] : index
+// CHECK-NEXT: %[[v1:.*]] = arith.addi %[[v0]], %[[c1]] overflow<nsw, nuw> : index
   %one = affine.apply #map3(%symbZero)[%zero]
 
 // CHECK-NEXT: %[[c2:.*]] = arith.constant 2 : index
-// CHECK-NEXT: %[[v2:.*]] = arith.muli %arg0, %[[c2]] : index
-// CHECK-NEXT: %[[v3:.*]] = arith.addi %arg0, %[[v2]] : index
+// CHECK-NEXT: %[[v2:.*]] = arith.muli %arg0, %[[c2]] overflow<nsw, nuw> : index
+// CHECK-NEXT: %[[v3:.*]] = arith.addi %arg0, %[[v2]] overflow<nsw, nuw> : index
 // CHECK-NEXT: %[[c3:.*]] = arith.constant 3 : index
-// CHECK-NEXT: %[[v4:.*]] = arith.muli %arg0, %[[c3]] : index
-// CHECK-NEXT: %[[v5:.*]] = arith.addi %[[v3]], %[[v4]] : index
+// CHECK-NEXT: %[[v4:.*]] = arith.muli %arg0, %[[c3]] overflow<nsw, nuw> : index
+// CHECK-NEXT: %[[v5:.*]] = arith.addi %[[v3]], %[[v4]] overflow<nsw, nuw> : index
 // CHECK-NEXT: %[[c4:.*]] = arith.constant 4 : index
-// CHECK-NEXT: %[[v6:.*]] = arith.muli %arg0, %[[c4]] : index
-// CHECK-NEXT: %[[v7:.*]] = arith.addi %[[v5]], %[[v6]] : index
+// CHECK-NEXT: %[[v6:.*]] = arith.muli %arg0, %[[c4]] overflow<nsw, nuw> : index
+// CHECK-NEXT: %[[v7:.*]] = arith.addi %[[v5]], %[[v6]] overflow<nsw, nuw> : index
 // CHECK-NEXT: %[[c5:.*]] = arith.constant 5 : index
-// CHECK-NEXT: %[[v8:.*]] = arith.muli %arg0, %[[c5]] : index
-// CHECK-NEXT: %[[v9:.*]] = arith.addi %[[v7]], %[[v8]] : index
+// CHECK-NEXT: %[[v8:.*]] = arith.muli %arg0, %[[c5]] overflow<nsw, nuw> : index
+// CHECK-NEXT: %[[v9:.*]] = arith.addi %[[v7]], %[[v8]] overflow<nsw, nuw> : index
 // CHECK-NEXT: %[[c6:.*]] = arith.constant 6 : index
-// CHECK-NEXT: %[[v10:.*]] = arith.muli %arg0, %[[c6]] : index
-// CHECK-NEXT: %[[v11:.*]] = arith.addi %[[v9]], %[[v10]] : index
+// CHECK-NEXT: %[[v10:.*]] = arith.muli %arg0, %[[c6]] overflow<nsw, nuw> : index
+// CHECK-NEXT: %[[v11:.*]] = arith.addi %[[v9]], %[[v10]] overflow<nsw, nuw> : index
 // CHECK-NEXT: %[[c7:.*]] = arith.constant 7 : index
-// CHECK-NEXT: %[[v12:.*]] = arith.muli %arg0, %[[c7]] : index
-// CHECK-NEXT: %[[v13:.*]] = arith.addi %[[v11]], %[[v12]] : index
+// CHECK-NEXT: %[[v12:.*]] = arith.muli %arg0, %[[c7]] overflow<nsw, nuw> : index
+// CHECK-NEXT: %[[v13:.*]] = arith.addi %[[v11]], %[[v12]] overflow<nsw, nuw> : index
   %four = affine.apply #map4(%arg0, %arg0, %arg0, %arg0)[%arg0, %arg0, %arg0]
   return
 }
@@ -498,7 +498,7 @@ func.func @affine_apply_mod(%arg0 : index) -> (index) {
 // CHECK-NEXT: %[[v0:.*]] = arith.remsi %{{.*}}, %[[c42]] : index
 // CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
 // CHECK-NEXT: %[[v1:.*]] = arith.cmpi slt, %[[v0]], %[[c0]] : index
-// CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v0]], %[[c42]] : index
+// CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v0]], %[[c42]] overflow<nsw, nuw> : index
 // CHECK-NEXT: %[[v3:.*]] = arith.select %[[v1]], %[[v2]], %[[v0]] : index
   %0 = affine.apply #map_mod (%arg0)
   return %0 : index
@@ -509,7 +509,7 @@ func.func @affine_apply_mod_dynamic_divisor(%arg0 : index, %arg1 : index) -> (in
 // CHECK-NEXT: %[[v0:.*]] = arith.remsi %{{.*}}, %arg1 : index
 // CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
 // CHECK-NEXT: %[[v1:.*]] = arith.cmpi slt, %[[v0]], %[[c0]] : index
-// CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v0]], %arg1 : index
+// CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v0]], %arg1 overflow<nsw, nuw> : index
 // CHECK-NEXT: %[[v3:.*]] = arith.select %[[v1]], %[[v2]], %[[v0]] : index
   %0 = affine.apply #map_mod_dynamic_divisor (%arg0)[%arg1]
   return %0 : index
@@ -567,7 +567,7 @@ func.func @affine_apply_ceildiv(%arg0 : index) -> (index) {
 // CHECK-NEXT:  %[[v3:.*]] = arith.select %[[v0]], %[[v1]], %[[v2]] : index
 // CHECK-NEXT:  %[[v4:.*]] = arith.divsi %[[v3]], %[[c42]] : index
 // CHECK-NEXT:  %[[v5:.*]] = arith.subi %[[c0]], %[[v4]] : index
-// CHECK-NEXT:  %[[v6:.*]] = arith.addi %[[v4]], %[[c1]] : index
+// CHECK-NEXT:  %[[v6:.*]] = arith.addi %[[v4]], %[[c1]] overflow<nsw, nuw> : index
 // CHECK-NEXT:  %[[v7:.*]] = arith.select %[[v0]], %[[v5]], %[[v6]] : index
   %0 = affine.apply #map_ceildiv (%arg0)
   return %0 : index
@@ -583,7 +583,7 @@ func.func @affine_apply_ceildiv_dynamic_divisor(%arg0 : index, %arg1 : index) ->
 // CHECK-NEXT:  %[[v3:.*]] = arith.select %[[v0]], %[[v1]], %[[v2]] : index
 // CHECK-NEXT:  %[[v4:.*]] = arith.divsi %[[v3]], %arg1 : index
 // CHECK-NEXT:  %[[v5:.*]] = arith.subi %[[c0]], %[[v4]] : index
-// CHECK-NEXT:  %[[v6:.*]] = arith.addi %[[v4]], %[[c1]] : index
+// CHECK-NEXT:  %[[v6:.*]] = arith.addi %[[v4]], %[[c1]] overflow<nsw, nuw> : index
 // CHECK-NEXT:  %[[v7:.*]] = arith.select %[[v0]], %[[v5]], %[[v6]] : index
   %0 = affine.apply #map_ceildiv_dynamic_divisor (%arg0)[%arg1]
   return %0 : index
@@ -597,7 +597,7 @@ func.func @affine_load(%arg0 : index) {
   }
 // CHECK:       %[[a:.*]] = arith.addi %{{.*}}, %{{.*}} : index
 // CHECK-NEXT:  %[[c7:.*]] = arith.constant 7 : index
-// CHECK-NEXT:  %[[b:.*]] = arith.addi %[[a]], %[[c7]] : index
+// CHECK-NEXT:  %[[b:.*]] = arith.addi %[[a]], %[[c7]] overflow<nsw, nuw> : index
 // CHECK-NEXT:  %{{.*}} = memref.load %[[v0:.*]][%[[b]]] : memref<10xf32>
   return
 }
@@ -610,10 +610,10 @@ func.func @affine_store(%arg0 : index) {
     affine.store %1, %0[%i0 - symbol(%arg0) + 7] : memref<10xf32>
   }
 // CHECK:       %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT:  %[[a:.*]] = arith.muli %{{.*}}, %[[cm1]] : index
-// CHECK-NEXT:  %[[b:.*]] = arith.addi %{{.*}}, %[[a]] : index
+// CHECK-NEXT:  %[[a:.*]] = arith.muli %{{.*}}, %[[cm1]] overflow<nsw, nuw> : index
+// CHECK-NEXT:  %[[b:.*]] = arith.addi %{{.*}}, %[[a]] overflow<nsw, nuw> : index
 // CHECK-NEXT:  %[[c7:.*]] = arith.constant 7 : index
-// CHECK-NEXT:  %[[c:.*]] = arith.addi %[[b]], %[[c7]] : index
+// CHECK-NEXT:  %[[c:.*]] = arith.addi %[[b]], %[[c7]] overflow<nsw, nuw> : index
 // CHECK-NEXT:  store %{{.*}}, %{{.*}}[%[[c]]] : memref<10xf32>
   return
 }
@@ -635,7 +635,7 @@ func.func @affine_prefetch(%arg0 : index) {
   }
 // CHECK:       %[[a:.*]] = arith.addi %{{.*}}, %{{.*}} : index
 // CHECK-NEXT:  %[[c7:.*]] = arith.constant 7 : index
-// CHECK-NEXT:  %[[b:.*]] = arith.addi %[[a]], %[[c7]] : index
+// CHECK-NEXT:  %[[b:.*]] = arith.addi %[[a]], %[[c7]] overflow<nsw, nuw> : index
 // CHECK-NEXT:  memref.prefetch %[[v0:.*]][%[[b]]], read, locality<3>, data : memref<10xf32>
   return
 }
@@ -652,9 +652,9 @@ func.func @affine_dma_start(%arg0 : index) {
         : memref<100xf32>, memref<100xf32, 2>, memref<1xi32>
   }
 // CHECK:       %[[c7:.*]] = arith.constant 7 : index
-// CHECK-NEXT:  %[[a:.*]] = arith.addi %{{.*}}, %[[c7]] : index
+// CHECK-NEXT:  %[[a:.*]] = arith.addi %{{.*}}, %[[c7]] overflow<nsw, nuw> : index
 // CHECK-NEXT:  %[[c11:.*]] = arith.constant 11 : index
-// CHECK-NEXT:  %[[b:.*]] = arith.addi %{{.*}}, %[[c11]] : index
+// CHECK-NEXT:  %[[b:.*]] = arith.addi %{{.*}}, %[[c11]] overflow<nsw, nuw> : index
 // CHECK-NEXT:  dma_start %{{.*}}[%[[a]]], %{{.*}}[%[[b]]], %{{.*}}, %{{.*}}[%{{.*}}] : memref<100xf32>, memref<100xf32, 2>, memref<1xi32>
   return
 }
@@ -666,9 +666,9 @@ func.func @affine_dma_wait(%arg0 : index) {
   affine.for %i0 = 0 to 10 {
     affine.dma_wait %2[%i0 + %arg0 + 17], %c64 : memref<1xi32>
   }
-// CHECK:       %[[a:.*]] = arith.addi %{{.*}}, %arg0 : index
+// CHECK:       %[[a:.*]] = arith.addi %{{.*}}, %arg0 overflow<nsw, nuw> : index
 // CHECK-NEXT:  %[[c17:.*]] = arith.constant 17 : index
-// CHECK-NEXT:  %[[b:.*]] = arith.addi %[[a]], %[[c17]] : index
+// CHECK-NEXT:  %[[b:.*]] = arith.addi %[[a]], %[[c17]] overflow<nsw, nuw> : index
 // CHECK-NEXT:  dma_wait %{{.*}}[%[[b]]], %{{.*}} : memref<1xi32>
   return
 }
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index a78db9733b7eef..c89c8a5365f44b 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -59,10 +59,10 @@ func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : in
   // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64
   // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64
   // CHECK: %[[STRIDE0:.*]] = llvm.mlir.constant(4 : index) : i64
-  // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64
+  // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] overflow<nsw, nuw> : i64
   // CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[DESCSTRIDE0]] : i64 to index
   // CHECK: %[[DESCSTRIDE0_V2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
-  // CHECK: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64
+  // CHECK: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] overflow<nsw, nuw> : i64
   // CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF2]] : i64 to index
   // CHECK: %[[OFF2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
   // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
@@ -95,10 +95,10 @@ func.func @subview_non_zero_addrspace(%0 : memref<64x4xf32, strided<[4, 1], offs
   // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[STRIDE0:.*]] = llvm.mlir.constant(4 : index) : i64
-  // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]]  : i64
+  // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] overflow<nsw, nuw> : i64
   // CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[DESCSTRIDE0]] : i64 to index
   // CHECK: %[[DESCSTRIDE0_V2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
-  // CHECK: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]]  : i64
+  // CHECK: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] overflow<nsw, nuw> : i64
   // CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF2]] : i64 to index
   // CHECK: %[[OFF2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
   // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
@@ -131,10 +131,10 @@ func.func @subview_const_size(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>,
   // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : index) : i64
-  // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[C4]]  : i64
+  // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[C4]] overflow<nsw, nuw> : i64
   // CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[DESCSTRIDE0]] : i64 to index
   // CHECK: %[[DESCSTRIDE0_V2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
-  // CHECK: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]]  : i64
+  // CHECK: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] overflow<nsw, nuw> : i64
   // CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF2]] : i64 to index
   // CHECK: %[[OFF2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
   // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
@@ -168,8 +168,8 @@ func.func @subview_const_stride(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>
   // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : index) : i64
-  // CHECK: %[[OFF0:.*]] = llvm.mul %[[ARG0]], %[[C4]]  : i64
-  // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF0]], %[[ARG1]]  : i64
+  // CHECK: %[[OFF0:.*]] = llvm.mul %[[ARG0]], %[[C4]] overflow<nsw, nuw> : i64
+  // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF0]], %[[ARG1]] overflow<nsw, nuw> : i64
   // CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF2]] : i64 to index
   // CHECK: %[[OFF2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
   // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
@@ -234,12 +234,12 @@ func.func @subview_mixed_static_dynamic(%0 : memref<64x4xf32, strided<[4, 1], of
   // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[STRIDE0:.*]] = llvm.mlir.constant(4 : index) : i64
-  // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]]  : i64
+  // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] overflow<nsw, nuw> : i64
   // CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[DESCSTRIDE0]] : i64 to index
   // CHECK: %[[DESCSTRIDE0_V2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
-  // CHECK: %[[OFF0:.*]] = llvm.mul %[[ARG1]], %[[STRIDE0]]  : i64
+  // CHECK: %[[OFF0:.*]] = llvm.mul %[[ARG1]], %[[STRIDE0]] overflow<nsw, nuw> : i64
   // CHECK: %[[BASE_OFF:.*]] = llvm.mlir.constant(8 : index)  : i64
-  // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF0]], %[[BASE_OFF]]  : i64
+  // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF0]], %[[BASE_OFF]] overflow<nsw, nuw> : i64
   // CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF2]] : i64 to index
   // CHECK: %[[OFF2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
   // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
@@ -301,7 +301,7 @@ func.func @subview_leading_operands_dynamic(%0 : memref<5x?xf32>) -> memref<3x?x
   // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // Compute and insert offset from 2 + dynamic value.
   // CHECK: %[[CST_OFF0:.*]] = llvm.mlir.constant(2 : index) : i64
-  // CHECK: %[[OFF0:.*]] = llvm.mul %[[STRIDE0]], %[[CST_OFF0]] : i64
+  // CHECK: %[[OFF0:.*]] = llvm.mul %[[STRIDE0]], %[[CST_OFF0]] overflow<nsw, nuw> : i64
   // CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF0]] : i64 to index
   // CHECK: %[[OFF0:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
   // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
@@ -425,7 +425,7 @@ func.func @collapse_shape_dynamic_with_non_identity_layout(
 // CHECK:           %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 // CHECK:           %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 // CHECK:           %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE1]], %[[SIZE2]]  : i64
+// CHECK:           %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE1]], %[[SIZE2]] overflow<nsw, nuw> : i64
 // CHECK:           %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
 // CHECK:           %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
 // CHECK:           %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
@@ -547,7 +547,7 @@ func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32>
 // CHECK:           %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 // CHECK:           %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 // CHECK:           %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64
-// CHECK:           %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE2]], %[[C2]]  : i64
+// CHECK:           %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE2]], %[[C2]] overflow<nsw, nuw> : i64
 // CHECK:           %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
 // CHECK:           %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
 // CHECK:           %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 6a186a0c6ceca0..bc8df8222abb3c 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3230,3 +3230,65 @@ func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>,
     }
   }
 #-}
+
+func.func @fold_divui_of_muli_0(%arg0 : index, %arg1 : index) -> index {
+  %0 = arith.muli %arg0, %arg1 overflow<nuw> : index
+  %1 = arith.divui %0, %arg0 : index
+  return %1 : index
+}
+// CHECK-LABEL: func @fold_divui_of_muli_0(
+//  CHECK-SAME:     %[[ARG0:.+]]: index,
+//  CHECK-SAME:     %[[ARG1:.+]]: index)
+//       CHECK:   return %[[ARG1]]
+
+func.func @fold_divui_of_muli_1(%arg0 : index, %arg1 : index) -> index {
+  %0 = arith.muli %arg0, %arg1 overflow<nuw> : index
+  %1 = arith.divui %0, %arg1 : index
+  return %1 : index
+}
+// CHECK-LABEL: func @fold_divui_of_muli_1(
+//  CHECK-SAME:     %[[ARG0:.+]]: index,
+//  CHECK-SAME:     %[[ARG1:.+]]: index)
+//       CHECK:   return %[[ARG0]]
+
+func.func @fold_divsi_of_muli_0(%arg0 : index, %arg1 : index) -> index {
+  %0 = arith.muli %arg0, %arg1 overflow<nsw> : index
+  %1 = arith.divsi %0, %arg0 : index
+  return %1 : index
+}
+// CHECK-LABEL: func @fold_divsi_of_muli_0(
+//  CHECK-SAME:     %[[ARG0:.+]]: index,
+//  CHECK-SAME:     %[[ARG1:.+]]: index)
+//       CHECK:   return %[[ARG1]]
+
+func.func @fold_divsi_of_muli_1(%arg0 : index, %arg1 : index) -> index {
+  %0 = arith.muli %arg0, %arg1 overflow<nsw> : index
+  %1 = arith.divsi %0, %arg1 : index
+  return %1 : index
+}
+// CHECK-LABEL: func @fold_divsi_of_muli_1(
+//  CHECK-SAME:     %[[ARG0:.+]]: index,
+//  CHECK-SAME:     %[[ARG1:.+]]: index)
+//       CHECK:   return %[[ARG0]]
+
+// Do not fold divui(mul(a, v), v) -> a with nuw attribute.
+func.func @no_fold_divui_of_muli(%arg0 : index, %arg1 : index) -> index {
+  %0 = arith.muli %arg0, %arg1 : index
+  %1 = arith.divui %0, %arg0 : index
+  return %1 : index
+}
+// CHECK-LABEL: func @no_fold_divui_of_muli
+//       CHECK:   %[[T0:.+]] = arith.muli
+//       CHECK:   %[[T1:.+]] = arith.divui %[[T0]],
+//       CHECK:   return %[[T1]]
+
+// Do not fold divsi(mul(a, v), v) -> a with nuw attribute.
+func.func @no_fold_divsi_of_muli(%arg0 : index, %arg1 : index) -> index {
+  %0 = arith.muli %arg0, %arg1 : index
+  %1 = arith.divsi %0, %arg0 : index
+  return %1 : index
+}
+// CHECK-LABEL: func @no_fold_divsi_of_muli
+//       CHECK:   %[[T0:.+]] = arith.muli
+//       CHECK:   %[[T1:.+]] = arith.divsi %[[T0]],
+//       CHECK:   return %[[T1]]
diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir
index f6d3387d99b3c3..55f59b70754abf 100644
--- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir
+++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir
@@ -28,8 +28,8 @@ func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : in
   // CHECK-SAME: -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>
 
   // CHECK-DAG: %[[STRIDE0:.*]] = llvm.mlir.constant(4 : index) : i64
-  // CHECK-DAG: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64
-  // CHECK-DAG: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64
+  // CHECK-DAG: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] overflow<nsw, nuw> : i64
+  // CHECK-DAG: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] overflow<nsw, nuw> : i64
   // CHECK-DAG: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 
   // Base address and algined address.
diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir
index a74553cc2268ef..389dcc5c41e6c9 100644
--- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir
+++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir
@@ -27,8 +27,8 @@ func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : in
   // CHECK-SAME: -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>
 
   // CHECK-DAG: %[[STRIDE0:.*]] = llvm.mlir.constant(4 : index) : i64
-  // CHECK-DAG: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64
-  // CHECK-DAG: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64
+  // CHECK-DAG: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] overflow<nsw, nuw> : i64
+  // CHECK-DAG: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] overflow<nsw, nuw> : i64
   // CHECK-DAG: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 
   // Base address and algined address.
diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
index b688a677500c22..df9deb06086fb4 100644
--- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
@@ -376,10 +376,10 @@ func.func @vectorize_affine_apply(%arg0: tensor<5xf32>, %arg3: index) -> tensor<
 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 // CHECK:   %[[EMPTY:.*]] = tensor.empty() : tensor<5xi32>
 // CHECK:   %[[BCAST:.*]] = vector.broadcast %[[ARG1]] : index to vector<5xindex>
-// CHECK:   %[[ADDI_1:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<5xindex>
-// CHECK:   %[[ADDI_2:.*]] = arith.addi %[[ADDI_1]], %[[CST_0]] : vector<5xindex>
-// CHECK:   %[[ADDI_3:.*]] = arith.addi %[[ADDI_1]], %[[ADDI_2]] : vector<5xindex>
-// CHECK:   %[[ADDI_4:.*]] = arith.addi %[[ADDI_3]], %[[CST]] : vector<5xindex>
+// CHECK:   %[[ADDI_1:.*]] = arith.addi %[[BCAST]], %[[CST]] overflow<nsw, nuw> : vector<5xindex>
+// CHECK:   %[[ADDI_2:.*]] = arith.addi %[[ADDI_1]], %[[CST_0]] overflow<nsw, nuw> : vector<5xindex>
+// CHECK:   %[[ADDI_3:.*]] = arith.addi %[[ADDI_1]], %[[ADDI_2]] overflow<nsw, nuw> : vector<5xindex>
+// CHECK:   %[[ADDI_4:.*]] = arith.addi %[[ADDI_3]], %[[CST]] overflow<nsw, nuw> : vector<5xindex>
 // CHECK:   %[[CAST:.*]] = arith.index_cast %[[ADDI_4]] : vector<5xindex> to vector<5xi32>
 // CHECK:   vector.transfer_write %[[CAST]], %[[EMPTY]][%[[C0:.*]]] {in_bounds = [true]} : vector<5xi32>, tensor<5xi32>
 
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
index d0d3b58a057041..51dbc82001ff21 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
@@ -36,7 +36,7 @@ func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguou
 /// Caluclate the index vector
 // CHECK:           %[[STEP:.*]] = vector.step : vector<4xindex>
 // CHECK:           %[[IDX_BC:.*]] = vector.broadcast %[[IDX_IN]] : index to vector<4xindex>
-// CHECK:           %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<4xindex>
+// CHECK:           %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] overflow<nsw, nuw> : vector<4xindex>
 // CHECK:           %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<4xindex> to vector<4xindex>
 
 /// Extract the starting point from the index vector
@@ -96,7 +96,7 @@ func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguou
 /// Caluclate the index vector
 // CHECK:           %[[STEP:.*]] = vector.step : vector<[4]xindex>
 // CHECK:           %[[IDX_BC:.*]] = vector.broadcast %[[IDX_IN]] : index to vector<[4]xindex>
-// CHECK:           %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<[4]xindex>
+// CHECK:           %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] overflow<nsw, nuw> : vector<[4]xindex>
 // CHECK:           %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<[4]xindex> to vector<[4]xindex>
 
 /// Extract the starting point from the index vector
@@ -156,7 +156,7 @@ func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguo
 /// Caluclate the index vector
 // CHECK:           %[[STEP:.*]] = vector.step : vector<4xindex>
 // CHECK:           %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<4xindex>
-// CHECK:           %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<4xindex>
+// CHECK:           %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] overflow<nsw, nuw> : vector<4xindex>
 // CHECK:           %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<4xindex> to vector<4xindex>
 
 /// Extract the starting point from the index vector
@@ -214,7 +214,7 @@ func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguo
 /// Caluclate the index vector
 // CHECK:           %[[STEP:.*]] = vector.step : vector<[4]xindex>
 // CHECK:           %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<[4]xindex>
-// CHECK:           %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<[4]xindex>
+// CHECK:           %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] overflow<nsw, nuw> : vector<[4]xindex>
 // CHECK:           %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<[4]xindex> to vector<[4]xindex>
 
 /// Extract the starting point from the index vector
@@ -304,7 +304,7 @@ func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_gather(%
 // CHECK:           %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
 // CHECK:           %[[VAL_12:.*]] = vector.step : vector<4xindex>
 // CHECK:           %[[VAL_13:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
-// CHECK:           %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : vector<4xindex>
+// CHECK:           %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] overflow<nsw, nuw> : vector<4xindex>
 // CHECK:           %[[VAL_15:.*]] = arith.constant dense<true> : vector<1x4xi1>
 // CHECK:           %[[VAL_16:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
 // CHECK:           %[[VAL_17:.*]] = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index e66fbe968d9b0e..b989b0f6c9feb3 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -112,7 +112,7 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<8
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 79 : index
 // CHECK:           %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
-// CHECK:           %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex>
+// CHECK:           %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] overflow<nsw, nuw> : vector<4xindex>
 // CHECK:           %[[VAL_10:.*]] = vector.extract %[[VAL_9]][0] : index from vector<4xindex>
 // CHECK:           %[[VAL_11:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_10]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
 // CHECK:           %[[VAL_12:.*]] = vector.transfer_write %[[VAL_11]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
@@ -147,7 +147,7 @@ func.func @vectorize_nd_tensor_extract_with_tensor_extract(%input_1: tensor<1x20
 // CHECK-SAME:      %[[INPUT_2:.*]]: tensor<257x24xf32>,
 // CHECK-SAME:      %[[INPUT_3:.*]]: index, %[[INPUT_4:.*]]: index, %[[INPUT_5:.*]]: index,
 // CHECK:           %[[EXTRACTED_0_IDX_0:.*]] = arith.constant 0 : index
-// CHECK:           %[[SCALAR:.*]] = arith.addi %[[INPUT_3]], %[[INPUT_5]] : index
+// CHECK:           %[[SCALAR:.*]] = arith.addi %[[INPUT_3]], %[[INPUT_5]] overflow<nsw, nuw> : index
 // First `vector.transfer_read` from the generic Op - loop invariant scalar load.
 // CHECK:           vector.transfer_read %[[INPUT_1]][%[[EXTRACTED_0_IDX_0]], %[[SCALAR]]]
 // CHECK-SAME:      tensor<1x20xi32>, vector<i32>
@@ -294,7 +294,7 @@ func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(%
 // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
 // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
 // CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
-// CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
+// CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] overflow<nsw, nuw> : index
 // CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<1xindex>
 // CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
 // CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
@@ -440,7 +440,7 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant dense<16> : vector<1x4xindex>
 // CHECK:           %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
-// CHECK:           %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex>
+// CHECK:           %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] overflow<nsw, nuw> : vector<4xindex>
 // CHECK:           %[[VAL_10:.*]] = vector.broadcast %[[VAL_9]] : vector<4xindex> to vector<1x4xindex>
 // CHECK:           %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_7]] : vector<1x4xindex>
 // CHECK:           %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_7]] : vector<1x4xindex>



More information about the Mlir-commits mailing list