[Mlir-commits] [mlir] [mlir][arith] Add overflow flags to `arith.trunci` (PR #144863)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 19 03:00:25 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-arith
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
LLVM already support overflow flags on `llvm.trunc` for a while. This commit adds support for these flags to `arith.trunci`.
---
Full diff: https://github.com/llvm/llvm-project/pull/144863.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+30-6)
- (modified) mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp (+2-1)
- (modified) mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td (+6-6)
- (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (+2)
- (modified) mlir/test/Dialect/Arith/ops.mlir (+2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index adc27ae6bdafb..993f36f556e87 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -226,7 +226,7 @@ def Arith_AddIOp : Arith_IntBinaryOpWithOverflowFlags<"addi", [Commutative]> {
these is required to be the same type. This type may be an integer scalar type,
a vector whose element type is integer, or a tensor of integers.
- This op supports `nuw`/`nsw` overflow flags which stands stand for
+ This op supports `nuw`/`nsw` overflow flags which stands for
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
`nsw` flags are present, and an unsigned/signed overflow occurs
(respectively), the result is poison.
@@ -321,7 +321,7 @@ def Arith_SubIOp : Arith_IntBinaryOpWithOverflowFlags<"subi"> {
these is required to be the same type. This type may be an integer scalar type,
a vector whose element type is integer, or a tensor of integers.
- This op supports `nuw`/`nsw` overflow flags which stands stand for
+ This op supports `nuw`/`nsw` overflow flags which stands for
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
`nsw` flags are present, and an unsigned/signed overflow occurs
(respectively), the result is poison.
@@ -367,7 +367,7 @@ def Arith_MulIOp : Arith_IntBinaryOpWithOverflowFlags<"muli",
these is required to be the same type. This type may be an integer scalar type,
a vector whose element type is integer, or a tensor of integers.
- This op supports `nuw`/`nsw` overflow flags which stands stand for
+ This op supports `nuw`/`nsw` overflow flags which stands for
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
`nsw` flags are present, and an unsigned/signed overflow occurs
(respectively), the result is poison.
@@ -800,7 +800,7 @@ def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> {
operand is greater or equal than the bitwidth of the first operand, then the
operation returns poison.
- This op supports `nuw`/`nsw` overflow flags which stands stand for
+ This op supports `nuw`/`nsw` overflow flags which stands for
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
`nsw` flags are present, and an unsigned/signed overflow occurs
(respectively), the result is poison.
@@ -1271,7 +1271,11 @@ def Arith_ScalingExtFOp
// TruncIOp
//===----------------------------------------------------------------------===//
-def Arith_TruncIOp : Arith_IToICastOp<"trunci"> {
+def Arith_TruncIOp : Op<Arith_Dialect, "trunci",
+ [Pure, SameOperandsAndResultShape, SameInputOutputTensorDims,
+ DeclareOpInterfaceMethods<CastOpInterface>,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
+ DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]> {
let summary = "integer truncation operation";
let description = [{
The integer truncation operation takes an integer input of
@@ -1279,17 +1283,37 @@ def Arith_TruncIOp : Arith_IToICastOp<"trunci"> {
bit-width must be smaller than the input bit-width (N < M).
The top-most (N - M) bits of the input are discarded.
+ This op supports `nuw`/`nsw` overflow flags which stands for "No Unsigned
+ Wrap" and "No Signed Wrap", respectively. If the nuw keyword is present,
+ and any of the truncated bits are non-zero, the result is a poison value.
+ If the nsw keyword is present, and any of the truncated bits are not the
+ same as the top bit of the truncation result, the result is a poison value.
+
Example:
```mlir
+ // Scalar truncation.
%1 = arith.constant 21 : i5 // %1 is 0b10101
%2 = arith.trunci %1 : i5 to i4 // %2 is 0b0101
%3 = arith.trunci %1 : i5 to i3 // %3 is 0b101
- %5 = arith.trunci %0 : vector<2 x i32> to vector<2 x i16>
+ // Vector truncation.
+ %4 = arith.trunci %0 : vector<2 x i32> to vector<2 x i16>
+
+ // Scalar truncation with overflow flags.
+ %5 = arith.trunci %a overflow<nsw, nuw> : i32 to i16
```
}];
+ let arguments = (ins
+ SignlessFixedWidthIntegerLike:$in,
+ DefaultValuedAttr<Arith_IntegerOverflowAttr,
+ "::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags);
+ let results = (outs SignlessFixedWidthIntegerLike:$out);
+ let assemblyFormat = [{
+ $in (`overflow` `` $overflowFlags^)? attr-dict
+ `:` type($in) `to` type($out)
+ }];
let hasFolder = 1;
let hasCanonicalizer = 1;
let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index ced18a48766bf..b8e5aa87244fa 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -163,7 +163,8 @@ using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
arith::AttrConverterConstrainedFPToLLVM>;
using TruncIOpLowering =
- VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
+ VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp,
+ arith::AttrConvertOverflowToLLVM>;
using UIToFPOpLowering =
VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>;
using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>;
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 13eb97a910bd4..b61612436eb78 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -378,14 +378,14 @@ def TruncationMatchesShiftAmount :
// trunci(extsi(x)) -> extsi(x), when only the sign-extension bits are truncated
def TruncIExtSIToExtSI :
- Pat<(Arith_TruncIOp:$tr (Arith_ExtSIOp:$ext $x)),
+ Pat<(Arith_TruncIOp:$tr (Arith_ExtSIOp:$ext $x), $overflow),
(Arith_ExtSIOp $x),
[(ValueWiderThan $ext, $tr),
(ValueWiderThan $tr, $x)]>;
// trunci(extui(x)) -> extui(x), when only the zero-extension bits are truncated
def TruncIExtUIToExtUI :
- Pat<(Arith_TruncIOp:$tr (Arith_ExtUIOp:$ext $x)),
+ Pat<(Arith_TruncIOp:$tr (Arith_ExtUIOp:$ext $x), $overflow),
(Arith_ExtUIOp $x),
[(ValueWiderThan $ext, $tr),
(ValueWiderThan $tr, $x)]>;
@@ -393,8 +393,8 @@ def TruncIExtUIToExtUI :
// trunci(shrsi(x, c)) -> trunci(shrui(x, c))
def TruncIShrSIToTrunciShrUI :
Pat<(Arith_TruncIOp:$tr
- (Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0))),
- (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0)))),
+ (Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0)), $overflow),
+ (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0))), $overflow),
[(TruncationMatchesShiftAmount $x, $tr, $c0)]>;
// trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
@@ -402,7 +402,7 @@ def TruncIShrUIMulIToMulSIExtended :
Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
(Arith_MulIOp:$mul
(Arith_ExtSIOp $x), (Arith_ExtSIOp $y), $ovf1),
- (ConstantLikeMatcher AnyAttr:$c0))),
+ (ConstantLikeMatcher AnyAttr:$c0)), $overflow),
(Arith_MulSIExtendedOp:$res__1 $x, $y),
[(ValuesWithSameType $tr, $x, $y),
(ValueWiderThan $mul, $x),
@@ -413,7 +413,7 @@ def TruncIShrUIMulIToMulUIExtended :
Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
(Arith_MulIOp:$mul
(Arith_ExtUIOp $x), (Arith_ExtUIOp $y), $ovf1),
- (ConstantLikeMatcher AnyAttr:$c0))),
+ (ConstantLikeMatcher AnyAttr:$c0)), $overflow),
(Arith_MulUIExtendedOp:$res__1 $x, $y),
[(ValuesWithSameType $tr, $x, $y),
(ValueWiderThan $mul, $x),
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index e0d974ea74041..83bdbe1f67118 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -731,6 +731,8 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
// CHECK: %{{.*}} = llvm.shl %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
%3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
+ // CHECK: %{{.*}} = llvm.trunc %{{.*}} overflow<nsw, nuw> : i64 to i32
+ %4 = arith.trunci %arg0 overflow<nsw, nuw> : i64 to i32
return
}
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index f684e02344a51..1e656e84da836 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -1159,5 +1159,7 @@ func.func @intflags_func(%arg0: i64, %arg1: i64) {
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
// CHECK: %{{.*}} = arith.shli %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
%3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
+ // CHECK: %{{.*}} = arith.trunci %{{.*}} overflow<nsw, nuw> : i64 to i32
+ %4 = arith.trunci %arg0 overflow<nsw, nuw> : i64 to i32
return
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/144863
More information about the Mlir-commits
mailing list