[Mlir-commits] [mlir] Add 'exact' flag to arith.shrui/shrsi/divsi/divui operations (PR #165923)
Jeremy Furtek
llvmlistbot at llvm.org
Fri Oct 31 14:21:08 PDT 2025
https://github.com/jfurtek created https://github.com/llvm/llvm-project/pull/165923
This MR adds support for the `exact` flag to the `arith.shrui/shrsi/divsi/divui` operations. The semantics are identical to those of the LLVM dialect and the LLVM language reference.
This MR also modifies the mechanism for converting `arith` dialect **attributes** to corresponding **properties** in the `LLVM` dialect. (As a specific example, the integer overflow flags `nsw/nuw` are **properties** in the `LLVM` dialect, as opposed to attributes.)
Previously, attribute converter classes were required to have a specific method to support integer overflow flags:
```C++
template <typename SourceOp, typename TargetOp>
class AttrConvertPassThrough {
public:
...
LLVM::IntegerOverflowFlags getOverflowFlags() const {
return LLVM::IntegerOverflowFlags::none;
}
};
```
This method was required, even for `arith` source operations that did not use integer overflow flags (e.g. `AttrConvertFastMathToLLVM`).
This MR modifies the interface required by `arith` dialect attribute converters to instead provide a (possibly NULL) properties attribute:
```C++
template <typename SourceOp, typename TargetOp>
class AttrConvertPassThrough {
public:
...
Attribute getPropAttr() const { return {}; }
};
```
For `arith` operations with attributes that map to `LLVM` dialect **properties**, the attribute converter can create a `DictionaryAttr` containing target properties and return that attribute from the attribute converter's `getPropAttr()` method. The `arith` attribute conversion framework will set the `propertiesAttr` of an `OperationState`, and the target operation's `setPropertiesFromAttr()` method will be invoked to set the properties when the target operation is created. The `AttrConvertOverflowToLLVM` class in this MR uses the new approach.
>From 810fcf73b8f15d73919ce12041ede071624d47c6 Mon Sep 17 00:00:00 2001
From: Jeremy Furtek <jfurtek at nvidia.com>
Date: Fri, 31 Oct 2025 15:22:25 -0500
Subject: [PATCH] Add 'exact' flag to arith.shrui/shrsi/divsi/divui operations
---
.../ArithCommon/AttrToLLVMConverter.h | 33 +++++++----
.../mlir/Conversion/LLVMCommon/Pattern.h | 20 +++----
.../Conversion/LLVMCommon/VectorPattern.h | 34 +++++++-----
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 55 +++++++++++++++----
.../Dialect/Arith/IR/ArithOpsInterfaces.td | 34 ++++++++++++
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 1 +
.../ComplexToLLVM/ComplexToLLVM.cpp | 3 +-
mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 21 +++----
.../Conversion/LLVMCommon/VectorPattern.cpp | 21 ++++---
.../Dialect/Arith/IR/ArithCanonicalization.td | 4 +-
.../Conversion/ArithToLLVM/arith-to-llvm.mlir | 16 ++++++
mlir/test/Dialect/Arith/canonicalize.mlir | 13 +++++
mlir/test/Dialect/Arith/ops.mlir | 24 ++++++++
13 files changed, 203 insertions(+), 76 deletions(-)
diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index 7ffc861331760..3f6215458f90c 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -65,11 +65,8 @@ class AttrConvertFastMathToLLVM {
convertArithFastMathAttrToLLVM(arithFMFAttr));
}
}
-
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
- LLVM::IntegerOverflowFlags getOverflowFlags() const {
- return LLVM::IntegerOverflowFlags::none;
- }
+ Attribute getPropAttr() const { return {}; }
private:
NamedAttrList convertedAttr;
@@ -82,23 +79,37 @@ template <typename SourceOp, typename TargetOp>
class AttrConvertOverflowToLLVM {
public:
AttrConvertOverflowToLLVM(SourceOp srcOp) {
+ using IntegerOverflowFlagsAttr = LLVM::IntegerOverflowFlagsAttr;
+
// Copy the source attributes.
convertedAttr = NamedAttrList{srcOp->getAttrs()};
// Get the name of the arith overflow attribute.
StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
- // Remove the source overflow attribute.
+ // Remove the source overflow attribute from the set that will be present
+ // in the target.
if (auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
convertedAttr.erase(arithAttrName))) {
- overflowFlags = convertArithOverflowFlagsToLLVM(arithAttr.getValue());
+ auto llvmFlag = convertArithOverflowFlagsToLLVM(arithAttr.getValue());
+ // Create a dictionary attribute holding the overflow flags property.
+ // (In the LLVM dialect, the overflow flags are a property, not an
+ // attribute.)
+ MLIRContext *ctx = srcOp.getOperation()->getContext();
+ Builder b(ctx);
+ auto llvmFlagAttr = IntegerOverflowFlagsAttr::get(ctx, llvmFlag);
+ StringRef llvmAttrName = TargetOp::getOverflowFlagsAttrName();
+ SmallVector<NamedAttribute> attrs;
+ attrs.push_back(b.getNamedAttr(llvmAttrName, llvmFlagAttr));
+ // Set the properties attribute of the operation state so that the
+ // property can be updated when the operation is created.
+ propertiesAttr = b.getDictionaryAttr(attrs);
}
}
-
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
- LLVM::IntegerOverflowFlags getOverflowFlags() const { return overflowFlags; }
+ Attribute getPropAttr() const { return propertiesAttr; }
private:
NamedAttrList convertedAttr;
- LLVM::IntegerOverflowFlags overflowFlags = LLVM::IntegerOverflowFlags::none;
+ DictionaryAttr propertiesAttr;
};
template <typename SourceOp, typename TargetOp>
@@ -129,9 +140,7 @@ class AttrConverterConstrainedFPToLLVM {
}
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
- LLVM::IntegerOverflowFlags getOverflowFlags() const {
- return LLVM::IntegerOverflowFlags::none;
- }
+ Attribute getPropAttr() const { return {}; }
private:
NamedAttrList convertedAttr;
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index c292e3727f46c..f8e0ccc093f8b 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -19,16 +19,14 @@ class CallOpInterface;
namespace LLVM {
namespace detail {
-/// Handle generically setting flags as native properties on LLVM operations.
-void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags);
-
/// Replaces the given operation "op" with a new operation of type "targetOp"
/// and given operands.
-LogicalResult oneToOneRewrite(
- Operation *op, StringRef targetOp, ValueRange operands,
- ArrayRef<NamedAttribute> targetAttrs,
- const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
- IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
+LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
+ ValueRange operands,
+ ArrayRef<NamedAttribute> targetAttrs,
+ Attribute propertiesAttr,
+ const LLVMTypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter);
/// Replaces the given operation "op" with a call to an LLVM intrinsic with the
/// specified name "intrinsic" and operands.
@@ -307,9 +305,9 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
- adaptor.getOperands(), op->getAttrs(),
- *this->getTypeConverter(), rewriter);
+ return LLVM::detail::oneToOneRewrite(
+ op, TargetOp::getOperationName(), adaptor.getOperands(), op->getAttrs(),
+ /*propertiesAttr=*/Attribute{}, *this->getTypeConverter(), rewriter);
}
};
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index 964281592cc65..5da239ee23066 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -54,25 +54,26 @@ LogicalResult handleMultidimensionalVectors(
std::function<Value(Type, ValueRange)> createOperand,
ConversionPatternRewriter &rewriter);
-LogicalResult vectorOneToOneRewrite(
- Operation *op, StringRef targetOp, ValueRange operands,
- ArrayRef<NamedAttribute> targetAttrs,
- const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
- IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
+LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
+ ValueRange operands,
+ ArrayRef<NamedAttribute> targetAttrs,
+ Attribute propertiesAttr,
+ const LLVMTypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter);
} // namespace detail
} // namespace LLVM
// Default attribute conversion class, which passes all source attributes
-// through to the target op, unmodified.
+// through to the target op, unmodified. The attribute to set properties of the
+// target operation will be nullptr (i.e. any properties that exist in will have
+// default values).
template <typename SourceOp, typename TargetOp>
class AttrConvertPassThrough {
public:
AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {}
ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; }
- LLVM::IntegerOverflowFlags getOverflowFlags() const {
- return LLVM::IntegerOverflowFlags::none;
- }
+ Attribute getPropAttr() const { return {}; }
private:
ArrayRef<NamedAttribute> srcAttrs;
@@ -80,10 +81,13 @@ class AttrConvertPassThrough {
/// Basic lowering implementation to rewrite Ops with just one result to the
/// LLVM Dialect. This supports higher-dimensional vector types.
-/// The AttrConvert template template parameter should be a template class
-/// with SourceOp and TargetOp type parameters, a constructor that takes
-/// a SourceOp instance, and a getAttrs() method that returns
-/// ArrayRef<NamedAttribute>.
+/// The AttrConvert template template parameter should:
+// - be a template class with SourceOp and TargetOp type parameters
+// - have a constructor that takes a SourceOp instance
+// - a getAttrs() method that returns ArrayRef<NamedAttribute> containing
+// attributes that the target operation will have
+// - a getPropAttr() method that returns either a NULL attribute or a
+// DictionaryAttribute with properties that exist for the target operation
template <typename SourceOp, typename TargetOp,
template <typename, typename> typename AttrConvert =
AttrConvertPassThrough>
@@ -103,8 +107,8 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
return LLVM::detail::vectorOneToOneRewrite(
op, TargetOp::getOperationName(), adaptor.getOperands(),
- attrConvert.getAttrs(), *this->getTypeConverter(), rewriter,
- attrConvert.getOverflowFlags());
+ attrConvert.getAttrs(), attrConvert.getPropAttr(),
+ *this->getTypeConverter(), rewriter);
}
};
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index a38cf41a3e09b..17b8486077697 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -158,6 +158,19 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
attr-dict `:` type($result) }];
}
+class Arith_IntBinaryOpWithExactFlag<string mnemonic, list<Trait> traits = []> :
+ Arith_BinaryOp<mnemonic, traits #
+ [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
+ DeclareOpInterfaceMethods<ArithExactFlagInterface>]>,
+ Arguments<(ins SignlessIntegerOrIndexLike:$lhs,
+ SignlessIntegerOrIndexLike:$rhs,
+ UnitAttr:$isExact)>,
+ Results<(outs SignlessIntegerOrIndexLike:$result)> {
+
+ let assemblyFormat = [{ $lhs `,` $rhs (`exact` $isExact^)?
+ attr-dict `:` type($result) }];
+}
+
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
@@ -482,7 +495,8 @@ def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative,
// DivUIOp
//===----------------------------------------------------------------------===//
-def Arith_DivUIOp : Arith_IntBinaryOp<"divui", [ConditionallySpeculatable]> {
+def Arith_DivUIOp : Arith_IntBinaryOpWithExactFlag<"divui",
+ [ConditionallySpeculatable]> {
let summary = "unsigned integer division operation";
let description = [{
Unsigned integer division. Rounds towards zero. Treats the leading bit as
@@ -493,12 +507,18 @@ def Arith_DivUIOp : Arith_IntBinaryOp<"divui", [ConditionallySpeculatable]> {
`tensor` values, the behavior is undefined if _any_ elements are divided by
zero.
+ If the `exact` attribute is present, the result value is poison if `lhs` is
+ not a multiple of `rhs`.
+
Example:
```mlir
// Scalar unsigned integer division.
%a = arith.divui %b, %c : i64
+ // Scalar unsigned integer division where %b is known to be a multiple of %c.
+ %a = arith.divui %b, %c exact : i64
+
// SIMD vector element-wise division.
%f = arith.divui %g, %h : vector<4xi32>
@@ -519,7 +539,8 @@ def Arith_DivUIOp : Arith_IntBinaryOp<"divui", [ConditionallySpeculatable]> {
// DivSIOp
//===----------------------------------------------------------------------===//
-def Arith_DivSIOp : Arith_IntBinaryOp<"divsi", [ConditionallySpeculatable]> {
+def Arith_DivSIOp : Arith_IntBinaryOpWithExactFlag<"divsi",
+ [ConditionallySpeculatable]> {
let summary = "signed integer division operation";
let description = [{
Signed integer division. Rounds towards zero. Treats the leading bit as
@@ -530,12 +551,18 @@ def Arith_DivSIOp : Arith_IntBinaryOp<"divsi", [ConditionallySpeculatable]> {
behavior is undefined if _any_ of its elements are divided by zero or has a
signed division overflow.
+ If the `exact` attribute is present, the result value is poison if `lhs` is
+ not a multiple of `rhs`.
+
Example:
```mlir
// Scalar signed integer division.
%a = arith.divsi %b, %c : i64
+ // Scalar signed integer division where %b is known to be a multiple of %c.
+ %a = arith.divsi %b, %c exact : i64
+
// SIMD vector element-wise division.
%f = arith.divsi %g, %h : vector<4xi32>
@@ -821,7 +848,7 @@ def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> {
// ShRUIOp
//===----------------------------------------------------------------------===//
-def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> {
+def Arith_ShRUIOp : Arith_IntBinaryOpWithExactFlag<"shrui", [Pure]> {
let summary = "unsigned integer right-shift";
let description = [{
The `shrui` operation shifts an integer value of the first operand to the right
@@ -830,12 +857,17 @@ def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> {
filled with zeros. If the value of the second operand is greater or equal than the
bitwidth of the first operand, then the operation returns poison.
+ If the `exact` keyword is present, the result value of shrui is a poison
+ value if any of the bits shifted out are non-zero.
+
Example:
```mlir
- %1 = arith.constant 160 : i8 // %1 is 0b10100000
+ %1 = arith.constant 160 : i8 // %1 is 0b10100000
%2 = arith.constant 3 : i8
- %3 = arith.shrui %1, %2 : (i8, i8) -> i8 // %3 is 0b00010100
+ %3 = arith.constant 6 : i8
+ %4 = arith.shrui %1, %2 exact : i8 // %4 is 0b00010100
+ %5 = arith.shrui %1, %3 : i8 // %3 is 0b00000010
```
}];
let hasFolder = 1;
@@ -845,7 +877,7 @@ def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> {
// ShRSIOp
//===----------------------------------------------------------------------===//
-def Arith_ShRSIOp : Arith_TotalIntBinaryOp<"shrsi"> {
+def Arith_ShRSIOp : Arith_IntBinaryOpWithExactFlag<"shrsi", [Pure]> {
let summary = "signed integer right-shift";
let description = [{
The `shrsi` operation shifts an integer value of the first operand to the right
@@ -856,14 +888,17 @@ def Arith_ShRSIOp : Arith_TotalIntBinaryOp<"shrsi"> {
operand is greater or equal than bitwidth of the first operand, then the operation
returns poison.
+ If the `exact` keyword is present, the result value of shrsi is a poison
+ value if any of the bits shifted out are non-zero.
+
Example:
```mlir
- %1 = arith.constant 160 : i8 // %1 is 0b10100000
+ %1 = arith.constant 160 : i8 // %1 is 0b10100000
%2 = arith.constant 3 : i8
- %3 = arith.shrsi %1, %2 : (i8, i8) -> i8 // %3 is 0b11110100
- %4 = arith.constant 96 : i8 // %4 is 0b01100000
- %5 = arith.shrsi %4, %2 : (i8, i8) -> i8 // %5 is 0b00001100
+ %3 = arith.shrsi %1, %2 exact : i8 // %3 is 0b11110100
+ %4 = arith.constant 98 : i8 // %4 is 0b01100010
+ %5 = arith.shrsi %4, %2 : i8 // %5 is 0b00001100
```
}];
let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index 82d6c9ad6b03d..bd6bd5f2d00d3 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -139,4 +139,38 @@ def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> {
];
}
+def ArithExactFlagInterface : OpInterface<"ArithExactFlagInterface"> {
+ let description = [{
+ Access to op exact flag.
+ }];
+
+ let cppNamespace = "::mlir::arith";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Returns whether the operation has the exact flag",
+ /*returnType=*/ "bool",
+ /*methodName=*/ "hasExactFlag",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ return op.getIsExact();
+ }]
+ >,
+ StaticInterfaceMethod<
+ /*desc=*/ [{Returns the name of the 'exact' attribute
+ for the operation}],
+ /*returnType=*/ "StringRef",
+ /*methodName=*/ "getExactAttrName",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ return "isExact";
+ }]
+ >
+
+ ];
+}
+
#endif // ARITH_OPS_INTERFACES
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index ba57155ab9b45..dc548e0798a84 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -259,6 +259,7 @@ ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
adaptor.getOperands(), op->getAttrs(),
+ /*propAttr=*/Attribute{},
*getTypeConverter(), rewriter);
}
diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
index 86d02e6c6209f..6a0c21185983e 100644
--- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
+++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
@@ -96,7 +96,8 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
ConversionPatternRewriter &rewriter) const override {
return LLVM::detail::oneToOneRewrite(
op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
- op->getAttrs(), *getTypeConverter(), rewriter);
+ op->getAttrs(), /*propAttr=*/Attribute{}, *getTypeConverter(),
+ rewriter);
}
};
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 48a03198fd465..f28a6ccb42455 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -296,19 +296,13 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
// Detail methods
//===----------------------------------------------------------------------===//
-void LLVM::detail::setNativeProperties(Operation *op,
- IntegerOverflowFlags overflowFlags) {
- if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
- iface.setOverflowFlags(overflowFlags);
-}
-
/// Replaces the given operation "op" with a new operation of type "targetOp"
/// and given operands.
LogicalResult LLVM::detail::oneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
- ArrayRef<NamedAttribute> targetAttrs,
- const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
- IntegerOverflowFlags overflowFlags) {
+ ArrayRef<NamedAttribute> targetAttrs, Attribute propertiesAttr,
+ const LLVMTypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter) {
unsigned numResults = op->getNumResults();
SmallVector<Type> resultTypes;
@@ -320,11 +314,10 @@ LogicalResult LLVM::detail::oneToOneRewrite(
}
// Create the operation through state since we don't know its C++ type.
- Operation *newOp =
- rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
- resultTypes, targetAttrs);
-
- setNativeProperties(newOp, overflowFlags);
+ OperationState state(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
+ resultTypes, targetAttrs);
+ state.propertiesAttr = propertiesAttr;
+ Operation *newOp = rewriter.create(state);
// If the operation produced 0 or 1 result, return them immediately.
if (numResults == 0)
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index e7dd0b506e12d..24b01259f0499 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -105,9 +105,9 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
LogicalResult LLVM::detail::vectorOneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
- ArrayRef<NamedAttribute> targetAttrs,
- const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
- IntegerOverflowFlags overflowFlags) {
+ ArrayRef<NamedAttribute> targetAttrs, Attribute propertiesAttr,
+ const LLVMTypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter) {
assert(!operands.empty());
// Cannot convert ops if their operands are not of LLVM type.
@@ -116,15 +116,14 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
auto llvmNDVectorTy = operands[0].getType();
if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy))
- return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter,
- rewriter, overflowFlags);
-
- auto callback = [op, targetOp, targetAttrs, overflowFlags,
+ return oneToOneRewrite(op, targetOp, operands, targetAttrs, propertiesAttr,
+ typeConverter, rewriter);
+ auto callback = [op, targetOp, targetAttrs, propertiesAttr,
&rewriter](Type llvm1DVectorTy, ValueRange operands) {
- Operation *newOp =
- rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp),
- operands, llvm1DVectorTy, targetAttrs);
- LLVM::detail::setNativeProperties(newOp, overflowFlags);
+ OperationState state(op->getLoc(), rewriter.getStringAttr(targetOp),
+ operands, llvm1DVectorTy, targetAttrs);
+ state.propertiesAttr = propertiesAttr;
+ Operation *newOp = rewriter.create(state);
return newOp->getResult(0);
};
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index de3efc9fe3506..e256915933a71 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -389,8 +389,8 @@ def TruncIExtUIToExtUI :
// trunci(shrsi(x, c)) -> trunci(shrui(x, c))
def TruncIShrSIToTrunciShrUI :
Pat<(Arith_TruncIOp:$tr
- (Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0)), $overflow),
- (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0))), $overflow),
+ (Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0), $exact), $overflow),
+ (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0)), $exact), $overflow),
[(TruncationMatchesShiftAmount $x, $tr, $c0)]>;
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index ba12ff29ebef9..93d4d6ad2547d 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -738,6 +738,22 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
// -----
+// CHECK-LABEL: @ops_supporting_exact
+func.func @ops_supporting_exact(i32, i32) {
+^bb0(%arg0: i32, %arg1: i32):
+// CHECK: = llvm.ashr exact %arg0, %arg1 : i32
+ %0 = arith.shrsi %arg0, %arg1 exact : i32
+// CHECK: = llvm.lshr exact %arg0, %arg1 : i32
+ %1 = arith.shrui %arg0, %arg1 exact : i32
+// CHECK: = llvm.sdiv exact %arg0, %arg1 : i32
+ %2 = arith.divsi %arg0, %arg1 exact : i32
+// CHECK: = llvm.udiv exact %arg0, %arg1 : i32
+ %3 = arith.divui %arg0, %arg1 exact : i32
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @memref_bitcast
// CHECK-SAME: (%[[ARG:.*]]: memref<?xi16>)
// CHECK: %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?xi16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 2fe0995c9d4df..3ad1530248809 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2958,6 +2958,19 @@ func.func @truncIShrSIToTrunciShrUI(%a: i64) -> i32 {
return %hi : i32
}
+// CHECK-LABEL: @truncIShrSIExactToTrunciShrUIExact
+// CHECK-SAME: (%[[A:.+]]: i64)
+// CHECK-NEXT: %[[C32:.+]] = arith.constant 32 : i64
+// CHECK-NEXT: %[[SHR:.+]] = arith.shrui %[[A]], %[[C32]] exact : i64
+// CHECK-NEXT: %[[TRU:.+]] = arith.trunci %[[SHR]] : i64 to i32
+// CHECK-NEXT: return %[[TRU]] : i32
+func.func @truncIShrSIExactToTrunciShrUIExact(%a: i64) -> i32 {
+ %c32 = arith.constant 32: i64
+ %sh = arith.shrsi %a, %c32 exact : i64
+ %hi = arith.trunci %sh: i64 to i32
+ return %hi : i32
+}
+
// CHECK-LABEL: @truncIShrSIToTrunciShrUIBadShiftAmt1
// CHECK: arith.shrsi
func.func @truncIShrSIToTrunciShrUIBadShiftAmt1(%a: i64) -> i32 {
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 1e656e84da836..58eadfda17060 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -151,6 +151,12 @@ func.func @test_divui(%arg0 : i64, %arg1 : i64) -> i64 {
return %0 : i64
}
+// CHECK-LABEL: test_divui_exact
+func.func @test_divui_exact(%arg0 : i64, %arg1 : i64) -> i64 {
+ %0 = arith.divui %arg0, %arg1 exact : i64
+ return %0 : i64
+}
+
// CHECK-LABEL: test_divui_tensor
func.func @test_divui_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> {
%0 = arith.divui %arg0, %arg1 : tensor<8x8xi64>
@@ -175,6 +181,12 @@ func.func @test_divsi(%arg0 : i64, %arg1 : i64) -> i64 {
return %0 : i64
}
+// CHECK-LABEL: test_divsi_exact
+func.func @test_divsi_exact(%arg0 : i64, %arg1 : i64) -> i64 {
+ %0 = arith.divsi %arg0, %arg1 exact : i64
+ return %0 : i64
+}
+
// CHECK-LABEL: test_divsi_tensor
func.func @test_divsi_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> {
%0 = arith.divsi %arg0, %arg1 : tensor<8x8xi64>
@@ -391,6 +403,12 @@ func.func @test_shrui(%arg0 : i64, %arg1 : i64) -> i64 {
return %0 : i64
}
+// CHECK-LABEL: test_shrui_exact
+func.func @test_shrui_exact(%arg0 : i64, %arg1 : i64) -> i64 {
+ %0 = arith.shrui %arg0, %arg1 exact : i64
+ return %0 : i64
+}
+
// CHECK-LABEL: test_shrui_tensor
func.func @test_shrui_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> {
%0 = arith.shrui %arg0, %arg1 : tensor<8x8xi64>
@@ -415,6 +433,12 @@ func.func @test_shrsi(%arg0 : i64, %arg1 : i64) -> i64 {
return %0 : i64
}
+// CHECK-LABEL: test_shrsi_exact
+func.func @test_shrsi_exact(%arg0 : i64, %arg1 : i64) -> i64 {
+ %0 = arith.shrsi %arg0, %arg1 exact : i64
+ return %0 : i64
+}
+
// CHECK-LABEL: test_shrsi_tensor
func.func @test_shrsi_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> {
%0 = arith.shrsi %arg0, %arg1 : tensor<8x8xi64>
More information about the Mlir-commits
mailing list