[Mlir-commits] [mlir] eaba6e0 - [mlir][complex] Convert complex.abs to libm
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 7 17:56:38 PDT 2022
Author: lewuathe
Date: 2022-07-08T09:55:51+09:00
New Revision: eaba6e0b5cf596571f6c0ba5924ffba959566a3f
URL: https://github.com/llvm/llvm-project/commit/eaba6e0b5cf596571f6c0ba5924ffba959566a3f
DIFF: https://github.com/llvm/llvm-project/commit/eaba6e0b5cf596571f6c0ba5924ffba959566a3f.diff
LOG: [mlir][complex] Convert complex.abs to libm
Convert complex.abs to libm library
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D127476
Added:
Modified:
mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp
mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp
index c973489938461..802f171ba6a25 100644
--- a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp
+++ b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp
@@ -16,14 +16,43 @@
using namespace mlir;
namespace {
+// Functor to resolve the function name corresponding to the given complex
+// result type.
+struct ComplexTypeResolver {
+ llvm::Optional<bool> operator()(Type type) const {
+ auto complexType = type.cast<ComplexType>();
+ auto elementType = complexType.getElementType();
+ if (!elementType.isa<Float32Type, Float64Type>())
+ return {};
+
+ return elementType.getIntOrFloatBitWidth() == 64;
+ }
+};
+
+// Functor to resolve the function name corresponding to the given float result
+// type.
+struct FloatTypeResolver {
+ llvm::Optional<bool> operator()(Type type) const {
+ auto elementType = type.cast<FloatType>();
+ if (!elementType.isa<Float32Type, Float64Type>())
+ return {};
+
+ return elementType.getIntOrFloatBitWidth() == 64;
+ }
+};
+
// Pattern to convert scalar complex operations to calls to libm functions.
// Additionally the libm function signatures are declared.
-template <typename Op>
+// TypeResolver is a functor returning the libm function name according to the
+// expected type double or float.
+template <typename Op, typename TypeResolver = ComplexTypeResolver>
struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
public:
using OpRewritePattern<Op>::OpRewritePattern;
- ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc,
- StringRef doubleFunc, PatternBenefit benefit)
+ ScalarOpToLibmCall<Op, TypeResolver>(MLIRContext *context,
+ StringRef floatFunc,
+ StringRef doubleFunc,
+ PatternBenefit benefit)
: OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
doubleFunc(doubleFunc){};
@@ -34,18 +63,16 @@ struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
};
} // namespace
-template <typename Op>
-LogicalResult
-ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
- PatternRewriter &rewriter) const {
+template <typename Op, typename TypeResolver>
+LogicalResult ScalarOpToLibmCall<Op, TypeResolver>::matchAndRewrite(
+ Op op, PatternRewriter &rewriter) const {
auto module = SymbolTable::getNearestSymbolTable(op);
- auto type = op.getType().template cast<ComplexType>();
- Type elementType = type.getElementType();
- if (!elementType.isa<Float32Type, Float64Type>())
+ auto isDouble = TypeResolver()(op.getType());
+ if (!isDouble.hasValue())
return failure();
- auto name =
- elementType.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
+ auto name = isDouble.value() ? doubleFunc : floatFunc;
+
auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
SymbolTable::lookupSymbolIn(module, name));
// Forward declare function if it hasn't already been
@@ -60,7 +87,8 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
}
assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
- rewriter.replaceOpWithNewOp<func::CallOp>(op, name, type, op->getOperands());
+ rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
+ op->getOperands());
return success();
}
@@ -79,6 +107,8 @@ void mlir::populateComplexToLibmConversionPatterns(RewritePatternSet &patterns,
"csinf", "csin", benefit);
patterns.add<ScalarOpToLibmCall<complex::ConjOp>>(patterns.getContext(),
"conjf", "conj", benefit);
+ patterns.add<ScalarOpToLibmCall<complex::AbsOp, FloatTypeResolver>>(
+ patterns.getContext(), "cabsf", "cabs", benefit);
}
namespace {
@@ -96,7 +126,8 @@ void ConvertComplexToLibmPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<func::FuncDialect>();
- target.addIllegalOp<complex::PowOp, complex::SqrtOp, complex::TanhOp>();
+ target.addIllegalOp<complex::PowOp, complex::SqrtOp, complex::TanhOp,
+ complex::AbsOp>();
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir b/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir
index f0cbe37f000e7..ad6e5a2d482f3 100644
--- a/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir
@@ -9,6 +9,7 @@
// CHECK-DAG: @ccos(complex<f64>) -> complex<f64>
// CHECK-DAG: @csin(complex<f64>) -> complex<f64>
// CHECK-DAG: @conj(complex<f64>) -> complex<f64>
+// CHECK-DAG: @cabs(complex<f64>) -> f64
// CHECK-LABEL: func @cpow_caller
// CHECK-SAME: %[[FLOAT:.*]]: complex<f32>
@@ -80,4 +81,16 @@ func.func @conj_caller(%float: complex<f32>, %double: complex<f64>) -> (complex<
%double_result = complex.conj %double : complex<f64>
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
return %float_result, %double_result : complex<f32>, complex<f64>
+}
+
+// CHECK-LABEL: func @cabs_caller
+// CHECK-SAME: %[[FLOAT:.*]]: complex<f32>
+// CHECK-SAME: %[[DOUBLE:.*]]: complex<f64>
+func.func @cabs_caller(%float: complex<f32>, %double: complex<f64>) -> (f32, f64) {
+ // CHECK: %[[FLOAT_RESULT:.*]] = call @cabsf(%[[FLOAT]])
+ %float_result = complex.abs %float : complex<f32>
+ // CHECK: %[[DOUBLE_RESULT:.*]] = call @cabs(%[[DOUBLE]])
+ %double_result = complex.abs %double : complex<f64>
+ // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+ return %float_result, %double_result : f32, f64
}
\ No newline at end of file
More information about the Mlir-commits
mailing list