[Mlir-commits] [mlir] d9edc1a - [mlir][spirv] Add math.fma lowering to spirv

Thomas Raoux llvmlistbot at llvm.org
Wed Jan 19 10:57:50 PST 2022


Author: Thomas Raoux
Date: 2022-01-19T10:57:05-08:00
New Revision: d9edc1a585d7b4a203ee29136260282bc9c65c95

URL: https://github.com/llvm/llvm-project/commit/d9edc1a585d7b4a203ee29136260282bc9c65c95
DIFF: https://github.com/llvm/llvm-project/commit/d9edc1a585d7b4a203ee29136260282bc9c65c95.diff

LOG: [mlir][spirv] Add math.fma lowering to spirv

Differential Revision: https://reviews.llvm.org/D117704

Added: 
    

Modified: 
    mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
    mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
    mlir/lib/Conversion/SPIRVCommon/Pattern.h
    mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
    mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index 4fd2b6df9fa5f..b3972a6b3b0bc 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -790,25 +790,25 @@ void mlir::arith::populateArithmeticToSPIRVPatterns(
   patterns.add<
     ConstantCompositeOpPattern,
     ConstantScalarOpPattern,
-    spirv::UnaryAndBinaryOpPattern<arith::AddIOp, spirv::IAddOp>,
-    spirv::UnaryAndBinaryOpPattern<arith::SubIOp, spirv::ISubOp>,
-    spirv::UnaryAndBinaryOpPattern<arith::MulIOp, spirv::IMulOp>,
-    spirv::UnaryAndBinaryOpPattern<arith::DivUIOp, spirv::UDivOp>,
-    spirv::UnaryAndBinaryOpPattern<arith::DivSIOp, spirv::SDivOp>,
-    spirv::UnaryAndBinaryOpPattern<arith::RemUIOp, spirv::UModOp>,
+    spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>,
+    spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>,
+    spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>,
+    spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
+    spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
+    spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
     RemSIOpGLSLPattern, RemSIOpOCLPattern,
     BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
     BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
     XOrIOpLogicalPattern, XOrIOpBooleanPattern,
-    spirv::UnaryAndBinaryOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
-    spirv::UnaryAndBinaryOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
-    spirv::UnaryAndBinaryOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
-    spirv::UnaryAndBinaryOpPattern<arith::NegFOp, spirv::FNegateOp>,
-    spirv::UnaryAndBinaryOpPattern<arith::AddFOp, spirv::FAddOp>,
-    spirv::UnaryAndBinaryOpPattern<arith::SubFOp, spirv::FSubOp>,
-    spirv::UnaryAndBinaryOpPattern<arith::MulFOp, spirv::FMulOp>,
-    spirv::UnaryAndBinaryOpPattern<arith::DivFOp, spirv::FDivOp>,
-    spirv::UnaryAndBinaryOpPattern<arith::RemFOp, spirv::FRemOp>,
+    spirv::ElementwiseOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
+    spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
+    spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
+    spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>,
+    spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>,
+    spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>,
+    spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>,
+    spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>,
+    spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>,
     TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern,
     TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>,
     TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,

diff  --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 7e95d33e78dd8..ec8402af03009 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -64,35 +64,36 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                  RewritePatternSet &patterns) {
 
   // GLSL patterns
-  patterns.add<
-      Log1pOpPattern<spirv::GLSLLogOp>,
-      spirv::UnaryAndBinaryOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
-      spirv::UnaryAndBinaryOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
-      spirv::UnaryAndBinaryOpPattern<math::CosOp, spirv::GLSLCosOp>,
-      spirv::UnaryAndBinaryOpPattern<math::ExpOp, spirv::GLSLExpOp>,
-      spirv::UnaryAndBinaryOpPattern<math::FloorOp, spirv::GLSLFloorOp>,
-      spirv::UnaryAndBinaryOpPattern<math::LogOp, spirv::GLSLLogOp>,
-      spirv::UnaryAndBinaryOpPattern<math::PowFOp, spirv::GLSLPowOp>,
-      spirv::UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
-      spirv::UnaryAndBinaryOpPattern<math::SinOp, spirv::GLSLSinOp>,
-      spirv::UnaryAndBinaryOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
-      spirv::UnaryAndBinaryOpPattern<math::TanhOp, spirv::GLSLTanhOp>>(
-      typeConverter, patterns.getContext());
+  patterns
+      .add<Log1pOpPattern<spirv::GLSLLogOp>,
+           spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
+           spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
+           spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
+           spirv::ElementwiseOpPattern<math::ExpOp, spirv::GLSLExpOp>,
+           spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLSLFloorOp>,
+           spirv::ElementwiseOpPattern<math::LogOp, spirv::GLSLLogOp>,
+           spirv::ElementwiseOpPattern<math::PowFOp, spirv::GLSLPowOp>,
+           spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
+           spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSLSinOp>,
+           spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
+           spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>,
+           spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>>(
+          typeConverter, patterns.getContext());
 
   // OpenCL patterns
   patterns.add<Log1pOpPattern<spirv::OCLLogOp>,
-               spirv::UnaryAndBinaryOpPattern<math::AbsOp, spirv::OCLFAbsOp>,
-               spirv::UnaryAndBinaryOpPattern<math::CeilOp, spirv::OCLCeilOp>,
-               spirv::UnaryAndBinaryOpPattern<math::CosOp, spirv::OCLCosOp>,
-               spirv::UnaryAndBinaryOpPattern<math::ErfOp, spirv::OCLErfOp>,
-               spirv::UnaryAndBinaryOpPattern<math::ExpOp, spirv::OCLExpOp>,
-               spirv::UnaryAndBinaryOpPattern<math::FloorOp, spirv::OCLFloorOp>,
-               spirv::UnaryAndBinaryOpPattern<math::LogOp, spirv::OCLLogOp>,
-               spirv::UnaryAndBinaryOpPattern<math::PowFOp, spirv::OCLPowOp>,
-               spirv::UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,
-               spirv::UnaryAndBinaryOpPattern<math::SinOp, spirv::OCLSinOp>,
-               spirv::UnaryAndBinaryOpPattern<math::SqrtOp, spirv::OCLSqrtOp>,
-               spirv::UnaryAndBinaryOpPattern<math::TanhOp, spirv::OCLTanhOp>>(
+               spirv::ElementwiseOpPattern<math::AbsOp, spirv::OCLFAbsOp>,
+               spirv::ElementwiseOpPattern<math::CeilOp, spirv::OCLCeilOp>,
+               spirv::ElementwiseOpPattern<math::CosOp, spirv::OCLCosOp>,
+               spirv::ElementwiseOpPattern<math::ErfOp, spirv::OCLErfOp>,
+               spirv::ElementwiseOpPattern<math::ExpOp, spirv::OCLExpOp>,
+               spirv::ElementwiseOpPattern<math::FloorOp, spirv::OCLFloorOp>,
+               spirv::ElementwiseOpPattern<math::LogOp, spirv::OCLLogOp>,
+               spirv::ElementwiseOpPattern<math::PowFOp, spirv::OCLPowOp>,
+               spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,
+               spirv::ElementwiseOpPattern<math::SinOp, spirv::OCLSinOp>,
+               spirv::ElementwiseOpPattern<math::SqrtOp, spirv::OCLSqrtOp>,
+               spirv::ElementwiseOpPattern<math::TanhOp, spirv::OCLTanhOp>>(
       typeConverter, patterns.getContext());
 }
 

diff  --git a/mlir/lib/Conversion/SPIRVCommon/Pattern.h b/mlir/lib/Conversion/SPIRVCommon/Pattern.h
index 39338995c297c..c70009fcc23e2 100644
--- a/mlir/lib/Conversion/SPIRVCommon/Pattern.h
+++ b/mlir/lib/Conversion/SPIRVCommon/Pattern.h
@@ -15,16 +15,17 @@
 namespace mlir {
 namespace spirv {
 
-/// Converts unary and binary standard operations to SPIR-V operations.
+/// Converts elementwise unary, binary and ternary standard operations to SPIR-V
+/// operations.
 template <typename Op, typename SPIRVOp>
-class UnaryAndBinaryOpPattern final : public OpConversionPattern<Op> {
+class ElementwiseOpPattern final : public OpConversionPattern<Op> {
 public:
   using OpConversionPattern<Op>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    assert(adaptor.getOperands().size() <= 2);
+    assert(adaptor.getOperands().size() <= 3);
     auto dstType = this->getTypeConverter()->convertType(op.getType());
     if (!dstType)
       return failure();

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index fea7c7ca61697..8d39aa9e598d7 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -230,12 +230,12 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
 
   patterns.add<
       // Unary and binary patterns
-      spirv::UnaryAndBinaryOpPattern<arith::MaxFOp, spirv::GLSLFMaxOp>,
-      spirv::UnaryAndBinaryOpPattern<arith::MaxSIOp, spirv::GLSLSMaxOp>,
-      spirv::UnaryAndBinaryOpPattern<arith::MaxUIOp, spirv::GLSLUMaxOp>,
-      spirv::UnaryAndBinaryOpPattern<arith::MinFOp, spirv::GLSLFMinOp>,
-      spirv::UnaryAndBinaryOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>,
-      spirv::UnaryAndBinaryOpPattern<arith::MinUIOp, spirv::GLSLUMinOp>,
+      spirv::ElementwiseOpPattern<arith::MaxFOp, spirv::GLSLFMaxOp>,
+      spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSLSMaxOp>,
+      spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLSLUMaxOp>,
+      spirv::ElementwiseOpPattern<arith::MinFOp, spirv::GLSLFMinOp>,
+      spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>,
+      spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLSLUMinOp>,
 
       ReturnOpPattern, SelectOpPattern, SplatPattern, BranchOpPattern,
       CondBranchOpPattern>(typeConverter, context);

diff  --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
index 8cae1ca7d94ef..f0e0b7e63fdce 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
@@ -68,4 +68,19 @@ func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
   return
 }
 
+  // CHECK-LABEL: @float32_ternary_scalar
+func @float32_ternary_scalar(%a: f32, %b: f32, %c: f32) {
+  // CHECK: spv.GLSL.Fma %{{.*}}: f32
+  %0 = math.fma %a, %b, %c : f32
+  return
+}
+
+// CHECK-LABEL: @float32_ternary_vector
+func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>,
+                            %c: vector<4xf32>) {
+  // CHECK: spv.GLSL.Fma %{{.*}}: vector<4xf32>
+  %0 = math.fma %a, %b, %c : vector<4xf32>
+  return
+}
+
 } // end module


        


More information about the Mlir-commits mailing list