[Mlir-commits] [mlir] 58cdb8b - [mlir][StandardToSPIRV] Add support for lowering unary ops
Lei Zhang
llvmlistbot at llvm.org
Tue Mar 24 06:17:13 PDT 2020
Author: Hanhan Wang
Date: 2020-03-24T09:16:10-04:00
New Revision: 58cdb8bff067a521dd68d6b699b13da74188a68b
URL: https://github.com/llvm/llvm-project/commit/58cdb8bff067a521dd68d6b699b13da74188a68b
DIFF: https://github.com/llvm/llvm-project/commit/58cdb8bff067a521dd68d6b699b13da74188a68b.diff
LOG: [mlir][StandardToSPIRV] Add support for lowering unary ops
Differential Revision: https://reviews.llvm.org/D76661
Added:
Modified:
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 69ef69d1de65..ea8812cebdc4 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -107,16 +107,16 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
namespace {
-/// Converts binary standard operations to SPIR-V operations.
+/// Converts unary and binary standard operations to SPIR-V operations.
template <typename StdOp, typename SPIRVOp>
-class BinaryOpPattern final : public SPIRVOpLowering<StdOp> {
+class UnaryAndBinaryOpPattern final : public SPIRVOpLowering<StdOp> {
public:
using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
LogicalResult
matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- assert(operands.size() == 2);
+ assert(operands.size() <= 2);
auto dstType = this->typeConverter.convertType(operation.getType());
if (!dstType)
return failure();
@@ -572,21 +572,31 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
patterns.insert<
- BinaryOpPattern<AddFOp, spirv::FAddOp>,
- BinaryOpPattern<AddIOp, spirv::IAddOp>,
- BinaryOpPattern<DivFOp, spirv::FDivOp>,
- BinaryOpPattern<MulFOp, spirv::FMulOp>,
- BinaryOpPattern<MulIOp, spirv::IMulOp>,
- BinaryOpPattern<RemFOp, spirv::FRemOp>,
- BinaryOpPattern<ShiftLeftOp, spirv::ShiftLeftLogicalOp>,
- BinaryOpPattern<SignedShiftRightOp, spirv::ShiftRightArithmeticOp>,
- BinaryOpPattern<SignedDivIOp, spirv::SDivOp>,
- BinaryOpPattern<SignedRemIOp, spirv::SRemOp>,
- BinaryOpPattern<SubFOp, spirv::FSubOp>,
- BinaryOpPattern<SubIOp, spirv::ISubOp>,
- BinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
- BinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
- BinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
+ UnaryAndBinaryOpPattern<AbsFOp, spirv::GLSLFAbsOp>,
+ UnaryAndBinaryOpPattern<AddFOp, spirv::FAddOp>,
+ UnaryAndBinaryOpPattern<AddIOp, spirv::IAddOp>,
+ UnaryAndBinaryOpPattern<CeilFOp, spirv::GLSLCeilOp>,
+ UnaryAndBinaryOpPattern<CosOp, spirv::GLSLCosOp>,
+ UnaryAndBinaryOpPattern<DivFOp, spirv::FDivOp>,
+ UnaryAndBinaryOpPattern<ExpOp, spirv::GLSLExpOp>,
+ UnaryAndBinaryOpPattern<LogOp, spirv::GLSLLogOp>,
+ UnaryAndBinaryOpPattern<MulFOp, spirv::FMulOp>,
+ UnaryAndBinaryOpPattern<MulIOp, spirv::IMulOp>,
+ UnaryAndBinaryOpPattern<NegFOp, spirv::FNegateOp>,
+ UnaryAndBinaryOpPattern<RemFOp, spirv::FRemOp>,
+ UnaryAndBinaryOpPattern<RsqrtOp, spirv::GLSLInverseSqrtOp>,
+ UnaryAndBinaryOpPattern<ShiftLeftOp, spirv::ShiftLeftLogicalOp>,
+ UnaryAndBinaryOpPattern<SignedDivIOp, spirv::SDivOp>,
+ UnaryAndBinaryOpPattern<SignedRemIOp, spirv::SRemOp>,
+ UnaryAndBinaryOpPattern<SignedShiftRightOp,
+ spirv::ShiftRightArithmeticOp>,
+ UnaryAndBinaryOpPattern<SqrtOp, spirv::GLSLSqrtOp>,
+ UnaryAndBinaryOpPattern<SubFOp, spirv::FSubOp>,
+ UnaryAndBinaryOpPattern<SubIOp, spirv::ISubOp>,
+ UnaryAndBinaryOpPattern<TanhOp, spirv::GLSLTanhOp>,
+ UnaryAndBinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
+ UnaryAndBinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
+ UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
ConstantCompositeOpPattern, ConstantScalarOpPattern, CmpFOpPattern,
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
index cb5873a1baf0..91219acc0bd5 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
@@ -31,9 +31,33 @@ func @int32_scalar(%lhs: i32, %rhs: i32) {
return
}
-// Check float operation conversions.
-// CHECK-LABEL: @float32_scalar
-func @float32_scalar(%lhs: f32, %rhs: f32) {
+// Check float unary operation conversions.
+// CHECK-LABEL: @float32_unary_scalar
+func @float32_unary_scalar(%arg0: f32) {
+ // CHECK: spv.GLSL.FAbs %{{.*}}: f32
+ %0 = absf %arg0 : f32
+ // CHECK: spv.GLSL.Ceil %{{.*}}: f32
+ %1 = ceilf %arg0 : f32
+ // CHECK: spv.GLSL.Cos %{{.*}}: f32
+ %2 = cos %arg0 : f32
+ // CHECK: spv.GLSL.Exp %{{.*}}: f32
+ %3 = exp %arg0 : f32
+ // CHECK: spv.GLSL.Log %{{.*}}: f32
+ %4 = log %arg0 : f32
+ // CHECK: spv.FNegate %{{.*}}: f32
+ %5 = negf %arg0 : f32
+ // CHECK: spv.GLSL.InverseSqrt %{{.*}}: f32
+ %6 = rsqrt %arg0 : f32
+ // CHECK: spv.GLSL.Sqrt %{{.*}}: f32
+ %7 = sqrt %arg0 : f32
+ // CHECK: spv.GLSL.Tanh %{{.*}}: f32
+ %8 = tanh %arg0 : f32
+ return
+}
+
+// Check float binary operation conversions.
+// CHECK-LABEL: @float32_binary_scalar
+func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
// CHECK: spv.FAdd %{{.*}}, %{{.*}}: f32
%0 = addf %lhs, %rhs: f32
// CHECK: spv.FSub %{{.*}}, %{{.*}}: f32
More information about the Mlir-commits
mailing list