[Mlir-commits] [mlir] [mlir][Affine] Add nuw/nsw to lowering of affine ops. (PR #121535)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 3 17:04:48 PST 2025
https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/121535
>From 7b6a4e1b2bf176489321d6813a9561e90b1e54e0 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mravisha at amd.com>
Date: Thu, 2 Jan 2025 16:37:20 -0600
Subject: [PATCH] [mlir][Affine] Add nsw to `muli` ops generated by affine
lowering.
Since index operations have no set bitwidth, it is ill-defined to use
signed/unsigned wrapping behavior. The corollary to which is that it
is always safe to add nsw/nuw to lowering of affine ops.
Signed-off-by: MaheshRavishankar <mravisha at amd.com>
---
mlir/lib/Dialect/Affine/Utils/Utils.cpp | 9 ++-
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 38 +++++++++---
.../lower-affine-to-vector.mlir | 2 +-
.../AffineToStandard/lower-affine.mlir | 26 ++++----
.../expand-then-convert-to-llvm.mlir | 26 ++++----
mlir/test/Dialect/Arith/canonicalize.mlir | 62 +++++++++++++++++++
.../lower-to-llvm-e2e-with-target-tag.mlir | 2 +-
...lvm-e2e-with-top-level-named-sequence.mlir | 2 +-
8 files changed, 128 insertions(+), 39 deletions(-)
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 4d3ead20fb5cd3..9e3257a62b12fb 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -51,12 +51,14 @@ class AffineApplyExpander
loc(loc) {}
template <typename OpTy>
- Value buildBinaryExpr(AffineBinaryOpExpr expr) {
+ Value buildBinaryExpr(AffineBinaryOpExpr expr,
+ arith::IntegerOverflowFlags overflowFlags =
+ arith::IntegerOverflowFlags::none) {
auto lhs = visit(expr.getLHS());
auto rhs = visit(expr.getRHS());
if (!lhs || !rhs)
return nullptr;
- auto op = builder.create<OpTy>(loc, lhs, rhs);
+ auto op = builder.create<OpTy>(loc, lhs, rhs, overflowFlags);
return op.getResult();
}
@@ -65,7 +67,8 @@ class AffineApplyExpander
}
Value visitMulExpr(AffineBinaryOpExpr expr) {
- return buildBinaryExpr<arith::MulIOp>(expr);
+ return buildBinaryExpr<arith::MulIOp>(expr,
+ arith::IntegerOverflowFlags::nsw);
}
/// Euclidean modulo operation: negative RHS is not allowed.
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index e016a6e16e59ff..ceaf9a9a4ff804 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -616,6 +616,18 @@ OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
return a.udiv(b);
});
+ // divui (muli (a, v), v) -> a
+ if (auto muliOp = getLhs().getDefiningOp<arith::MulIOp>()) {
+ if (muliOp.hasNoUnsignedWrap()) {
+ if (getRhs() == muliOp.getRhs()) {
+ return muliOp.getLhs();
+ }
+ if (getRhs() == muliOp.getLhs()) {
+ return muliOp.getRhs();
+ }
+ }
+ }
+
return div0 ? Attribute() : result;
}
@@ -656,6 +668,18 @@ OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
return a.sdiv_ov(b, overflowOrDiv0);
});
+ // divsi (muli (a, v), v) -> a
+ if (auto muliOp = getLhs().getDefiningOp<arith::MulIOp>()) {
+ if (muliOp.hasNoSignedWrap()) {
+ if (getRhs() == muliOp.getRhs()) {
+ return muliOp.getLhs();
+ }
+ if (getRhs() == muliOp.getLhs()) {
+ return muliOp.getRhs();
+ }
+ }
+ }
+
return overflowOrDiv0 ? Attribute() : result;
}
@@ -2365,12 +2389,12 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
// Constant-fold constant operands over non-splat constant condition.
// select %cst_vec, %cst0, %cst1 => %cst2
- if (auto cond =
- llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
- if (auto lhs =
- llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
- if (auto rhs =
- llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
+ if (auto cond = llvm::dyn_cast_if_present<DenseElementsAttr>(
+ adaptor.getCondition())) {
+ if (auto lhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
+ adaptor.getTrueValue())) {
+ if (auto rhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
+ adaptor.getFalseValue())) {
SmallVector<Attribute> results;
results.reserve(static_cast<size_t>(cond.getNumElements()));
auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
@@ -2638,7 +2662,7 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
case AtomicRMWKind::minimumf:
return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
- case AtomicRMWKind::maxnumf:
+ case AtomicRMWKind::maxnumf:
return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
case AtomicRMWKind::minnumf:
return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir
index 58580a194df0c7..f2e0306073f27b 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir
@@ -26,7 +26,7 @@ func.func @affine_vector_store(%arg0 : index) {
// CHECK: %[[buf:.*]] = memref.alloc
// CHECK: %[[val:.*]] = arith.constant dense
// CHECK: %[[c_1:.*]] = arith.constant -1 : index
-// CHECK-NEXT: %[[a:.*]] = arith.muli %arg0, %[[c_1]] : index
+// CHECK-NEXT: %[[a:.*]] = arith.muli %arg0, %[[c_1]] overflow<nsw> : index
// CHECK-NEXT: %[[b:.*]] = arith.addi %{{.*}}, %[[a]] : index
// CHECK-NEXT: %[[c7:.*]] = arith.constant 7 : index
// CHECK-NEXT: %[[c:.*]] = arith.addi %[[b]], %[[c7]] : index
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 00d7b6b8d65f67..550ea71882e144 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -156,7 +156,7 @@ func.func private @get_idx() -> (index)
// CHECK-NEXT: %[[v0:.*]] = call @get_idx() : () -> index
// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index
+// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw> : index
// CHECK-NEXT: %[[c20:.*]] = arith.constant 20 : index
// CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] : index
// CHECK-NEXT: %[[v3:.*]] = arith.cmpi sge, %[[v2]], %[[c0]] : index
@@ -177,7 +177,7 @@ func.func @if_only() {
// CHECK-NEXT: %[[v0:.*]] = call @get_idx() : () -> index
// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index
+// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw> : index
// CHECK-NEXT: %[[c20:.*]] = arith.constant 20 : index
// CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] : index
// CHECK-NEXT: %[[v3:.*]] = arith.cmpi sge, %[[v2]], %[[c0]] : index
@@ -202,7 +202,7 @@ func.func @if_else() {
// CHECK-NEXT: %[[v0:.*]] = call @get_idx() : () -> index
// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index
+// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw> : index
// CHECK-NEXT: %[[c20:.*]] = arith.constant 20 : index
// CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] : index
// CHECK-NEXT: %[[v3:.*]] = arith.cmpi sge, %[[v2]], %[[c0]] : index
@@ -272,7 +272,7 @@ func.func @if_with_yield() -> (i64) {
// CHECK-NEXT: %[[v0:.*]] = call @get_idx() : () -> index
// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index
+// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw> : index
// CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v1]], %{{.*}} : index
// CHECK-NEXT: %[[c1:.*]] = arith.constant 1 : index
// CHECK-NEXT: %[[v3:.*]] = arith.addi %[[v2]], %[[c1]] : index
@@ -316,7 +316,7 @@ func.func @if_for() {
%i = call @get_idx() : () -> (index)
// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index
+// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw> : index
// CHECK-NEXT: %[[c20:.*]] = arith.constant 20 : index
// CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] : index
// CHECK-NEXT: %[[v3:.*]] = arith.cmpi sge, %[[v2]], %[[c0]] : index
@@ -371,7 +371,7 @@ func.func @if_for() {
// CHECK-NEXT: %[[c1:.*]] = arith.constant 1 : index
// CHECK-NEXT: for %{{.*}} = %[[c0]] to %[[c42]] step %[[c1]] {
// CHECK-NEXT: %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT: %[[mul0:.*]] = arith.muli %{{.*}}, %[[cm1]] : index
+// CHECK-NEXT: %[[mul0:.*]] = arith.muli %{{.*}}, %[[cm1]] overflow<nsw> : index
// CHECK-NEXT: %[[add0:.*]] = arith.addi %[[mul0]], %{{.*}} : index
// CHECK-NEXT: %[[max:.*]] = arith.maxsi %{{.*}}, %[[add0]] : index
// CHECK-NEXT: %[[c10:.*]] = arith.constant 10 : index
@@ -448,22 +448,22 @@ func.func @affine_applies(%arg0 : index) {
%one = affine.apply #map3(%symbZero)[%zero]
// CHECK-NEXT: %[[c2:.*]] = arith.constant 2 : index
-// CHECK-NEXT: %[[v2:.*]] = arith.muli %arg0, %[[c2]] : index
+// CHECK-NEXT: %[[v2:.*]] = arith.muli %arg0, %[[c2]] overflow<nsw> : index
// CHECK-NEXT: %[[v3:.*]] = arith.addi %arg0, %[[v2]] : index
// CHECK-NEXT: %[[c3:.*]] = arith.constant 3 : index
-// CHECK-NEXT: %[[v4:.*]] = arith.muli %arg0, %[[c3]] : index
+// CHECK-NEXT: %[[v4:.*]] = arith.muli %arg0, %[[c3]] overflow<nsw> : index
// CHECK-NEXT: %[[v5:.*]] = arith.addi %[[v3]], %[[v4]] : index
// CHECK-NEXT: %[[c4:.*]] = arith.constant 4 : index
-// CHECK-NEXT: %[[v6:.*]] = arith.muli %arg0, %[[c4]] : index
+// CHECK-NEXT: %[[v6:.*]] = arith.muli %arg0, %[[c4]] overflow<nsw> : index
// CHECK-NEXT: %[[v7:.*]] = arith.addi %[[v5]], %[[v6]] : index
// CHECK-NEXT: %[[c5:.*]] = arith.constant 5 : index
-// CHECK-NEXT: %[[v8:.*]] = arith.muli %arg0, %[[c5]] : index
+// CHECK-NEXT: %[[v8:.*]] = arith.muli %arg0, %[[c5]] overflow<nsw> : index
// CHECK-NEXT: %[[v9:.*]] = arith.addi %[[v7]], %[[v8]] : index
// CHECK-NEXT: %[[c6:.*]] = arith.constant 6 : index
-// CHECK-NEXT: %[[v10:.*]] = arith.muli %arg0, %[[c6]] : index
+// CHECK-NEXT: %[[v10:.*]] = arith.muli %arg0, %[[c6]] overflow<nsw> : index
// CHECK-NEXT: %[[v11:.*]] = arith.addi %[[v9]], %[[v10]] : index
// CHECK-NEXT: %[[c7:.*]] = arith.constant 7 : index
-// CHECK-NEXT: %[[v12:.*]] = arith.muli %arg0, %[[c7]] : index
+// CHECK-NEXT: %[[v12:.*]] = arith.muli %arg0, %[[c7]] overflow<nsw> : index
// CHECK-NEXT: %[[v13:.*]] = arith.addi %[[v11]], %[[v12]] : index
%four = affine.apply #map4(%arg0, %arg0, %arg0, %arg0)[%arg0, %arg0, %arg0]
return
@@ -610,7 +610,7 @@ func.func @affine_store(%arg0 : index) {
affine.store %1, %0[%i0 - symbol(%arg0) + 7] : memref<10xf32>
}
// CHECK: %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT: %[[a:.*]] = arith.muli %{{.*}}, %[[cm1]] : index
+// CHECK-NEXT: %[[a:.*]] = arith.muli %{{.*}}, %[[cm1]] overflow<nsw> : index
// CHECK-NEXT: %[[b:.*]] = arith.addi %{{.*}}, %[[a]] : index
// CHECK-NEXT: %[[c7:.*]] = arith.constant 7 : index
// CHECK-NEXT: %[[c:.*]] = arith.addi %[[b]], %[[c7]] : index
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index a78db9733b7eef..1fe4217cde9827 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -59,7 +59,7 @@ func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : in
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64
// CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64
// CHECK: %[[STRIDE0:.*]] = llvm.mlir.constant(4 : index) : i64
- // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64
+ // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] overflow<nsw> : i64
// CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[DESCSTRIDE0]] : i64 to index
// CHECK: %[[DESCSTRIDE0_V2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
// CHECK: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64
@@ -95,10 +95,10 @@ func.func @subview_non_zero_addrspace(%0 : memref<64x4xf32, strided<[4, 1], offs
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[STRIDE0:.*]] = llvm.mlir.constant(4 : index) : i64
- // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64
+ // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] overflow<nsw> : i64
// CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[DESCSTRIDE0]] : i64 to index
// CHECK: %[[DESCSTRIDE0_V2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
- // CHECK: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64
+ // CHECK: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64
// CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF2]] : i64 to index
// CHECK: %[[OFF2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
@@ -131,10 +131,10 @@ func.func @subview_const_size(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>,
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : index) : i64
- // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[C4]] : i64
+ // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[C4]] overflow<nsw> : i64
// CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[DESCSTRIDE0]] : i64 to index
// CHECK: %[[DESCSTRIDE0_V2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
- // CHECK: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64
+ // CHECK: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64
// CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF2]] : i64 to index
// CHECK: %[[OFF2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
@@ -168,8 +168,8 @@ func.func @subview_const_stride(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : index) : i64
- // CHECK: %[[OFF0:.*]] = llvm.mul %[[ARG0]], %[[C4]] : i64
- // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF0]], %[[ARG1]] : i64
+ // CHECK: %[[OFF0:.*]] = llvm.mul %[[ARG0]], %[[C4]] overflow<nsw> : i64
+ // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF0]], %[[ARG1]] : i64
// CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF2]] : i64 to index
// CHECK: %[[OFF2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
@@ -234,12 +234,12 @@ func.func @subview_mixed_static_dynamic(%0 : memref<64x4xf32, strided<[4, 1], of
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[STRIDE0:.*]] = llvm.mlir.constant(4 : index) : i64
- // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64
+ // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] overflow<nsw> : i64
// CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[DESCSTRIDE0]] : i64 to index
// CHECK: %[[DESCSTRIDE0_V2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
- // CHECK: %[[OFF0:.*]] = llvm.mul %[[ARG1]], %[[STRIDE0]] : i64
+ // CHECK: %[[OFF0:.*]] = llvm.mul %[[ARG1]], %[[STRIDE0]] overflow<nsw> : i64
// CHECK: %[[BASE_OFF:.*]] = llvm.mlir.constant(8 : index) : i64
- // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF0]], %[[BASE_OFF]] : i64
+ // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF0]], %[[BASE_OFF]] : i64
// CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF2]] : i64 to index
// CHECK: %[[OFF2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
@@ -301,7 +301,7 @@ func.func @subview_leading_operands_dynamic(%0 : memref<5x?xf32>) -> memref<3x?x
// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// Compute and insert offset from 2 + dynamic value.
// CHECK: %[[CST_OFF0:.*]] = llvm.mlir.constant(2 : index) : i64
- // CHECK: %[[OFF0:.*]] = llvm.mul %[[STRIDE0]], %[[CST_OFF0]] : i64
+ // CHECK: %[[OFF0:.*]] = llvm.mul %[[STRIDE0]], %[[CST_OFF0]] overflow<nsw> : i64
// CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF0]] : i64 to index
// CHECK: %[[OFF0:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
@@ -425,7 +425,7 @@ func.func @collapse_shape_dynamic_with_non_identity_layout(
// CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE1]], %[[SIZE2]] : i64
+// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE1]], %[[SIZE2]] overflow<nsw> : i64
// CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
// CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
@@ -547,7 +547,7 @@ func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32>
// CHECK: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64
-// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE2]], %[[C2]] : i64
+// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE2]], %[[C2]] overflow<nsw> : i64
// CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
// CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 522711b08f289d..30f71f52010e30 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3294,3 +3294,65 @@ func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>,
}
}
#-}
+
+func.func @fold_divui_of_muli_0(%arg0 : index, %arg1 : index) -> index {
+ %0 = arith.muli %arg0, %arg1 overflow<nuw> : index
+ %1 = arith.divui %0, %arg0 : index
+ return %1 : index
+}
+// CHECK-LABEL: func @fold_divui_of_muli_0(
+// CHECK-SAME: %[[ARG0:.+]]: index,
+// CHECK-SAME: %[[ARG1:.+]]: index)
+// CHECK: return %[[ARG1]]
+
+func.func @fold_divui_of_muli_1(%arg0 : index, %arg1 : index) -> index {
+ %0 = arith.muli %arg0, %arg1 overflow<nuw> : index
+ %1 = arith.divui %0, %arg1 : index
+ return %1 : index
+}
+// CHECK-LABEL: func @fold_divui_of_muli_1(
+// CHECK-SAME: %[[ARG0:.+]]: index,
+// CHECK-SAME: %[[ARG1:.+]]: index)
+// CHECK: return %[[ARG0]]
+
+func.func @fold_divsi_of_muli_0(%arg0 : index, %arg1 : index) -> index {
+ %0 = arith.muli %arg0, %arg1 overflow<nsw> : index
+ %1 = arith.divsi %0, %arg0 : index
+ return %1 : index
+}
+// CHECK-LABEL: func @fold_divsi_of_muli_0(
+// CHECK-SAME: %[[ARG0:.+]]: index,
+// CHECK-SAME: %[[ARG1:.+]]: index)
+// CHECK: return %[[ARG1]]
+
+func.func @fold_divsi_of_muli_1(%arg0 : index, %arg1 : index) -> index {
+ %0 = arith.muli %arg0, %arg1 overflow<nsw> : index
+ %1 = arith.divsi %0, %arg1 : index
+ return %1 : index
+}
+// CHECK-LABEL: func @fold_divsi_of_muli_1(
+// CHECK-SAME: %[[ARG0:.+]]: index,
+// CHECK-SAME: %[[ARG1:.+]]: index)
+// CHECK: return %[[ARG0]]
+
+// Do not fold divui(mul(a, v), v) -> a with nuw attribute.
+func.func @no_fold_divui_of_muli(%arg0 : index, %arg1 : index) -> index {
+ %0 = arith.muli %arg0, %arg1 : index
+ %1 = arith.divui %0, %arg0 : index
+ return %1 : index
+}
+// CHECK-LABEL: func @no_fold_divui_of_muli
+// CHECK: %[[T0:.+]] = arith.muli
+// CHECK: %[[T1:.+]] = arith.divui %[[T0]],
+// CHECK: return %[[T1]]
+
+// Do not fold divsi(mul(a, v), v) -> a with nuw attribute.
+func.func @no_fold_divsi_of_muli(%arg0 : index, %arg1 : index) -> index {
+ %0 = arith.muli %arg0, %arg1 : index
+ %1 = arith.divsi %0, %arg0 : index
+ return %1 : index
+}
+// CHECK-LABEL: func @no_fold_divsi_of_muli
+// CHECK: %[[T0:.+]] = arith.muli
+// CHECK: %[[T1:.+]] = arith.divsi %[[T0]],
+// CHECK: return %[[T1]]
diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir
index f6d3387d99b3c3..2785b508861228 100644
--- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir
+++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir
@@ -28,7 +28,7 @@ func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : in
// CHECK-SAME: -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>
// CHECK-DAG: %[[STRIDE0:.*]] = llvm.mlir.constant(4 : index) : i64
- // CHECK-DAG: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64
+ // CHECK-DAG: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] overflow<nsw> : i64
// CHECK-DAG: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64
// CHECK-DAG: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir
index a74553cc2268ef..c1f30c7eaf6430 100644
--- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir
+++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir
@@ -27,7 +27,7 @@ func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : in
// CHECK-SAME: -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>
// CHECK-DAG: %[[STRIDE0:.*]] = llvm.mlir.constant(4 : index) : i64
- // CHECK-DAG: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64
+ // CHECK-DAG: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] overflow<nsw> : i64
// CHECK-DAG: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64
// CHECK-DAG: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
More information about the Mlir-commits
mailing list