[flang-commits] [flang] [mlir] [mlir][llvm] Port `overflowFlags` to a native operation property (RELAND) (PR #89410)

Jeff Niu via flang-commits flang-commits at lists.llvm.org
Fri Apr 19 09:22:55 PDT 2024


https://github.com/Mogball created https://github.com/llvm/llvm-project/pull/89410

This PR changes the LLVM dialect's IntegerOverflowFlags to be stored on operations as native properties.

Reland to fix flang

>From 005d8b74130f1a75a009f74a52ade0c9f0a57b4e Mon Sep 17 00:00:00 2001
From: Jeff Niu <jeff at modular.com>
Date: Thu, 18 Apr 2024 15:35:39 -0700
Subject: [PATCH] [mlir][llvm] Port `overflowFlags` to a native operation
 property (#89312)

This PR changes the LLVM dialect's IntegerOverflowFlags to be stored on
operations as native properties.
---
 flang/lib/Optimizer/CodeGen/CodeGen.cpp       | 10 +--
 .../ArithCommon/AttrToLLVMConverter.h         | 22 +++---
 .../mlir/Conversion/LLVMCommon/Pattern.h      | 14 ++--
 .../Conversion/LLVMCommon/VectorPattern.h     | 16 ++--
 .../mlir/Dialect/LLVMIR/LLVMInterfaces.td     | 76 +++++++------------
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   | 23 ++++--
 .../include/mlir/Target/LLVMIR/ModuleImport.h |  3 +-
 .../ArithCommon/AttrToLLVMConverter.cpp       |  7 --
 mlir/lib/Conversion/LLVMCommon/Pattern.cpp    | 19 +++--
 .../Conversion/LLVMCommon/VectorPattern.cpp   | 26 +++----
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    | 76 ++++++++++++++++++-
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       |  7 +-
 12 files changed, 183 insertions(+), 116 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index d909bda89cdeb4..921eac2f8f4b60 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -2110,9 +2110,8 @@ struct XArrayCoorOpConversion
     const bool baseIsBoxed = coor.getMemref().getType().isa<fir::BaseBoxType>();
     TypePair baseBoxTyPair =
         baseIsBoxed ? getBoxTypePair(coor.getMemref().getType()) : TypePair{};
-    mlir::LLVM::IntegerOverflowFlagsAttr nsw =
-        mlir::LLVM::IntegerOverflowFlagsAttr::get(
-            rewriter.getContext(), mlir::LLVM::IntegerOverflowFlags::nsw);
+    mlir::LLVM::IntegerOverflowFlags nsw =
+        mlir::LLVM::IntegerOverflowFlags::nsw;
 
     // For each dimension of the array, generate the offset calculation.
     for (unsigned i = 0; i < rank; ++i, ++indexOffset, ++shapeOffset,
@@ -2396,9 +2395,8 @@ struct CoordinateOpConversion
     auto cpnTy = fir::dyn_cast_ptrOrBoxEleTy(boxObjTy);
     mlir::Type llvmPtrTy = ::getLlvmPtrType(coor.getContext());
     mlir::Type byteTy = ::getI8Type(coor.getContext());
-    mlir::LLVM::IntegerOverflowFlagsAttr nsw =
-        mlir::LLVM::IntegerOverflowFlagsAttr::get(
-            rewriter.getContext(), mlir::LLVM::IntegerOverflowFlags::nsw);
+    mlir::LLVM::IntegerOverflowFlags nsw =
+        mlir::LLVM::IntegerOverflowFlags::nsw;
 
     for (unsigned i = 1, last = operands.size(); i < last; ++i) {
       if (auto arrTy = cpnTy.dyn_cast<fir::SequenceType>()) {
diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index 0891e2ba7be760..7ffc8613317603 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -31,11 +31,6 @@ convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);
 LLVM::IntegerOverflowFlags
 convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
 
-/// Creates an LLVM overflow attribute from a given arithmetic overflow
-/// attribute.
-LLVM::IntegerOverflowFlagsAttr
-convertArithOverflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);
-
 /// Creates an LLVM rounding mode enum value from a given arithmetic rounding
 /// mode enum value.
 LLVM::RoundingMode
@@ -72,6 +67,9 @@ class AttrConvertFastMathToLLVM {
   }
 
   ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+  LLVM::IntegerOverflowFlags getOverflowFlags() const {
+    return LLVM::IntegerOverflowFlags::none;
+  }
 
 private:
   NamedAttrList convertedAttr;
@@ -89,19 +87,18 @@ class AttrConvertOverflowToLLVM {
     // Get the name of the arith overflow attribute.
     StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
     // Remove the source overflow attribute.
-    auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
-        convertedAttr.erase(arithAttrName));
-    if (arithAttr) {
-      StringRef targetAttrName = TargetOp::getIntegerOverflowAttrName();
-      convertedAttr.set(targetAttrName,
-                        convertArithOverflowAttrToLLVM(arithAttr));
+    if (auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
+            convertedAttr.erase(arithAttrName))) {
+      overflowFlags = convertArithOverflowFlagsToLLVM(arithAttr.getValue());
     }
   }
 
   ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+  LLVM::IntegerOverflowFlags getOverflowFlags() const { return overflowFlags; }
 
 private:
   NamedAttrList convertedAttr;
+  LLVM::IntegerOverflowFlags overflowFlags = LLVM::IntegerOverflowFlags::none;
 };
 
 template <typename SourceOp, typename TargetOp>
@@ -132,6 +129,9 @@ class AttrConverterConstrainedFPToLLVM {
   }
 
   ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+  LLVM::IntegerOverflowFlags getOverflowFlags() const {
+    return LLVM::IntegerOverflowFlags::none;
+  }
 
 private:
   NamedAttrList convertedAttr;
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index f362167ee93249..f3bf5b66398e09 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
@@ -18,13 +19,16 @@ class CallOpInterface;
 
 namespace LLVM {
 namespace detail {
+/// Handle generically setting flags as native properties on LLVM operations.
+void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags);
+
 /// Replaces the given operation "op" with a new operation of type "targetOp"
 /// and given operands.
-LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
-                              ValueRange operands,
-                              ArrayRef<NamedAttribute> targetAttrs,
-                              const LLVMTypeConverter &typeConverter,
-                              ConversionPatternRewriter &rewriter);
+LogicalResult oneToOneRewrite(
+    Operation *op, StringRef targetOp, ValueRange operands,
+    ArrayRef<NamedAttribute> targetAttrs,
+    const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
+    IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
 
 } // namespace detail
 } // namespace LLVM
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index 279175b6128fc7..964281592cc65e 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -54,11 +54,11 @@ LogicalResult handleMultidimensionalVectors(
     std::function<Value(Type, ValueRange)> createOperand,
     ConversionPatternRewriter &rewriter);
 
-LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
-                                    ValueRange operands,
-                                    ArrayRef<NamedAttribute> targetAttrs,
-                                    const LLVMTypeConverter &typeConverter,
-                                    ConversionPatternRewriter &rewriter);
+LogicalResult vectorOneToOneRewrite(
+    Operation *op, StringRef targetOp, ValueRange operands,
+    ArrayRef<NamedAttribute> targetAttrs,
+    const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
+    IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
 } // namespace detail
 } // namespace LLVM
 
@@ -70,6 +70,9 @@ class AttrConvertPassThrough {
   AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {}
 
   ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; }
+  LLVM::IntegerOverflowFlags getOverflowFlags() const {
+    return LLVM::IntegerOverflowFlags::none;
+  }
 
 private:
   ArrayRef<NamedAttribute> srcAttrs;
@@ -100,7 +103,8 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
 
     return LLVM::detail::vectorOneToOneRewrite(
         op, TargetOp::getOperationName(), adaptor.getOperands(),
-        attrConvert.getAttrs(), *this->getTypeConverter(), rewriter);
+        attrConvert.getAttrs(), *this->getTypeConverter(), rewriter,
+        attrConvert.getOverflowFlags());
   }
 };
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index cee752aeb269b7..7085f81e203a1e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -50,58 +50,40 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
 
 def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface"> {
   let description = [{
-    Access to op integer overflow flags.
+    This interface defines an LLVM operation with integer overflow flags and
+    provides a uniform API for accessing them.
   }];
 
   let cppNamespace = "::mlir::LLVM";
 
   let methods = [
-    InterfaceMethod<
-      /*desc=*/        "Returns an IntegerOverflowFlagsAttr attribute for the operation",
-      /*returnType=*/  "IntegerOverflowFlagsAttr",
-      /*methodName=*/  "getOverflowAttr",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
-        auto op = cast<ConcreteOp>(this->getOperation());
-        return op.getOverflowFlagsAttr();
-      }]
-      >,
-    InterfaceMethod<
-      /*desc=*/        "Returns whether the operation has the No Unsigned Wrap keyword",
-      /*returnType=*/  "bool",
-      /*methodName=*/  "hasNoUnsignedWrap",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
-        auto op = cast<ConcreteOp>(this->getOperation());
-        IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
-        return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw);
-      }]
-      >,
-    InterfaceMethod<
-      /*desc=*/        "Returns whether the operation has the No Signed Wrap keyword",
-      /*returnType=*/  "bool",
-      /*methodName=*/  "hasNoSignedWrap",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
-        auto op = cast<ConcreteOp>(this->getOperation());
-        IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
-        return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw);
-      }]
-      >,
-    StaticInterfaceMethod<
-      /*desc=*/        [{Returns the name of the IntegerOverflowFlagsAttr attribute
-                         for the operation}],
-      /*returnType=*/  "StringRef",
-      /*methodName=*/  "getIntegerOverflowAttrName",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
-        return "overflowFlags";
-      }]
-      >
+    InterfaceMethod<[{
+      Get the integer overflow flags for the operation.
+    }], "IntegerOverflowFlags", "getOverflowFlags", (ins), [{}], [{
+      return $_op.getProperties().overflowFlags;
+    }]>,
+    InterfaceMethod<[{
+      Set the integer overflow flags for the operation.
+    }], "void", "setOverflowFlags", (ins "IntegerOverflowFlags":$flags), [{}], [{
+      $_op.getProperties().overflowFlags = flags;
+    }]>,
+    InterfaceMethod<[{
+      Returns whether the operation has the No Unsigned Wrap keyword.
+    }], "bool", "hasNoUnsignedWrap", (ins), [{}], [{
+      return bitEnumContainsAll($_op.getOverflowFlags(),
+                                IntegerOverflowFlags::nuw);
+    }]>,
+    InterfaceMethod<[{
+      Returns whether the operation has the No Signed Wrap keyword.
+    }], "bool", "hasNoSignedWrap", (ins), [{}], [{
+      return bitEnumContainsAll($_op.getOverflowFlags(),
+                                IntegerOverflowFlags::nsw);
+    }]>,
+    StaticInterfaceMethod<[{
+      Get the attribute name of the overflow flags property.
+    }], "StringRef", "getOverflowFlagsAttrName", (ins), [{}], [{
+      return "overflowFlags";
+    }]>,
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index f8f9264b3889be..eedae4b9bb7c8e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -59,17 +59,30 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
                                    list<Trait> traits = []> :
     LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
     !listconcat([DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)> {
-  dag iofArg = (
-    ins DefaultValuedAttr<LLVM_IntegerOverflowFlagsAttr, "{}">:$overflowFlags);
+  dag iofArg = (ins EnumProperty<"IntegerOverflowFlags">:$overflowFlags);
   let arguments = !con(commonArgs, iofArg);
+
+  let builders = [
+    OpBuilder<(ins "Type":$type, "Value":$lhs, "Value":$rhs, 
+                   "IntegerOverflowFlags":$overflowFlags), [{
+      build($_builder, $_state, type, lhs, rhs);
+      $_state.getOrAddProperties<Properties>().overflowFlags = overflowFlags;
+    }]>,
+    OpBuilder<(ins "Value":$lhs, "Value":$rhs, 
+                   "IntegerOverflowFlags":$overflowFlags), [{
+      build($_builder, $_state, lhs, rhs);
+      $_state.getOrAddProperties<Properties>().overflowFlags = overflowFlags;
+    }]>
+  ];
+
   string mlirBuilder = [{
     auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
-    moduleImport.setIntegerOverflowFlagsAttr(inst, op);
+    moduleImport.setIntegerOverflowFlags(inst, op);
     $res = op;
   }];
   let assemblyFormat = [{
-    $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
-    custom<LLVMOpAttrs>(attr-dict) `:` type($res)
+    $lhs `,` $rhs `` custom<OverflowFlags>($overflowFlags)
+    `` custom<LLVMOpAttrs>(attr-dict) `:` type($res)
   }];
   string llvmBuilder =
     "$res = builder.Create" # instName #
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 1a188b1d042854..04d098d38155bf 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -183,8 +183,7 @@ class ModuleImport {
   /// Sets the integer overflow flags (nsw/nuw) attribute for the imported
   /// operation `op` given the original instruction `inst`. Asserts if the
   /// operation does not implement the integer overflow flag interface.
-  void setIntegerOverflowFlagsAttr(llvm::Instruction *inst,
-                                   Operation *op) const;
+  void setIntegerOverflowFlags(llvm::Instruction *inst, Operation *op) const;
 
   /// Sets the fastmath flags attribute for the imported operation `op` given
   /// the original instruction `inst`. Asserts if the operation does not
diff --git a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
index f12eba98480d33..cf60a048f782c6 100644
--- a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
+++ b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
@@ -49,13 +49,6 @@ LLVM::IntegerOverflowFlags mlir::arith::convertArithOverflowFlagsToLLVM(
   return llvmFlags;
 }
 
-LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOverflowAttrToLLVM(
-    arith::IntegerOverflowFlagsAttr flagsAttr) {
-  arith::IntegerOverflowFlags arithFlags = flagsAttr.getValue();
-  return LLVM::IntegerOverflowFlagsAttr::get(
-      flagsAttr.getContext(), convertArithOverflowFlagsToLLVM(arithFlags));
-}
-
 LLVM::RoundingMode
 mlir::arith::convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode) {
   switch (roundingMode) {
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 83c31a204efc7e..1886dfa870961a 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -329,14 +329,19 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
 // Detail methods
 //===----------------------------------------------------------------------===//
 
+void LLVM::detail::setNativeProperties(Operation *op,
+                                       IntegerOverflowFlags overflowFlags) {
+  if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
+    iface.setOverflowFlags(overflowFlags);
+}
+
 /// Replaces the given operation "op" with a new operation of type "targetOp"
 /// and given operands.
-LogicalResult
-LLVM::detail::oneToOneRewrite(Operation *op, StringRef targetOp,
-                              ValueRange operands,
-                              ArrayRef<NamedAttribute> targetAttrs,
-                              const LLVMTypeConverter &typeConverter,
-                              ConversionPatternRewriter &rewriter) {
+LogicalResult LLVM::detail::oneToOneRewrite(
+    Operation *op, StringRef targetOp, ValueRange operands,
+    ArrayRef<NamedAttribute> targetAttrs,
+    const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
+    IntegerOverflowFlags overflowFlags) {
   unsigned numResults = op->getNumResults();
 
   SmallVector<Type> resultTypes;
@@ -352,6 +357,8 @@ LLVM::detail::oneToOneRewrite(Operation *op, StringRef targetOp,
       rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
                       resultTypes, targetAttrs);
 
+  setNativeProperties(newOp, overflowFlags);
+
   // If the operation produced 0 or 1 result, return them immediately.
   if (numResults == 0)
     return rewriter.eraseOp(op), success();
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index 544bcc71aca1b5..626135c10a3e96 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -103,12 +103,11 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
   return success();
 }
 
-LogicalResult
-LLVM::detail::vectorOneToOneRewrite(Operation *op, StringRef targetOp,
-                                    ValueRange operands,
-                                    ArrayRef<NamedAttribute> targetAttrs,
-                                    const LLVMTypeConverter &typeConverter,
-                                    ConversionPatternRewriter &rewriter) {
+LogicalResult LLVM::detail::vectorOneToOneRewrite(
+    Operation *op, StringRef targetOp, ValueRange operands,
+    ArrayRef<NamedAttribute> targetAttrs,
+    const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
+    IntegerOverflowFlags overflowFlags) {
   assert(!operands.empty());
 
   // Cannot convert ops if their operands are not of LLVM type.
@@ -118,14 +117,15 @@ LLVM::detail::vectorOneToOneRewrite(Operation *op, StringRef targetOp,
   auto llvmNDVectorTy = operands[0].getType();
   if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy))
     return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter,
-                           rewriter);
+                           rewriter, overflowFlags);
 
-  auto callback = [op, targetOp, targetAttrs, &rewriter](Type llvm1DVectorTy,
-                                                         ValueRange operands) {
-    return rewriter
-        .create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
-                llvm1DVectorTy, targetAttrs)
-        ->getResult(0);
+  auto callback = [op, targetOp, targetAttrs, overflowFlags,
+                   &rewriter](Type llvm1DVectorTy, ValueRange operands) {
+    Operation *newOp =
+        rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp),
+                        operands, llvm1DVectorTy, targetAttrs);
+    LLVM::detail::setNativeProperties(newOp, overflowFlags);
+    return newOp->getResult(0);
   };
 
   return handleMultidimensionalVectors(op, operands, typeConverter, callback,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 1db506e286b3c0..78ff24dae68b4c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -47,6 +47,74 @@ using mlir::LLVM::linkage::getMaxEnumValForLinkage;
 
 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
 
+//===----------------------------------------------------------------------===//
+// Property Helpers
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// IntegerOverflowFlags
+
+namespace mlir {
+static Attribute convertToAttribute(MLIRContext *ctx,
+                                    IntegerOverflowFlags flags) {
+  return IntegerOverflowFlagsAttr::get(ctx, flags);
+}
+
+static LogicalResult
+convertFromAttribute(IntegerOverflowFlags &flags, Attribute attr,
+                     function_ref<InFlightDiagnostic()> emitError) {
+  auto flagsAttr = dyn_cast<IntegerOverflowFlagsAttr>(attr);
+  if (!flagsAttr) {
+    return emitError() << "expected 'overflowFlags' attribute to be an "
+                          "IntegerOverflowFlagsAttr, but got "
+                       << attr;
+  }
+  flags = flagsAttr.getValue();
+  return success();
+}
+} // namespace mlir
+
+static ParseResult parseOverflowFlags(AsmParser &p,
+                                      IntegerOverflowFlags &flags) {
+  if (failed(p.parseOptionalKeyword("overflow"))) {
+    flags = IntegerOverflowFlags::none;
+    return success();
+  }
+  if (p.parseLess())
+    return failure();
+  do {
+    StringRef kw;
+    SMLoc loc = p.getCurrentLocation();
+    if (p.parseKeyword(&kw))
+      return failure();
+    std::optional<IntegerOverflowFlags> flag =
+        symbolizeIntegerOverflowFlags(kw);
+    if (!flag)
+      return p.emitError(loc,
+                         "invalid overflow flag: expected nsw, nuw, or none");
+    flags = flags | *flag;
+  } while (succeeded(p.parseOptionalComma()));
+  return p.parseGreater();
+}
+
+static void printOverflowFlags(AsmPrinter &p, Operation *op,
+                               IntegerOverflowFlags flags) {
+  if (flags == IntegerOverflowFlags::none)
+    return;
+  p << " overflow<";
+  SmallVector<StringRef, 2> strs;
+  if (bitEnumContainsAny(flags, IntegerOverflowFlags::nsw))
+    strs.push_back("nsw");
+  if (bitEnumContainsAny(flags, IntegerOverflowFlags::nuw))
+    strs.push_back("nuw");
+  llvm::interleaveComma(strs, p);
+  p << ">";
+}
+
+//===----------------------------------------------------------------------===//
+// Attribute Helpers
+//===----------------------------------------------------------------------===//
+
 static constexpr const char kElemTypeAttrName[] = "elem_type";
 
 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
@@ -70,12 +138,12 @@ static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
                              DictionaryAttr attrs) {
   auto filteredAttrs = processFMFAttr(attrs.getValue());
-  if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
+  if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op)) {
     printer.printOptionalAttrDict(
-        filteredAttrs,
-        /*elidedAttrs=*/{iface.getIntegerOverflowAttrName()});
-  else
+        filteredAttrs, /*elidedAttrs=*/{iface.getOverflowFlagsAttrName()});
+  } else {
     printer.printOptionalAttrDict(filteredAttrs);
+  }
 }
 
 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 9d41fa0da37590..191b84acd56fae 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -627,8 +627,8 @@ void ModuleImport::setNonDebugMetadataAttrs(llvm::Instruction *inst,
   }
 }
 
-void ModuleImport::setIntegerOverflowFlagsAttr(llvm::Instruction *inst,
-                                               Operation *op) const {
+void ModuleImport::setIntegerOverflowFlags(llvm::Instruction *inst,
+                                           Operation *op) const {
   auto iface = cast<IntegerOverflowFlagsInterface>(op);
 
   IntegerOverflowFlags value = {};
@@ -636,8 +636,7 @@ void ModuleImport::setIntegerOverflowFlagsAttr(llvm::Instruction *inst,
   value =
       bitEnumSet(value, IntegerOverflowFlags::nuw, inst->hasNoUnsignedWrap());
 
-  auto attr = IntegerOverflowFlagsAttr::get(op->getContext(), value);
-  iface->setAttr(iface.getIntegerOverflowAttrName(), attr);
+  iface.setOverflowFlags(value);
 }
 
 void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,



More information about the flang-commits mailing list