[Mlir-commits] [mlir] [mlir][LLVM] Add nsw and nuw flags to trunc (PR #115509)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 8 08:33:18 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: None (lfrenot)

<details>
<summary>Changes</summary>

This implementation is based on the one already existing for the binary operations.

If the nuw keyword is present, and any of the truncated bits are non-zero, the result is a poison value. If the nsw keyword is present, and any of the truncated bits are not the same as the top bit of the truncation result, the result is a poison value.

@<!-- -->zero9178, @<!-- -->gysit and @<!-- -->Dinistro, could you take a look?

---
Full diff: https://github.com/llvm/llvm-project/pull/115509.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+18-1) 
- (modified) mlir/test/Dialect/LLVMIR/roundtrip.mlir (+17) 
- (modified) mlir/test/Target/LLVMIR/Import/nsw_nuw.ll (+2) 
- (modified) mlir/test/Target/LLVMIR/nsw_nuw.mlir (+2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 315af2594047a5..ef81a068a36055 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -508,6 +508,23 @@ class LLVM_CastOp<string mnemonic, string instName, Type type,
       $_location, $_resultType, $arg);
   }];
 }
+class LLVM_CastOpWithOverflowFlag<string mnemonic, string instName, Type type,
+                  Type resultType, list<Trait> traits = []> :
+    LLVM_Op<mnemonic, !listconcat([Pure], [DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)>,
+    LLVM_Builder<"$res = builder.Create" # instName # "($arg, $_resultType, /*Name=*/\"\", op.hasNoUnsignedWrap(), op.hasNoSignedWrap());"> {
+  let arguments = (ins type:$arg, EnumProperty<"IntegerOverflowFlags", "", "IntegerOverflowFlags::none">:$overflowFlags);
+  let results = (outs resultType:$res);
+  let builders = [LLVM_OneResultOpBuilder];
+  let assemblyFormat = "$arg attr-dict `` custom<OverflowFlags>($overflowFlags) `:` type($arg) `to` type($res)";
+  string llvmInstName = instName;
+  string mlirBuilder = [{
+    auto op = $_builder.create<$_qualCppClassName>(
+      $_location, $_resultType, $arg);
+    moduleImport.setIntegerOverflowFlags(inst, op);
+    $res = op;
+  }];
+}
+
 def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "BitCast", LLVM_AnyNonAggregate,
     LLVM_AnyNonAggregate, [DeclareOpInterfaceMethods<PromotableOpInterface>]> {
   let hasFolder = 1;
@@ -537,7 +554,7 @@ def LLVM_ZExtOp : LLVM_CastOp<"zext", "ZExt",
   let hasFolder = 1;
   let hasVerifier = 1;
 }
-def LLVM_TruncOp : LLVM_CastOp<"trunc", "Trunc",
+def LLVM_TruncOp : LLVM_CastOpWithOverflowFlag<"trunc", "Trunc",
                                LLVM_ScalarOrVectorOf<AnySignlessInteger>,
                                LLVM_ScalarOrVectorOf<AnySignlessInteger>>;
 def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "SIToFP",
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 682780c5f0a7df..73776df3484273 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -325,6 +325,23 @@ func.func @casts(%arg0: i32, %arg1: i64, %arg2: vector<4xi32>,
   llvm.return
 }
 
+// CHECK-LABEL: @casts_overflow
+// CHECK-SAME: (%[[I32:.*]]: i32, %[[I64:.*]]: i64, %[[V4I32:.*]]: vector<4xi32>, %[[V4I64:.*]]: vector<4xi64>, %[[PTR:.*]]: !llvm.ptr)
+func.func @casts_overflow(%arg0: i32, %arg1: i64, %arg2: vector<4xi32>,
+            %arg3: vector<4xi64>, %arg4: !llvm.ptr) {
+// CHECK:  = llvm.trunc %[[I64]] overflow<nsw> : i64 to i56
+  %0 = llvm.trunc %arg1 overflow<nsw> : i64 to i56
+// CHECK:  = llvm.trunc %[[I64]] overflow<nuw> : i64 to i56
+  %1 = llvm.trunc %arg1 overflow<nuw> : i64 to i56
+// CHECK:  = llvm.trunc %[[I64]] overflow<nsw, nuw> : i64 to i56
+  %2 = llvm.trunc %arg1 overflow<nsw, nuw> : i64 to i56
+// CHECK:  = llvm.trunc %[[I64]] overflow<nsw, nuw> : i64 to i56
+  %3 = llvm.trunc %arg1 overflow<nuw, nsw> : i64 to i56
+// CHECK:  = llvm.trunc %[[V4I64]] overflow<nsw> : vector<4xi64> to vector<4xi56>
+  %4 = llvm.trunc %arg3 overflow<nsw> : vector<4xi64> to vector<4xi56>
+  llvm.return
+}
+
 // CHECK-LABEL: @vect
 func.func @vect(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32, %arg3: !llvm.vec<2 x ptr>) {
 // CHECK:  = llvm.extractelement {{.*}} : vector<4xf32>
diff --git a/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
index d08098a5e5dfe0..4af799da36dc08 100644
--- a/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
+++ b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
@@ -10,5 +10,7 @@ define void @intflag_inst(i64 %arg1, i64 %arg2) {
   %3 = mul nsw nuw i64 %arg1, %arg2
   ; CHECK: llvm.shl %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
   %4 = shl nuw nsw i64 %arg1, %arg2
+  ; CHECK: llvm.trunc %{{.*}} overflow<nsw> : i64 to i32
+  %5 = trunc nsw i64 %arg1 to i32
   ret void
 }
diff --git a/mlir/test/Target/LLVMIR/nsw_nuw.mlir b/mlir/test/Target/LLVMIR/nsw_nuw.mlir
index 6843c2ef0299c7..584aa05a04f7cf 100644
--- a/mlir/test/Target/LLVMIR/nsw_nuw.mlir
+++ b/mlir/test/Target/LLVMIR/nsw_nuw.mlir
@@ -10,5 +10,7 @@ llvm.func @intflags_func(%arg0: i64, %arg1: i64) {
   %2 = llvm.mul %arg0, %arg1 overflow <nsw, nuw> : i64
   // CHECK: %{{.*}} = shl nuw nsw i64 %{{.*}}, %{{.*}}
   %3 = llvm.shl %arg0, %arg1 overflow <nsw, nuw> : i64
+  // CHECK: %{{.*}} = trunc nuw i64 %{{.*}} to i32
+  %4 = llvm.trunc %arg1 overflow<nuw> : i64 to i32
   llvm.return
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/115509


More information about the Mlir-commits mailing list