[Mlir-commits] [mlir] [mlir][arith] Add overflow flags to `arith.trunci` (PR #144863)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 19 03:00:25 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

LLVM already support overflow flags on `llvm.trunc` for a while. This commit adds support for these flags to `arith.trunci`.


---
Full diff: https://github.com/llvm/llvm-project/pull/144863.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+30-6) 
- (modified) mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp (+2-1) 
- (modified) mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td (+6-6) 
- (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (+2) 
- (modified) mlir/test/Dialect/Arith/ops.mlir (+2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index adc27ae6bdafb..993f36f556e87 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -226,7 +226,7 @@ def Arith_AddIOp : Arith_IntBinaryOpWithOverflowFlags<"addi", [Commutative]> {
     these is required to be the same type. This type may be an integer scalar type, 
     a vector whose element type is integer, or a tensor of integers.
 
-    This op supports `nuw`/`nsw` overflow flags which stands stand for
+    This op supports `nuw`/`nsw` overflow flags which stands for
     "No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
     `nsw` flags are present, and an unsigned/signed overflow occurs
     (respectively), the result is poison.
@@ -321,7 +321,7 @@ def Arith_SubIOp : Arith_IntBinaryOpWithOverflowFlags<"subi"> {
     these is required to be the same type. This type may be an integer scalar type,
     a vector whose element type is integer, or a tensor of integers.
 
-    This op supports `nuw`/`nsw` overflow flags which stands stand for
+    This op supports `nuw`/`nsw` overflow flags which stands for
     "No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
     `nsw` flags are present, and an unsigned/signed overflow occurs
     (respectively), the result is poison.
@@ -367,7 +367,7 @@ def Arith_MulIOp : Arith_IntBinaryOpWithOverflowFlags<"muli",
     these is required to be the same type. This type may be an integer scalar type,
     a vector whose element type is integer, or a tensor of integers.
 
-    This op supports `nuw`/`nsw` overflow flags which stands stand for
+    This op supports `nuw`/`nsw` overflow flags which stands for
     "No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
     `nsw` flags are present, and an unsigned/signed overflow occurs
     (respectively), the result is poison.
@@ -800,7 +800,7 @@ def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> {
     operand is greater or equal than the bitwidth of the first operand, then the
     operation returns poison.
 
-    This op supports `nuw`/`nsw` overflow flags which stands stand for
+    This op supports `nuw`/`nsw` overflow flags which stands for
     "No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
     `nsw` flags are present, and an unsigned/signed overflow occurs
     (respectively), the result is poison.
@@ -1271,7 +1271,11 @@ def Arith_ScalingExtFOp
 // TruncIOp
 //===----------------------------------------------------------------------===//
 
-def Arith_TruncIOp : Arith_IToICastOp<"trunci"> {
+def Arith_TruncIOp : Op<Arith_Dialect, "trunci",
+    [Pure, SameOperandsAndResultShape, SameInputOutputTensorDims,
+     DeclareOpInterfaceMethods<CastOpInterface>,
+     DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
+     DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]> {
   let summary = "integer truncation operation";
   let description = [{
     The integer truncation operation takes an integer input of
@@ -1279,17 +1283,37 @@ def Arith_TruncIOp : Arith_IToICastOp<"trunci"> {
     bit-width must be smaller than the input bit-width (N < M).
     The top-most (N - M) bits of the input are discarded.
 
+    This op supports `nuw`/`nsw` overflow flags which stands for "No Unsigned
+    Wrap" and "No Signed Wrap", respectively. If the nuw keyword is present,
+    and any of the truncated bits are non-zero, the result is a poison value.
+    If the nsw keyword is present, and any of the truncated bits are not the
+    same as the top bit of the truncation result, the result is a poison value.
+
     Example:
 
     ```mlir
+      // Scalar truncation.
       %1 = arith.constant 21 : i5     // %1 is 0b10101
       %2 = arith.trunci %1 : i5 to i4 // %2 is 0b0101
       %3 = arith.trunci %1 : i5 to i3 // %3 is 0b101
 
-      %5 = arith.trunci %0 : vector<2 x i32> to vector<2 x i16>
+      // Vector truncation.
+      %4 = arith.trunci %0 : vector<2 x i32> to vector<2 x i16>
+
+      // Scalar truncation with overflow flags.
+      %5 = arith.trunci %a overflow<nsw, nuw> : i32 to i16
     ```
   }];
 
+  let arguments = (ins
+      SignlessFixedWidthIntegerLike:$in,
+      DefaultValuedAttr<Arith_IntegerOverflowAttr,
+          "::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags);
+  let results = (outs SignlessFixedWidthIntegerLike:$out);
+  let assemblyFormat = [{
+    $in (`overflow` `` $overflowFlags^)? attr-dict
+    `:` type($in) `to` type($out)
+  }];
   let hasFolder = 1;
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index ced18a48766bf..b8e5aa87244fa 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -163,7 +163,8 @@ using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
     arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
     arith::AttrConverterConstrainedFPToLLVM>;
 using TruncIOpLowering =
-    VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
+    VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp,
+                               arith::AttrConvertOverflowToLLVM>;
 using UIToFPOpLowering =
     VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>;
 using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>;
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 13eb97a910bd4..b61612436eb78 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -378,14 +378,14 @@ def TruncationMatchesShiftAmount :
 
 // trunci(extsi(x)) -> extsi(x), when only the sign-extension bits are truncated
 def TruncIExtSIToExtSI :
-    Pat<(Arith_TruncIOp:$tr (Arith_ExtSIOp:$ext $x)),
+    Pat<(Arith_TruncIOp:$tr (Arith_ExtSIOp:$ext $x), $overflow),
         (Arith_ExtSIOp $x),
         [(ValueWiderThan $ext, $tr),
          (ValueWiderThan $tr, $x)]>;
 
 // trunci(extui(x)) -> extui(x), when only the zero-extension bits are truncated
 def TruncIExtUIToExtUI :
-    Pat<(Arith_TruncIOp:$tr (Arith_ExtUIOp:$ext $x)),
+    Pat<(Arith_TruncIOp:$tr (Arith_ExtUIOp:$ext $x), $overflow),
         (Arith_ExtUIOp $x),
         [(ValueWiderThan $ext, $tr),
          (ValueWiderThan $tr, $x)]>;
@@ -393,8 +393,8 @@ def TruncIExtUIToExtUI :
 // trunci(shrsi(x, c)) -> trunci(shrui(x, c))
 def TruncIShrSIToTrunciShrUI :
     Pat<(Arith_TruncIOp:$tr
-          (Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0))),
-        (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0)))),
+          (Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0)), $overflow),
+        (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0))), $overflow),
         [(TruncationMatchesShiftAmount $x, $tr, $c0)]>;
 
 // trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
@@ -402,7 +402,7 @@ def TruncIShrUIMulIToMulSIExtended :
     Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
                               (Arith_MulIOp:$mul
                                 (Arith_ExtSIOp $x), (Arith_ExtSIOp $y), $ovf1),
-                              (ConstantLikeMatcher AnyAttr:$c0))),
+                              (ConstantLikeMatcher AnyAttr:$c0)), $overflow),
         (Arith_MulSIExtendedOp:$res__1 $x, $y),
       [(ValuesWithSameType $tr, $x, $y),
        (ValueWiderThan $mul, $x),
@@ -413,7 +413,7 @@ def TruncIShrUIMulIToMulUIExtended :
     Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
                               (Arith_MulIOp:$mul
                                 (Arith_ExtUIOp $x), (Arith_ExtUIOp $y), $ovf1),
-                              (ConstantLikeMatcher AnyAttr:$c0))),
+                              (ConstantLikeMatcher AnyAttr:$c0)), $overflow),
         (Arith_MulUIExtendedOp:$res__1 $x, $y),
       [(ValuesWithSameType $tr, $x, $y),
        (ValueWiderThan $mul, $x),
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index e0d974ea74041..83bdbe1f67118 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -731,6 +731,8 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
   %2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
   // CHECK: %{{.*}} = llvm.shl %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
   %3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
+  // CHECK: %{{.*}} = llvm.trunc %{{.*}} overflow<nsw, nuw> : i64 to i32
+  %4 = arith.trunci %arg0 overflow<nsw, nuw> : i64 to i32
   return
 }
 
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index f684e02344a51..1e656e84da836 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -1159,5 +1159,7 @@ func.func @intflags_func(%arg0: i64, %arg1: i64) {
   %2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
   // CHECK: %{{.*}} = arith.shli %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
   %3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
+  // CHECK: %{{.*}} = arith.trunci %{{.*}} overflow<nsw, nuw> : i64 to i32
+  %4 = arith.trunci %arg0 overflow<nsw, nuw> : i64 to i32
   return
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/144863


More information about the Mlir-commits mailing list