[flang-commits] [flang] bc955ca - [flang] Support arith::FastMathFlagsAttr for fir::CallOp.

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Wed Nov 9 15:34:20 PST 2022


Author: Slava Zakharin
Date: 2022-11-09T15:31:09-08:00
New Revision: bc955cae35d2a9ef018221566c55a23cea341a90

URL: https://github.com/llvm/llvm-project/commit/bc955cae35d2a9ef018221566c55a23cea341a90
DIFF: https://github.com/llvm/llvm-project/commit/bc955cae35d2a9ef018221566c55a23cea341a90.diff

LOG: [flang] Support arith::FastMathFlagsAttr for fir::CallOp.

The main purpose of this patch is to propagate fastmath attribute
to SimplifyIntrinsicsPass, so that the inline code can inherit
the call operation's attributes. Even though I added translation
of fastmath from fir::CallOp to LLVM::CallOp, there are no fastmath
attributes in LLVM IR. It looks like the translation drops it.
This will need additional commits.

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D137602

Added: 
    flang/test/Fir/fir-fast-math.fir

Modified: 
    flang/include/flang/Optimizer/Dialect/FIRDialect.td
    flang/include/flang/Optimizer/Dialect/FIROps.td
    flang/lib/Optimizer/CodeGen/CodeGen.cpp
    flang/lib/Optimizer/Dialect/FIROps.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/FIRDialect.td b/flang/include/flang/Optimizer/Dialect/FIRDialect.td
index 2be44368bfe92..40501176a4683 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRDialect.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRDialect.td
@@ -26,6 +26,11 @@ def fir_Dialect : Dialect {
   let cppNamespace = "::fir";
   let useDefaultTypePrinterParser = 0;
   let useDefaultAttributePrinterParser = 0;
+  let dependentDialects = [
+    // Arith dialect provides FastMathFlagsAttr
+    // supported by some FIR operations.
+    "arith::ArithDialect"
+  ];
 }
 
 #endif // FORTRAN_DIALECT_FIR_DIALECT

diff  --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 69d0a47e7bdec..1669ac1e60cf4 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -14,6 +14,8 @@
 #ifndef FORTRAN_DIALECT_FIR_OPS
 #define FORTRAN_DIALECT_FIR_OPS
 
+include "mlir/Dialect/Arith/IR/ArithBase.td"
+include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
 include "flang/Optimizer/Dialect/FIRDialect.td"
 include "flang/Optimizer/Dialect/FIRTypes.td"
 include "flang/Optimizer/Dialect/FIRAttr.td"
@@ -2266,7 +2268,8 @@ def fir_IterWhileOp : region_Op<"iterate_while",
 // Procedure call operations
 //===----------------------------------------------------------------------===//
 
-def fir_CallOp : fir_Op<"call", [CallOpInterface]> {
+def fir_CallOp : fir_Op<"call",
+    [CallOpInterface, DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
   let summary = "call a procedure";
 
   let description = [{
@@ -2283,7 +2286,9 @@ def fir_CallOp : fir_Op<"call", [CallOpInterface]> {
 
   let arguments = (ins
     OptionalAttr<SymbolRefAttr>:$callee,
-    Variadic<AnyType>:$args
+    Variadic<AnyType>:$args,
+    DefaultValuedAttr<Arith_FastMathAttr,
+                      "::mlir::arith::FastMathFlags::none">:$fastmath
   );
   let results = (outs Variadic<AnyType>);
 

diff  --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index be49a0bf509be..81b55422087de 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -19,6 +19,7 @@
 #include "flang/Optimizer/Support/InternalNames.h"
 #include "flang/Optimizer/Support/TypeCode.h"
 #include "flang/Semantics/runtime-type-info.h"
+#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
@@ -699,8 +700,11 @@ struct CallOpConversion : public FIROpConversion<fir::CallOp> {
     llvm::SmallVector<mlir::Type> resultTys;
     for (auto r : call.getResults())
       resultTys.push_back(convertType(r.getType()));
+    // Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr.
+    mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
+        attrConvert(call);
     rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
-        call, resultTys, adaptor.getOperands(), call->getAttrs());
+        call, resultTys, adaptor.getOperands(), attrConvert.getAttrs());
     return mlir::success();
   }
 };

diff  --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 86628b792068b..08c041127bda6 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -655,8 +655,18 @@ void fir::CallOp::print(mlir::OpAsmPrinter &p) {
   else
     p << getOperand(0);
   p << '(' << (*this)->getOperands().drop_front(isDirect ? 0 : 1) << ')';
-  p.printOptionalAttrDict((*this)->getAttrs(),
-                          {fir::CallOp::getCalleeAttrNameStr()});
+
+  // Print 'fastmath<...>' (if it has non-default value) before
+  // any other attributes.
+  mlir::arith::FastMathFlagsAttr fmfAttr = getFastmathAttr();
+  if (fmfAttr.getValue() != mlir::arith::FastMathFlags::none) {
+    p << ' ' << mlir::arith::FastMathFlagsAttr::getMnemonic();
+    p.printStrippedAttrOrType(fmfAttr);
+  }
+
+  p.printOptionalAttrDict(
+      (*this)->getAttrs(),
+      {fir::CallOp::getCalleeAttrNameStr(), getFastmathAttrName()});
   auto resultTypes{getResultTypes()};
   llvm::SmallVector<mlir::Type> argTypes(
       llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1));
@@ -678,8 +688,18 @@ mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
       return mlir::failure();
 
   mlir::Type type;
-  if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren) ||
-      parser.parseOptionalAttrDict(attrs) || parser.parseColon() ||
+  if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren))
+    return mlir::failure();
+
+  // Parse 'fastmath<...>', if present.
+  mlir::arith::FastMathFlagsAttr fmfAttr;
+  llvm::StringRef fmfAttrName = getFastmathAttrName(result.name);
+  if (mlir::succeeded(parser.parseOptionalKeyword(fmfAttrName)))
+    if (parser.parseCustomAttributeWithFallback(fmfAttr, mlir::Type{},
+                                                fmfAttrName, attrs))
+      return mlir::failure();
+
+  if (parser.parseOptionalAttrDict(attrs) || parser.parseColon() ||
       parser.parseType(type))
     return mlir::failure();
 

diff  --git a/flang/test/Fir/fir-fast-math.fir b/flang/test/Fir/fir-fast-math.fir
new file mode 100644
index 0000000000000..b9ebe7248eed9
--- /dev/null
+++ b/flang/test/Fir/fir-fast-math.fir
@@ -0,0 +1,20 @@
+// RUN: fir-opt %s | fir-opt | FileCheck %s
+
+// CHECK-LABEL: @test_callop
+func.func @test_callop(%arg0 : f32) {
+  // CHECK: fir.call @callee() : () -> ()
+  fir.call @callee() fastmath<none> : () -> ()
+  // CHECK: fir.call @callee() : () -> ()
+  fir.call @callee() {fastmath = #arith.fastmath<none>} : () -> ()
+  // CHECK: fir.call @callee() fastmath<ninf,contract> : () -> ()
+  fir.call @callee() fastmath<ninf,contract> : () -> ()
+  // CHECK: fir.call @callee() fastmath<nnan,afn> : () -> ()
+  fir.call @callee() {fastmath = #arith.fastmath<nnan,afn>} : () -> ()
+  // CHECK: fir.call @callee() fastmath<fast> : () -> ()
+  fir.call @callee() fastmath<fast> : () -> ()
+  // CHECK: fir.call @callee() fastmath<fast> : () -> ()
+  fir.call @callee() {fastmath = #arith.fastmath<fast>} : () -> ()
+  return
+}
+
+func.func private @callee()


        


More information about the flang-commits mailing list