[Mlir-commits] [mlir] d9edc1a - [mlir][spirv] Add math.fma lowering to spirv
Thomas Raoux
llvmlistbot at llvm.org
Wed Jan 19 10:57:50 PST 2022
Author: Thomas Raoux
Date: 2022-01-19T10:57:05-08:00
New Revision: d9edc1a585d7b4a203ee29136260282bc9c65c95
URL: https://github.com/llvm/llvm-project/commit/d9edc1a585d7b4a203ee29136260282bc9c65c95
DIFF: https://github.com/llvm/llvm-project/commit/d9edc1a585d7b4a203ee29136260282bc9c65c95.diff
LOG: [mlir][spirv] Add math.fma lowering to spirv
Differential Revision: https://reviews.llvm.org/D117704
Added:
Modified:
mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
mlir/lib/Conversion/SPIRVCommon/Pattern.h
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index 4fd2b6df9fa5f..b3972a6b3b0bc 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -790,25 +790,25 @@ void mlir::arith::populateArithmeticToSPIRVPatterns(
patterns.add<
ConstantCompositeOpPattern,
ConstantScalarOpPattern,
- spirv::UnaryAndBinaryOpPattern<arith::AddIOp, spirv::IAddOp>,
- spirv::UnaryAndBinaryOpPattern<arith::SubIOp, spirv::ISubOp>,
- spirv::UnaryAndBinaryOpPattern<arith::MulIOp, spirv::IMulOp>,
- spirv::UnaryAndBinaryOpPattern<arith::DivUIOp, spirv::UDivOp>,
- spirv::UnaryAndBinaryOpPattern<arith::DivSIOp, spirv::SDivOp>,
- spirv::UnaryAndBinaryOpPattern<arith::RemUIOp, spirv::UModOp>,
+ spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>,
+ spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>,
+ spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>,
+ spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
+ spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
+ spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
RemSIOpGLSLPattern, RemSIOpOCLPattern,
BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
XOrIOpLogicalPattern, XOrIOpBooleanPattern,
- spirv::UnaryAndBinaryOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
- spirv::UnaryAndBinaryOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
- spirv::UnaryAndBinaryOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
- spirv::UnaryAndBinaryOpPattern<arith::NegFOp, spirv::FNegateOp>,
- spirv::UnaryAndBinaryOpPattern<arith::AddFOp, spirv::FAddOp>,
- spirv::UnaryAndBinaryOpPattern<arith::SubFOp, spirv::FSubOp>,
- spirv::UnaryAndBinaryOpPattern<arith::MulFOp, spirv::FMulOp>,
- spirv::UnaryAndBinaryOpPattern<arith::DivFOp, spirv::FDivOp>,
- spirv::UnaryAndBinaryOpPattern<arith::RemFOp, spirv::FRemOp>,
+ spirv::ElementwiseOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
+ spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
+ spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
+ spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>,
+ spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>,
+ spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>,
+ spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>,
+ spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>,
+ spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>,
TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern,
TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>,
TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 7e95d33e78dd8..ec8402af03009 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -64,35 +64,36 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
// GLSL patterns
- patterns.add<
- Log1pOpPattern<spirv::GLSLLogOp>,
- spirv::UnaryAndBinaryOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
- spirv::UnaryAndBinaryOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
- spirv::UnaryAndBinaryOpPattern<math::CosOp, spirv::GLSLCosOp>,
- spirv::UnaryAndBinaryOpPattern<math::ExpOp, spirv::GLSLExpOp>,
- spirv::UnaryAndBinaryOpPattern<math::FloorOp, spirv::GLSLFloorOp>,
- spirv::UnaryAndBinaryOpPattern<math::LogOp, spirv::GLSLLogOp>,
- spirv::UnaryAndBinaryOpPattern<math::PowFOp, spirv::GLSLPowOp>,
- spirv::UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
- spirv::UnaryAndBinaryOpPattern<math::SinOp, spirv::GLSLSinOp>,
- spirv::UnaryAndBinaryOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
- spirv::UnaryAndBinaryOpPattern<math::TanhOp, spirv::GLSLTanhOp>>(
- typeConverter, patterns.getContext());
+ patterns
+ .add<Log1pOpPattern<spirv::GLSLLogOp>,
+ spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
+ spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
+ spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
+ spirv::ElementwiseOpPattern<math::ExpOp, spirv::GLSLExpOp>,
+ spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLSLFloorOp>,
+ spirv::ElementwiseOpPattern<math::LogOp, spirv::GLSLLogOp>,
+ spirv::ElementwiseOpPattern<math::PowFOp, spirv::GLSLPowOp>,
+ spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
+ spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSLSinOp>,
+ spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
+ spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>,
+ spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>>(
+ typeConverter, patterns.getContext());
// OpenCL patterns
patterns.add<Log1pOpPattern<spirv::OCLLogOp>,
- spirv::UnaryAndBinaryOpPattern<math::AbsOp, spirv::OCLFAbsOp>,
- spirv::UnaryAndBinaryOpPattern<math::CeilOp, spirv::OCLCeilOp>,
- spirv::UnaryAndBinaryOpPattern<math::CosOp, spirv::OCLCosOp>,
- spirv::UnaryAndBinaryOpPattern<math::ErfOp, spirv::OCLErfOp>,
- spirv::UnaryAndBinaryOpPattern<math::ExpOp, spirv::OCLExpOp>,
- spirv::UnaryAndBinaryOpPattern<math::FloorOp, spirv::OCLFloorOp>,
- spirv::UnaryAndBinaryOpPattern<math::LogOp, spirv::OCLLogOp>,
- spirv::UnaryAndBinaryOpPattern<math::PowFOp, spirv::OCLPowOp>,
- spirv::UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,
- spirv::UnaryAndBinaryOpPattern<math::SinOp, spirv::OCLSinOp>,
- spirv::UnaryAndBinaryOpPattern<math::SqrtOp, spirv::OCLSqrtOp>,
- spirv::UnaryAndBinaryOpPattern<math::TanhOp, spirv::OCLTanhOp>>(
+ spirv::ElementwiseOpPattern<math::AbsOp, spirv::OCLFAbsOp>,
+ spirv::ElementwiseOpPattern<math::CeilOp, spirv::OCLCeilOp>,
+ spirv::ElementwiseOpPattern<math::CosOp, spirv::OCLCosOp>,
+ spirv::ElementwiseOpPattern<math::ErfOp, spirv::OCLErfOp>,
+ spirv::ElementwiseOpPattern<math::ExpOp, spirv::OCLExpOp>,
+ spirv::ElementwiseOpPattern<math::FloorOp, spirv::OCLFloorOp>,
+ spirv::ElementwiseOpPattern<math::LogOp, spirv::OCLLogOp>,
+ spirv::ElementwiseOpPattern<math::PowFOp, spirv::OCLPowOp>,
+ spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,
+ spirv::ElementwiseOpPattern<math::SinOp, spirv::OCLSinOp>,
+ spirv::ElementwiseOpPattern<math::SqrtOp, spirv::OCLSqrtOp>,
+ spirv::ElementwiseOpPattern<math::TanhOp, spirv::OCLTanhOp>>(
typeConverter, patterns.getContext());
}
diff --git a/mlir/lib/Conversion/SPIRVCommon/Pattern.h b/mlir/lib/Conversion/SPIRVCommon/Pattern.h
index 39338995c297c..c70009fcc23e2 100644
--- a/mlir/lib/Conversion/SPIRVCommon/Pattern.h
+++ b/mlir/lib/Conversion/SPIRVCommon/Pattern.h
@@ -15,16 +15,17 @@
namespace mlir {
namespace spirv {
-/// Converts unary and binary standard operations to SPIR-V operations.
+/// Converts elementwise unary, binary and ternary standard operations to SPIR-V
+/// operations.
template <typename Op, typename SPIRVOp>
-class UnaryAndBinaryOpPattern final : public OpConversionPattern<Op> {
+class ElementwiseOpPattern final : public OpConversionPattern<Op> {
public:
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- assert(adaptor.getOperands().size() <= 2);
+ assert(adaptor.getOperands().size() <= 3);
auto dstType = this->getTypeConverter()->convertType(op.getType());
if (!dstType)
return failure();
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index fea7c7ca61697..8d39aa9e598d7 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -230,12 +230,12 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
patterns.add<
// Unary and binary patterns
- spirv::UnaryAndBinaryOpPattern<arith::MaxFOp, spirv::GLSLFMaxOp>,
- spirv::UnaryAndBinaryOpPattern<arith::MaxSIOp, spirv::GLSLSMaxOp>,
- spirv::UnaryAndBinaryOpPattern<arith::MaxUIOp, spirv::GLSLUMaxOp>,
- spirv::UnaryAndBinaryOpPattern<arith::MinFOp, spirv::GLSLFMinOp>,
- spirv::UnaryAndBinaryOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>,
- spirv::UnaryAndBinaryOpPattern<arith::MinUIOp, spirv::GLSLUMinOp>,
+ spirv::ElementwiseOpPattern<arith::MaxFOp, spirv::GLSLFMaxOp>,
+ spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSLSMaxOp>,
+ spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLSLUMaxOp>,
+ spirv::ElementwiseOpPattern<arith::MinFOp, spirv::GLSLFMinOp>,
+ spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>,
+ spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLSLUMinOp>,
ReturnOpPattern, SelectOpPattern, SplatPattern, BranchOpPattern,
CondBranchOpPattern>(typeConverter, context);
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
index 8cae1ca7d94ef..f0e0b7e63fdce 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
@@ -68,4 +68,19 @@ func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
return
}
+ // CHECK-LABEL: @float32_ternary_scalar
+func @float32_ternary_scalar(%a: f32, %b: f32, %c: f32) {
+ // CHECK: spv.GLSL.Fma %{{.*}}: f32
+ %0 = math.fma %a, %b, %c : f32
+ return
+}
+
+// CHECK-LABEL: @float32_ternary_vector
+func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>,
+ %c: vector<4xf32>) {
+ // CHECK: spv.GLSL.Fma %{{.*}}: vector<4xf32>
+ %0 = math.fma %a, %b, %c : vector<4xf32>
+ return
+}
+
} // end module
More information about the Mlir-commits
mailing list