[Mlir-commits] [mlir] eaba6e0 - [mlir][complex] Convert complex.abs to libm

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 7 17:56:38 PDT 2022


Author: lewuathe
Date: 2022-07-08T09:55:51+09:00
New Revision: eaba6e0b5cf596571f6c0ba5924ffba959566a3f

URL: https://github.com/llvm/llvm-project/commit/eaba6e0b5cf596571f6c0ba5924ffba959566a3f
DIFF: https://github.com/llvm/llvm-project/commit/eaba6e0b5cf596571f6c0ba5924ffba959566a3f.diff

LOG: [mlir][complex] Convert complex.abs to libm

Convert complex.abs to libm library

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D127476

Added: 
    

Modified: 
    mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp
    mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp
index c973489938461..802f171ba6a25 100644
--- a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp
+++ b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp
@@ -16,14 +16,43 @@
 using namespace mlir;
 
 namespace {
+// Functor to resolve the function name corresponding to the given complex
+// result type.
+struct ComplexTypeResolver {
+  llvm::Optional<bool> operator()(Type type) const {
+    auto complexType = type.cast<ComplexType>();
+    auto elementType = complexType.getElementType();
+    if (!elementType.isa<Float32Type, Float64Type>())
+      return {};
+
+    return elementType.getIntOrFloatBitWidth() == 64;
+  }
+};
+
+// Functor to resolve the function name corresponding to the given float result
+// type.
+struct FloatTypeResolver {
+  llvm::Optional<bool> operator()(Type type) const {
+    auto elementType = type.cast<FloatType>();
+    if (!elementType.isa<Float32Type, Float64Type>())
+      return {};
+
+    return elementType.getIntOrFloatBitWidth() == 64;
+  }
+};
+
 // Pattern to convert scalar complex operations to calls to libm functions.
 // Additionally the libm function signatures are declared.
-template <typename Op>
+// TypeResolver is a functor returning the libm function name according to the
+// expected type double or float.
+template <typename Op, typename TypeResolver = ComplexTypeResolver>
 struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
 public:
   using OpRewritePattern<Op>::OpRewritePattern;
-  ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc,
-                         StringRef doubleFunc, PatternBenefit benefit)
+  ScalarOpToLibmCall<Op, TypeResolver>(MLIRContext *context,
+                                       StringRef floatFunc,
+                                       StringRef doubleFunc,
+                                       PatternBenefit benefit)
       : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
         doubleFunc(doubleFunc){};
 
@@ -34,18 +63,16 @@ struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
 };
 } // namespace
 
-template <typename Op>
-LogicalResult
-ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
-                                        PatternRewriter &rewriter) const {
+template <typename Op, typename TypeResolver>
+LogicalResult ScalarOpToLibmCall<Op, TypeResolver>::matchAndRewrite(
+    Op op, PatternRewriter &rewriter) const {
   auto module = SymbolTable::getNearestSymbolTable(op);
-  auto type = op.getType().template cast<ComplexType>();
-  Type elementType = type.getElementType();
-  if (!elementType.isa<Float32Type, Float64Type>())
+  auto isDouble = TypeResolver()(op.getType());
+  if (!isDouble.hasValue())
     return failure();
 
-  auto name =
-      elementType.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
+  auto name = isDouble.value() ? doubleFunc : floatFunc;
+
   auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
       SymbolTable::lookupSymbolIn(module, name));
   // Forward declare function if it hasn't already been
@@ -60,7 +87,8 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
   }
   assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
 
-  rewriter.replaceOpWithNewOp<func::CallOp>(op, name, type, op->getOperands());
+  rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
+                                            op->getOperands());
 
   return success();
 }
@@ -79,6 +107,8 @@ void mlir::populateComplexToLibmConversionPatterns(RewritePatternSet &patterns,
                                                    "csinf", "csin", benefit);
   patterns.add<ScalarOpToLibmCall<complex::ConjOp>>(patterns.getContext(),
                                                     "conjf", "conj", benefit);
+  patterns.add<ScalarOpToLibmCall<complex::AbsOp, FloatTypeResolver>>(
+      patterns.getContext(), "cabsf", "cabs", benefit);
 }
 
 namespace {
@@ -96,7 +126,8 @@ void ConvertComplexToLibmPass::runOnOperation() {
 
   ConversionTarget target(getContext());
   target.addLegalDialect<func::FuncDialect>();
-  target.addIllegalOp<complex::PowOp, complex::SqrtOp, complex::TanhOp>();
+  target.addIllegalOp<complex::PowOp, complex::SqrtOp, complex::TanhOp,
+                      complex::AbsOp>();
   if (failed(applyPartialConversion(module, target, std::move(patterns))))
     signalPassFailure();
 }

diff  --git a/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir b/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir
index f0cbe37f000e7..ad6e5a2d482f3 100644
--- a/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir
@@ -9,6 +9,7 @@
 // CHECK-DAG: @ccos(complex<f64>) -> complex<f64>
 // CHECK-DAG: @csin(complex<f64>) -> complex<f64>
 // CHECK-DAG: @conj(complex<f64>) -> complex<f64>
+// CHECK-DAG: @cabs(complex<f64>) -> f64
 
 // CHECK-LABEL: func @cpow_caller
 // CHECK-SAME: %[[FLOAT:.*]]: complex<f32>
@@ -80,4 +81,16 @@ func.func @conj_caller(%float: complex<f32>, %double: complex<f64>) -> (complex<
   %double_result = complex.conj %double : complex<f64>
   // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
   return %float_result, %double_result : complex<f32>, complex<f64>
+}
+
+// CHECK-LABEL: func @cabs_caller
+// CHECK-SAME: %[[FLOAT:.*]]: complex<f32>
+// CHECK-SAME: %[[DOUBLE:.*]]: complex<f64>
+func.func @cabs_caller(%float: complex<f32>, %double: complex<f64>) -> (f32, f64)  {
+  // CHECK: %[[FLOAT_RESULT:.*]] = call @cabsf(%[[FLOAT]])
+  %float_result = complex.abs %float : complex<f32>
+  // CHECK: %[[DOUBLE_RESULT:.*]] = call @cabs(%[[DOUBLE]])
+  %double_result = complex.abs %double : complex<f64>
+  // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+  return %float_result, %double_result : f32, f64
 }
\ No newline at end of file


        


More information about the Mlir-commits mailing list