[Mlir-commits] [mlir] [mlir][arith] Add overflow flags support to arith ops (PR #77211)
Ivan Butygin
llvmlistbot at llvm.org
Mon Jan 8 14:58:41 PST 2024
https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/77211
>From 0c1f04dc1ae0667563171147a89859d1231c5b04 Mon Sep 17 00:00:00 2001
From: Yi Wu <yi.wu2 at arm.com>
Date: Fri, 5 Jan 2024 13:57:51 +0000
Subject: [PATCH 1/5] [mlir][arith] Add overflow flags support to arith ops
Add overflow flags support to the following ops:
* `arith.addi`
* `arith.subi`
* `arith.muli`
Tablegen canonicalization patterns updated to always drop flags, proper support with tests will be added later.
Updated LLVMIR translation as part of this commit as it currenly written in a way that it will crash when new attributes added to arith ops otherwise.
---
.../ArithCommon/AttrToLLVMConverter.h | 35 ++++++++
.../mlir/Dialect/Arith/IR/ArithBase.td | 23 ++++++
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 22 ++++-
.../Dialect/Arith/IR/ArithOpsInterfaces.td | 57 +++++++++++++
.../ArithCommon/AttrToLLVMConverter.cpp | 29 ++++++-
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 12 ++-
.../Dialect/Arith/IR/ArithCanonicalization.td | 80 ++++++++++---------
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 5 ++
.../Conversion/ArithToLLVM/arith-to-llvm.mlir | 13 +++
mlir/test/Dialect/Arith/ops.mlir | 11 +++
10 files changed, 239 insertions(+), 48 deletions(-)
diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index eea16b4da6a690..dbd0726fe16d1a 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -26,6 +26,14 @@ convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF);
LLVM::FastmathFlagsAttr
convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);
+// Map arithmetic overflow enum values to LLVMIR enum values.
+LLVM::IntegerOverflowFlags
+convertArithOveflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
+
+// Create an LLVM overflow attribute from a given arithmetic overflow attribute.
+LLVM::IntegerOverflowFlagsAttr
+convertArithOveflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);
+
// Attribute converter that populates a NamedAttrList by removing the fastmath
// attribute from the source operation attributes, and replacing it with an
// equivalent LLVM fastmath attribute.
@@ -49,6 +57,33 @@ class AttrConvertFastMathToLLVM {
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+private:
+ NamedAttrList convertedAttr;
+};
+
+// Attribute converter that populates a NamedAttrList by removing the overflow
+// attribute from the source operation attributes, and replacing it with an
+// equivalent LLVM fastmath attribute.
+template <typename SourceOp, typename TargetOp>
+class AttrConvertOverflowToLLVM {
+public:
+ AttrConvertOverflowToLLVM(SourceOp srcOp) {
+ // Copy the source attributes.
+ convertedAttr = NamedAttrList{srcOp->getAttrs()};
+ // Get the name of the arith fastmath attribute.
+ llvm::StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
+ // Remove the source fastmath attribute.
+ auto arithAttr = dyn_cast_or_null<arith::IntegerOverflowFlagsAttr>(
+ convertedAttr.erase(arithAttrName));
+ if (arithAttr) {
+ llvm::StringRef targetAttrName = TargetOp::getIntegerOverflowAttrName();
+ convertedAttr.set(targetAttrName,
+ convertArithOveflowAttrToLLVM(arithAttr));
+ }
+ }
+
+ ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+
private:
NamedAttrList convertedAttr;
};
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index 1e4061392b22d4..3fb7f948b0a45a 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -133,4 +133,27 @@ def Arith_FastMathAttr :
let assemblyFormat = "`<` $value `>`";
}
+//===----------------------------------------------------------------------===//
+// IntegerOverflowFlags
+//===----------------------------------------------------------------------===//
+
+def IOFnone : I32BitEnumAttrCaseNone<"none">;
+def IOFnsw : I32BitEnumAttrCaseBit<"nsw", 0>;
+def IOFnuw : I32BitEnumAttrCaseBit<"nuw", 1>;
+
+def IntegerOverflowFlags : I32BitEnumAttr<
+ "IntegerOverflowFlags",
+ "Integer overflow arith flags",
+ [IOFnone, IOFnsw, IOFnuw]> {
+ let separator = ", ";
+ let cppNamespace = "::mlir::arith";
+ let genSpecializedAttr = 0;
+ let printBitEnumPrimaryGroups = 1;
+}
+
+def Arith_IntegerOverflowAttr :
+ EnumAttr<Arith_Dialect, IntegerOverflowFlags, "overflow"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
#endif // ARITH_BASE
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 6d133d69dd0faf..880718bca9e7ec 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -137,6 +137,22 @@ class Arith_CompareOpOfAnyRank<string mnemonic, list<Trait> traits = []> :
let results = (outs BoolLikeOfAnyRank:$result);
}
+class Arith_IntArithmeticOpWithOverflowFlag<string mnemonic, list<Trait> traits = []> :
+ Arith_BinaryOp<mnemonic, traits #
+ [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>,
+ DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
+ Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs,
+ DefaultValuedAttr<
+ Arith_IntegerOverflowAttr, "::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags)>,
+ Results<(outs SignlessIntegerLike:$result)> {
+
+ let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
+ attr-dict `:` type($result) }];
+
+ let hasFolder = 1;
+ let hasCanonicalizer = 1;
+}
+
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
@@ -192,7 +208,7 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
// AddIOp
//===----------------------------------------------------------------------===//
-def Arith_AddIOp : Arith_TotalIntBinaryOp<"addi", [Commutative]> {
+def Arith_AddIOp : Arith_IntArithmeticOpWithOverflowFlag<"addi", [Commutative]> {
let summary = "integer addition operation";
let description = [{
Performs N-bit addition on the operands. The operands are interpreted as
@@ -278,7 +294,7 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
// SubIOp
//===----------------------------------------------------------------------===//
-def Arith_SubIOp : Arith_TotalIntBinaryOp<"subi"> {
+def Arith_SubIOp : Arith_IntArithmeticOpWithOverflowFlag<"subi"> {
let summary = [{
Integer subtraction operation.
}];
@@ -302,7 +318,7 @@ def Arith_SubIOp : Arith_TotalIntBinaryOp<"subi"> {
// MulIOp
//===----------------------------------------------------------------------===//
-def Arith_MulIOp : Arith_TotalIntBinaryOp<"muli", [Commutative]> {
+def Arith_MulIOp : Arith_IntArithmeticOpWithOverflowFlag<"muli", [Commutative]> {
let summary = [{
Integer multiplication operation.
}];
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index acaecf6f409dcf..e248422f84db84 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -49,4 +49,61 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
];
}
+def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
+ let description = [{
+ Access to op integer overflow flags.
+ }];
+
+ let cppNamespace = "::mlir::arith";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Returns an IntegerOverflowFlagsAttr attribute for the operation",
+ /*returnType=*/ "IntegerOverflowFlagsAttr",
+ /*methodName=*/ "getOverflowAttr",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ return op.getOverflowFlagsAttr();
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/ "Returns whether the operation has the No Unsigned Wrap keyword",
+ /*returnType=*/ "bool",
+ /*methodName=*/ "hasNoUnsignedWrap",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
+ return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/ "Returns whether the operation has the No Signed Wrap keyword",
+ /*returnType=*/ "bool",
+ /*methodName=*/ "hasNoSignedWrap",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
+ return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw);
+ }]
+ >,
+ StaticInterfaceMethod<
+ /*desc=*/ [{Returns the name of the IntegerOveflowFlagsAttr attribute
+ for the operation}],
+ /*returnType=*/ "StringRef",
+ /*methodName=*/ "getIntegerOverflowAttrName",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ return "overflowFlags";
+ }]
+ >
+ ];
+}
+
#endif // ARITH_OPS_INTERFACES
diff --git a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
index 8c5d76f9f2d72e..7ba12de122bb4d 100644
--- a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
+++ b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
@@ -22,9 +22,9 @@ mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
{arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
{arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
{arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
- for (auto fmfMap : flags) {
- if (bitEnumContainsAny(arithFMF, fmfMap.first))
- llvmFMF = llvmFMF | fmfMap.second;
+ for (auto [arithFlag, llvmFlag] : flags) {
+ if (bitEnumContainsAny(arithFMF, arithFlag))
+ llvmFMF = llvmFMF | llvmFlag;
}
return llvmFMF;
}
@@ -36,3 +36,26 @@ mlir::arith::convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr) {
return LLVM::FastmathFlagsAttr::get(
fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
}
+
+// Map arithmetic overflow enum values to LLVMIR enum values.
+LLVM::IntegerOverflowFlags mlir::arith::convertArithOveflowFlagsToLLVM(
+ arith::IntegerOverflowFlags arithFlags) {
+ LLVM::IntegerOverflowFlags llvmFlags{};
+ const std::pair<arith::IntegerOverflowFlags, LLVM::IntegerOverflowFlags>
+ flags[] = {
+ {arith::IntegerOverflowFlags::nsw, LLVM::IntegerOverflowFlags::nsw},
+ {arith::IntegerOverflowFlags::nuw, LLVM::IntegerOverflowFlags::nuw}};
+ for (auto [arithFlag, llvmFlag] : flags) {
+ if (bitEnumContainsAny(arithFlags, arithFlag))
+ llvmFlags = llvmFlags | llvmFlag;
+ }
+ return llvmFlags;
+}
+
+// Create an LLVM overflow attribute from a given arithmetic overflow attribute.
+LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOveflowAttrToLLVM(
+ arith::IntegerOverflowFlagsAttr flagsAttr) {
+ arith::IntegerOverflowFlags arithFlags = flagsAttr.getValue();
+ return LLVM::IntegerOverflowFlagsAttr::get(
+ flagsAttr.getContext(), convertArithOveflowFlagsToLLVM(arithFlags));
+}
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 5e4213cc4e874a..cf46e0d3ac46ac 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -35,7 +35,9 @@ namespace {
using AddFOpLowering =
VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
arith::AttrConvertFastMathToLLVM>;
-using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>;
+using AddIOpLowering =
+ VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
+ arith::AttrConvertOverflowToLLVM>;
using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
using BitcastOpLowering =
VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
@@ -78,7 +80,9 @@ using MinUIOpLowering =
using MulFOpLowering =
VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
arith::AttrConvertFastMathToLLVM>;
-using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>;
+using MulIOpLowering =
+ VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
+ arith::AttrConvertOverflowToLLVM>;
using NegFOpLowering =
VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
arith::AttrConvertFastMathToLLVM>;
@@ -102,7 +106,9 @@ using SIToFPOpLowering =
using SubFOpLowering =
VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
arith::AttrConvertFastMathToLLVM>;
-using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>;
+using SubIOpLowering =
+ VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
+ arith::AttrConvertOverflowToLLVM>;
using TruncFOpLowering =
VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
using TruncIOpLowering =
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index ef951647ccd146..19f0c0aac31713 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -24,6 +24,8 @@ def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
// Multiply two integer attributes and create a new one with the result.
def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
+def DefOverflow : NativeCodeCall<"getDefOverflowFlags($_builder)">;
+
class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
//===----------------------------------------------------------------------===//
@@ -36,23 +38,23 @@ class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
// addi(addi(x, c0), c1) -> addi(x, c0 + c1)
def AddIAddConstant :
Pat<(Arith_AddIOp:$res
- (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
- (ConstantLikeMatcher APIntAttr:$c1)),
- (Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)))>;
+ (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+ (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+ (Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), (DefOverflow))>;
// addi(subi(x, c0), c1) -> addi(x, c1 - c0)
def AddISubConstantRHS :
Pat<(Arith_AddIOp:$res
- (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
- (ConstantLikeMatcher APIntAttr:$c1)),
- (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)))>;
+ (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+ (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+ (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), (DefOverflow))>;
// addi(subi(c0, x), c1) -> subi(c0 + c1, x)
def AddISubConstantLHS :
Pat<(Arith_AddIOp:$res
- (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x),
- (ConstantLikeMatcher APIntAttr:$c1)),
- (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x)>;
+ (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
+ (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+ (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x, (DefOverflow))>;
def IsScalarOrSplatNegativeOne :
Constraint<And<[
@@ -63,24 +65,24 @@ def IsScalarOrSplatNegativeOne :
def AddIMulNegativeOneRhs :
Pat<(Arith_AddIOp
$x,
- (Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0))),
- (Arith_SubIOp $x, $y),
+ (Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0), $ovf1), $ovf2),
+ (Arith_SubIOp $x, $y, (DefOverflow)),
[(IsScalarOrSplatNegativeOne $c0)]>;
// addi(muli(x, -1), y) -> subi(y, x)
def AddIMulNegativeOneLhs :
Pat<(Arith_AddIOp
- (Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0)),
- $y),
- (Arith_SubIOp $y, $x),
+ (Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0), $ovf1),
+ $y, $ovf2),
+ (Arith_SubIOp $y, $x, (DefOverflow)),
[(IsScalarOrSplatNegativeOne $c0)]>;
// muli(muli(x, c0), c1) -> muli(x, c0 * c1)
def MulIMulIConstant :
Pat<(Arith_MulIOp:$res
- (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
- (ConstantLikeMatcher APIntAttr:$c1)),
- (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)))>;
+ (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+ (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+ (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)), (DefOverflow))>;
//===----------------------------------------------------------------------===//
// AddUIExtendedOp
@@ -90,7 +92,7 @@ def MulIMulIConstant :
// uses. Since the 'overflow' result is unused, any replacement value will do.
def AddUIExtendedToAddI:
Pattern<(Arith_AddUIExtendedOp:$res $x, $y),
- [(Arith_AddIOp $x, $y), (replaceWithValue $x)],
+ [(Arith_AddIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
[(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
//===----------------------------------------------------------------------===//
@@ -100,49 +102,49 @@ def AddUIExtendedToAddI:
// subi(addi(x, c0), c1) -> addi(x, c0 - c1)
def SubIRHSAddConstant :
Pat<(Arith_SubIOp:$res
- (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
- (ConstantLikeMatcher APIntAttr:$c1)),
- (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)))>;
+ (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+ (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+ (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), (DefOverflow))>;
// subi(c1, addi(x, c0)) -> subi(c1 - c0, x)
def SubILHSAddConstant :
Pat<(Arith_SubIOp:$res
(ConstantLikeMatcher APIntAttr:$c1),
- (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0))),
- (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x)>;
+ (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
+ (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x, (DefOverflow))>;
// subi(subi(x, c0), c1) -> subi(x, c0 + c1)
def SubIRHSSubConstantRHS :
Pat<(Arith_SubIOp:$res
- (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
- (ConstantLikeMatcher APIntAttr:$c1)),
- (Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)))>;
+ (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+ (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+ (Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), (DefOverflow))>;
// subi(subi(c0, x), c1) -> subi(c0 - c1, x)
def SubIRHSSubConstantLHS :
Pat<(Arith_SubIOp:$res
- (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x),
- (ConstantLikeMatcher APIntAttr:$c1)),
- (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x)>;
+ (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
+ (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+ (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x, (DefOverflow))>;
// subi(c1, subi(x, c0)) -> subi(c0 + c1, x)
def SubILHSSubConstantRHS :
Pat<(Arith_SubIOp:$res
(ConstantLikeMatcher APIntAttr:$c1),
- (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0))),
- (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x)>;
+ (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
+ (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x, (DefOverflow))>;
// subi(c1, subi(c0, x)) -> addi(x, c1 - c0)
def SubILHSSubConstantLHS :
Pat<(Arith_SubIOp:$res
(ConstantLikeMatcher APIntAttr:$c1),
- (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x)),
- (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)))>;
+ (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1), $ovf2),
+ (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), (DefOverflow))>;
// subi(subi(a, b), a) -> subi(0, b)
def SubISubILHSRHSLHS :
- Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y), $x),
- (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y)>;
+ Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y, $ovf1), $x, $ovf2),
+ (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, (DefOverflow))>;
//===----------------------------------------------------------------------===//
// MulSIExtendedOp
@@ -152,7 +154,7 @@ def SubISubILHSRHSLHS :
// Since the `high` result it not used, any replacement value will do.
def MulSIExtendedToMulI :
Pattern<(Arith_MulSIExtendedOp:$res $x, $y),
- [(Arith_MulIOp $x, $y), (replaceWithValue $x)],
+ [(Arith_MulIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
[(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
@@ -179,7 +181,7 @@ def MulSIExtendedRHSOne :
// Since the `high` result it not used, any replacement value will do.
def MulUIExtendedToMulI :
Pattern<(Arith_MulUIExtendedOp:$res $x, $y),
- [(Arith_MulIOp $x, $y), (replaceWithValue $x)],
+ [(Arith_MulIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
[(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
//===----------------------------------------------------------------------===//
@@ -403,7 +405,7 @@ def TruncIShrSIToTrunciShrUI :
def TruncIShrUIMulIToMulSIExtended :
Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
(Arith_MulIOp:$mul
- (Arith_ExtSIOp $x), (Arith_ExtSIOp $y)),
+ (Arith_ExtSIOp $x), (Arith_ExtSIOp $y), $ovf1),
(ConstantLikeMatcher AnyAttr:$c0))),
(Arith_MulSIExtendedOp:$res__1 $x, $y),
[(ValuesWithSameType $tr, $x, $y),
@@ -414,7 +416,7 @@ def TruncIShrUIMulIToMulSIExtended :
def TruncIShrUIMulIToMulUIExtended :
Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
(Arith_MulIOp:$mul
- (Arith_ExtUIOp $x), (Arith_ExtUIOp $y)),
+ (Arith_ExtUIOp $x), (Arith_ExtUIOp $y), $ovf1),
(ConstantLikeMatcher AnyAttr:$c0))),
(Arith_MulUIExtendedOp:$res__1 $x, $y),
[(ValuesWithSameType $tr, $x, $y),
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index ff72becc8dfa77..2d124ce4980fa4 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -61,6 +61,11 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
}
+static IntegerOverflowFlagsAttr getDefOverflowFlags(OpBuilder &builder) {
+ return IntegerOverflowFlagsAttr::get(builder.getContext(),
+ IntegerOverflowFlags::none);
+}
+
/// Invert an integer comparison predicate.
arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
switch (pred) {
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index e16dbb5661058c..8937b24e0d174d 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -575,3 +575,16 @@ func.func @ops_supporting_fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
%7 = arith.subf %arg0, %arg1 fastmath<fast> : f32
return
}
+
+// -----
+
+// CHECK-LABEL: @ops_supporting_overflow
+func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
+ // CHECK: %{{.*}} = llvm.add %{{.*}}, %{{.*}} overflow<nsw> : i64
+ %0 = arith.addi %arg0, %arg1 overflow<nsw> : i64
+ // CHECK: %{{.*}} = llvm.sub %{{.*}}, %{{.*}} overflow<nuw> : i64
+ %1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
+ // CHECK: %{{.*}} = llvm.mul %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
+ %2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
+ return
+}
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 6e10e540d1d178..8ae3273f32c6b0 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -1138,3 +1138,14 @@ func.func @select_tensor_encoding(
%0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "foo">, tensor<8xi32, "foo">
return %0 : tensor<8xi32, "foo">
}
+
+// CHECK-LABEL: @intflags_func
+func.func @intflags_func(%arg0: i64, %arg1: i64) {
+ // CHECK: %{{.*}} = arith.addi %{{.*}}, %{{.*}} overflow<nsw> : i64
+ %0 = arith.addi %arg0, %arg1 overflow<nsw> : i64
+ // CHECK: %{{.*}} = arith.subi %{{.*}}, %{{.*}} overflow<nuw> : i64
+ %1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
+ // CHECK: %{{.*}} = arith.muli %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
+ %2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
+ return
+}
>From 82eeda06eaf5bfdf3d5efc314d7faf102619291e Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 6 Jan 2024 23:53:14 +0100
Subject: [PATCH 2/5] fix python test
---
mlir/test/python/ir/diagnostic_handler.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/python/ir/diagnostic_handler.py b/mlir/test/python/ir/diagnostic_handler.py
index 2f4300d2c55dfa..d516cda8198976 100644
--- a/mlir/test/python/ir/diagnostic_handler.py
+++ b/mlir/test/python/ir/diagnostic_handler.py
@@ -113,7 +113,7 @@ def testDiagnosticNonEmptyNotes():
def callback(d):
# CHECK: DIAGNOSTIC:
# CHECK: message='arith.addi' op requires one result
- # CHECK: notes=['see current operation: "arith.addi"() : () -> ()']
+ # CHECK: notes=['see current operation: "arith.addi"() {{.*}} : () -> ()']
print(f"DIAGNOSTIC:")
print(f" message={d.message}")
print(f" notes={list(map(str, d.notes))}")
>From 3d09907fe711ece18cf2b6535ec4f0f429ac458a Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 7 Jan 2024 23:20:06 +0100
Subject: [PATCH 3/5] Update docs and comments
---
.../ArithCommon/AttrToLLVMConverter.h | 6 +-
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 79 +++++++++++++++----
2 files changed, 66 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index dbd0726fe16d1a..5dff716724a475 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -63,16 +63,16 @@ class AttrConvertFastMathToLLVM {
// Attribute converter that populates a NamedAttrList by removing the overflow
// attribute from the source operation attributes, and replacing it with an
-// equivalent LLVM fastmath attribute.
+// equivalent LLVM overflow attribute.
template <typename SourceOp, typename TargetOp>
class AttrConvertOverflowToLLVM {
public:
AttrConvertOverflowToLLVM(SourceOp srcOp) {
// Copy the source attributes.
convertedAttr = NamedAttrList{srcOp->getAttrs()};
- // Get the name of the arith fastmath attribute.
+ // Get the name of the arith overflow attribute.
llvm::StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
- // Remove the source fastmath attribute.
+ // Remove the source overflow attribute.
auto arithAttr = dyn_cast_or_null<arith::IntegerOverflowFlagsAttr>(
convertedAttr.erase(arithAttrName));
if (arithAttr) {
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 880718bca9e7ec..0f340fbf748f10 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -219,8 +219,12 @@ def Arith_AddIOp : Arith_IntArithmeticOpWithOverflowFlag<"addi", [Commutative]>
The `addi` operation takes two operands and returns one result, each of
these is required to be the same type. This type may be an integer scalar type,
- a vector whose element type is integer, or a tensor of integers. It has no
- standard attributes.
+ a vector whose element type is integer, or a tensor of integers.
+
+ This op supports nuw/nsw flags which stands stand for "No Unsigned Wrap" and
+ "No Signed Wrap", respectively. If the nuw and/or nsw flags are present, the
+ result value is undefined if unsigned and/or signed overflow, respectively,
+ occurs.
Example:
@@ -228,6 +232,9 @@ def Arith_AddIOp : Arith_IntArithmeticOpWithOverflowFlag<"addi", [Commutative]>
// Scalar addition.
%a = arith.addi %b, %c : i64
+ // Scalar addition with overflow flags.
+ %a = arith.addi %b, %c overflow<nsw, nuw> : i64
+
// SIMD vector element-wise addition, e.g. for Intel SSE.
%f = arith.addi %g, %h : vector<4xi32>
@@ -299,16 +306,36 @@ def Arith_SubIOp : Arith_IntArithmeticOpWithOverflowFlag<"subi"> {
Integer subtraction operation.
}];
let description = [{
- Performs N-bit subtraction on the operands. The operands are interpreted as unsigned
- bitvectors. The result is represented by a bitvector containing the mathematical
- value of the subtraction modulo 2^n, where `n` is the bitwidth. Because `arith`
- integers use a two's complement representation, this operation is applicable on
+ Performs N-bit subtraction on the operands. The operands are interpreted as unsigned
+ bitvectors. The result is represented by a bitvector containing the mathematical
+ value of the subtraction modulo 2^n, where `n` is the bitwidth. Because `arith`
+ integers use a two's complement representation, this operation is applicable on
both signed and unsigned integer operands.
The `subi` operation takes two operands and returns one result, each of
- these is required to be the same type. This type may be an integer scalar type,
- a vector whose element type is integer, or a tensor of integers. It has no
- standard attributes.
+ these is required to be the same type. This type may be an integer scalar type,
+ a vector whose element type is integer, or a tensor of integers.
+
+ This op supports nuw/nsw flags which stands stand for "No Unsigned Wrap" and
+ "No Signed Wrap", respectively. If the nuw and/or nsw flags are present, the
+ result value is undefined if unsigned and/or signed overflow, respectively,
+ occurs.
+
+ Example:
+
+ ```mlir
+ // Scalar subtraction.
+ %a = arith.subi %b, %c : i64
+
+ // Scalar subtraction with overflow flags.
+ %a = arith.subi %b, %c overflow<nsw, nuw> : i64
+
+ // SIMD vector element-wise subtraction, e.g. for Intel SSE.
+ %f = arith.subi %g, %h : vector<4xi32>
+
+ // Tensor element-wise subtraction.
+ %x = arith.subi %y, %z : tensor<4x?xi8>
+ ```
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
@@ -323,16 +350,36 @@ def Arith_MulIOp : Arith_IntArithmeticOpWithOverflowFlag<"muli", [Commutative]>
Integer multiplication operation.
}];
let description = [{
- Performs N-bit multiplication on the operands. The operands are interpreted as
- unsigned bitvectors. The result is represented by a bitvector containing the
- mathematical value of the multiplication modulo 2^n, where `n` is the bitwidth.
- Because `arith` integers use a two's complement representation, this operation is
+ Performs N-bit multiplication on the operands. The operands are interpreted as
+ unsigned bitvectors. The result is represented by a bitvector containing the
+ mathematical value of the multiplication modulo 2^n, where `n` is the bitwidth.
+ Because `arith` integers use a two's complement representation, this operation is
applicable on both signed and unsigned integer operands.
The `muli` operation takes two operands and returns one result, each of
- these is required to be the same type. This type may be an integer scalar type,
- a vector whose element type is integer, or a tensor of integers. It has no
- standard attributes.
+ these is required to be the same type. This type may be an integer scalar type,
+ a vector whose element type is integer, or a tensor of integers.
+
+ This op supports nuw/nsw flags which stands stand for "No Unsigned Wrap" and
+ "No Signed Wrap", respectively. If the nuw and/or nsw flags are present, the
+ result value is undefined if unsigned and/or signed overflow, respectively,
+ occurs.
+
+ Example:
+
+ ```mlir
+ // Scalar multiplication.
+ %a = arith.muli %b, %c : i64
+
+ // Scalar multiplication with overflow flags.
+ %a = arith.muli %b, %c overflow<nsw, nuw> : i64
+
+ // SIMD vector element-wise multiplication, e.g. for Intel SSE.
+ %f = arith.muli %g, %h : vector<4xi32>
+
+ // Tensor element-wise multiplication.
+ %x = arith.muli %y, %z : tensor<4x?xi8>
+ ```
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
>From 4c6f3ac77e2d32fcec63775365565c194045144e Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 8 Jan 2024 14:28:50 +0100
Subject: [PATCH 4/5] Update docs, Arith_IntBinaryOpWithOverflowFlags, cleanup
hasFolder/hasCanonicalizer
---
.../ArithCommon/AttrToLLVMConverter.h | 22 +++++-----
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 41 +++++++++----------
.../ArithCommon/AttrToLLVMConverter.cpp | 4 --
3 files changed, 31 insertions(+), 36 deletions(-)
diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index 5dff716724a475..0296ec969d0b32 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -18,19 +18,21 @@
namespace mlir {
namespace arith {
-// Map arithmetic fastmath enum values to LLVMIR enum values.
+/// Maps arithmetic fastmath enum values to LLVM enum values.
LLVM::FastmathFlags
convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF);
-// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
+/// Creates an LLVM fastmath attribute from a given arithmetic fastmath
+/// attribute.
LLVM::FastmathFlagsAttr
convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);
-// Map arithmetic overflow enum values to LLVMIR enum values.
+/// Maps arithmetic overflow enum values to LLVM enum values.
LLVM::IntegerOverflowFlags
convertArithOveflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
-// Create an LLVM overflow attribute from a given arithmetic overflow attribute.
+/// Creates an LLVM overflow attribute from a given arithmetic overflow
+/// attribute.
LLVM::IntegerOverflowFlagsAttr
convertArithOveflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);
@@ -44,12 +46,12 @@ class AttrConvertFastMathToLLVM {
// Copy the source attributes.
convertedAttr = NamedAttrList{srcOp->getAttrs()};
// Get the name of the arith fastmath attribute.
- llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
+ StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
// Remove the source fastmath attribute.
- auto arithFMFAttr = dyn_cast_or_null<arith::FastMathFlagsAttr>(
+ auto arithFMFAttr = dyn_cast_if_present<arith::FastMathFlagsAttr>(
convertedAttr.erase(arithFMFAttrName));
if (arithFMFAttr) {
- llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName();
+ StringRef targetAttrName = TargetOp::getFastmathAttrName();
convertedAttr.set(targetAttrName,
convertArithFastMathAttrToLLVM(arithFMFAttr));
}
@@ -71,12 +73,12 @@ class AttrConvertOverflowToLLVM {
// Copy the source attributes.
convertedAttr = NamedAttrList{srcOp->getAttrs()};
// Get the name of the arith overflow attribute.
- llvm::StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
+ StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
// Remove the source overflow attribute.
- auto arithAttr = dyn_cast_or_null<arith::IntegerOverflowFlagsAttr>(
+ auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
convertedAttr.erase(arithAttrName));
if (arithAttr) {
- llvm::StringRef targetAttrName = TargetOp::getIntegerOverflowAttrName();
+ StringRef targetAttrName = TargetOp::getIntegerOverflowAttrName();
convertedAttr.set(targetAttrName,
convertArithOveflowAttrToLLVM(arithAttr));
}
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 0f340fbf748f10..90b4d5475dba2c 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -137,7 +137,7 @@ class Arith_CompareOpOfAnyRank<string mnemonic, list<Trait> traits = []> :
let results = (outs BoolLikeOfAnyRank:$result);
}
-class Arith_IntArithmeticOpWithOverflowFlag<string mnemonic, list<Trait> traits = []> :
+class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic, traits #
[Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>,
DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
@@ -148,9 +148,6 @@ class Arith_IntArithmeticOpWithOverflowFlag<string mnemonic, list<Trait> traits
let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
attr-dict `:` type($result) }];
-
- let hasFolder = 1;
- let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
@@ -208,7 +205,7 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
// AddIOp
//===----------------------------------------------------------------------===//
-def Arith_AddIOp : Arith_IntArithmeticOpWithOverflowFlag<"addi", [Commutative]> {
+def Arith_AddIOp : Arith_IntBinaryOpWithOverflowFlags<"addi", [Commutative]> {
let summary = "integer addition operation";
let description = [{
Performs N-bit addition on the operands. The operands are interpreted as
@@ -221,10 +218,10 @@ def Arith_AddIOp : Arith_IntArithmeticOpWithOverflowFlag<"addi", [Commutative]>
these is required to be the same type. This type may be an integer scalar type,
a vector whose element type is integer, or a tensor of integers.
- This op supports nuw/nsw flags which stands stand for "No Unsigned Wrap" and
- "No Signed Wrap", respectively. If the nuw and/or nsw flags are present, the
- result value is undefined if unsigned and/or signed overflow, respectively,
- occurs.
+ This op supports `nuw`/`nsw` overflow flags which stands stand for
+ "No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
+ `nsw` flags are present, and an unsigned/signed overflow occurs
+ (respectively), the result is poison.
Example:
@@ -235,7 +232,7 @@ def Arith_AddIOp : Arith_IntArithmeticOpWithOverflowFlag<"addi", [Commutative]>
// Scalar addition with overflow flags.
%a = arith.addi %b, %c overflow<nsw, nuw> : i64
- // SIMD vector element-wise addition, e.g. for Intel SSE.
+ // SIMD vector element-wise addition.
%f = arith.addi %g, %h : vector<4xi32>
// Tensor element-wise addition.
@@ -301,7 +298,7 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
// SubIOp
//===----------------------------------------------------------------------===//
-def Arith_SubIOp : Arith_IntArithmeticOpWithOverflowFlag<"subi"> {
+def Arith_SubIOp : Arith_IntBinaryOpWithOverflowFlags<"subi"> {
let summary = [{
Integer subtraction operation.
}];
@@ -316,10 +313,10 @@ def Arith_SubIOp : Arith_IntArithmeticOpWithOverflowFlag<"subi"> {
these is required to be the same type. This type may be an integer scalar type,
a vector whose element type is integer, or a tensor of integers.
- This op supports nuw/nsw flags which stands stand for "No Unsigned Wrap" and
- "No Signed Wrap", respectively. If the nuw and/or nsw flags are present, the
- result value is undefined if unsigned and/or signed overflow, respectively,
- occurs.
+ This op supports `nuw`/`nsw` overflow flags which stands stand for
+ "No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
+ `nsw` flags are present, and an unsigned/signed overflow occurs
+ (respectively), the result is poison.
Example:
@@ -330,7 +327,7 @@ def Arith_SubIOp : Arith_IntArithmeticOpWithOverflowFlag<"subi"> {
// Scalar subtraction with overflow flags.
%a = arith.subi %b, %c overflow<nsw, nuw> : i64
- // SIMD vector element-wise subtraction, e.g. for Intel SSE.
+ // SIMD vector element-wise subtraction.
%f = arith.subi %g, %h : vector<4xi32>
// Tensor element-wise subtraction.
@@ -345,7 +342,7 @@ def Arith_SubIOp : Arith_IntArithmeticOpWithOverflowFlag<"subi"> {
// MulIOp
//===----------------------------------------------------------------------===//
-def Arith_MulIOp : Arith_IntArithmeticOpWithOverflowFlag<"muli", [Commutative]> {
+def Arith_MulIOp : Arith_IntBinaryOpWithOverflowFlags<"muli", [Commutative]> {
let summary = [{
Integer multiplication operation.
}];
@@ -360,10 +357,10 @@ def Arith_MulIOp : Arith_IntArithmeticOpWithOverflowFlag<"muli", [Commutative]>
these is required to be the same type. This type may be an integer scalar type,
a vector whose element type is integer, or a tensor of integers.
- This op supports nuw/nsw flags which stands stand for "No Unsigned Wrap" and
- "No Signed Wrap", respectively. If the nuw and/or nsw flags are present, the
- result value is undefined if unsigned and/or signed overflow, respectively,
- occurs.
+ This op supports `nuw`/`nsw` overflow flags which stands stand for
+ "No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
+ `nsw` flags are present, and an unsigned/signed overflow occurs
+ (respectively), the result is poison.
Example:
@@ -374,7 +371,7 @@ def Arith_MulIOp : Arith_IntArithmeticOpWithOverflowFlag<"muli", [Commutative]>
// Scalar multiplication with overflow flags.
%a = arith.muli %b, %c overflow<nsw, nuw> : i64
- // SIMD vector element-wise multiplication, e.g. for Intel SSE.
+ // SIMD vector element-wise multiplication.
%f = arith.muli %g, %h : vector<4xi32>
// Tensor element-wise multiplication.
diff --git a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
index 7ba12de122bb4d..3e9aef87b9ef1f 100644
--- a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
+++ b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
@@ -10,7 +10,6 @@
using namespace mlir;
-// Map arithmetic fastmath enum values to LLVMIR enum values.
LLVM::FastmathFlags
mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
LLVM::FastmathFlags llvmFMF{};
@@ -29,7 +28,6 @@ mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
return llvmFMF;
}
-// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
LLVM::FastmathFlagsAttr
mlir::arith::convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr) {
arith::FastMathFlags arithFMF = fmfAttr.getValue();
@@ -37,7 +35,6 @@ mlir::arith::convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr) {
fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
}
-// Map arithmetic overflow enum values to LLVMIR enum values.
LLVM::IntegerOverflowFlags mlir::arith::convertArithOveflowFlagsToLLVM(
arith::IntegerOverflowFlags arithFlags) {
LLVM::IntegerOverflowFlags llvmFlags{};
@@ -52,7 +49,6 @@ LLVM::IntegerOverflowFlags mlir::arith::convertArithOveflowFlagsToLLVM(
return llvmFlags;
}
-// Create an LLVM overflow attribute from a given arithmetic overflow attribute.
LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOveflowAttrToLLVM(
arith::IntegerOverflowFlagsAttr flagsAttr) {
arith::IntegerOverflowFlags arithFlags = flagsAttr.getValue();
>From c36ae9dc627e2446da4fdffca9beadd135459cbf Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 8 Jan 2024 23:57:19 +0100
Subject: [PATCH 5/5] Canonicalizations comment
---
mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td | 3 +++
1 file changed, 3 insertions(+)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 19f0c0aac31713..27a00dd0dc5d48 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -24,6 +24,9 @@ def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
// Multiply two integer attributes and create a new one with the result.
def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
+// TODO: Canonicalizations currently doesn't take in account integer overflow
+// flags and always reset it to default (wraparound) which is safe but can
+// inhibit later optimizations.
def DefOverflow : NativeCodeCall<"getDefOverflowFlags($_builder)">;
class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
More information about the Mlir-commits
mailing list