[Mlir-commits] [mlir] b9e642a - [mlir][spirv] Add path for math.round to spirv for OCL and GLSL

Robert Suderman llvmlistbot at llvm.org
Thu Jul 7 12:27:10 PDT 2022


Author: Robert Suderman
Date: 2022-07-07T19:20:20Z
New Revision: b9e642afd152c2819d91bfaea028a81ff4e9454e

URL: https://github.com/llvm/llvm-project/commit/b9e642afd152c2819d91bfaea028a81ff4e9454e
DIFF: https://github.com/llvm/llvm-project/commit/b9e642afd152c2819d91bfaea028a81ff4e9454e.diff

LOG: [mlir][spirv] Add path for math.round to spirv for OCL and GLSL

OpenCL's round function matches `math.round` so we can directly lower to
the op, this includes adding the op definition to the SPIRV OCL ops.
GLSL does not guarantee rounding direction so we include custom rounding
code to guarantee correct rounding direction.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D129236

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td
    mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
    mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
    mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td
index c1d1fd480d4f5..025c4628bb092 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td
@@ -110,36 +110,6 @@ class SPV_OCLTernaryArithmeticOp<string mnemonic, int opcode, Type type,
 }
 
 
-// -----
-
-def SPV_OCLFmaOp : SPV_OCLTernaryArithmeticOp<"fma", 26, SPV_Float> {
-  let summary = [{
-    Compute the correctly rounded floating-point representation of the sum
-    of c with the infinitely precise product of a and b. Rounding of
-    intermediate products shall not occur. Edge case results are per the
-    IEEE 754-2008 standard.
-  }];
-
-  let description = [{
-    Result Type, a, b and c must be floating-point or vector(2,3,4,8,16) of
-    floating-point values.
-
-    All of the operands, including the Result Type operand, must be of the
-    same type.
-
-    <!-- End of AutoGen section -->
-
-    ```
-    fma-op ::= ssa-id `=` `spv.OCL.fma` ssa-use, ssa-use, ssa-use `:`
-               float-scalar-vector-type
-    ```mlir
-
-    ```
-    %0 = spv.OCL.fma %a, %b, %c : f32
-    %1 = spv.OCL.fma %a, %b, %c : vector<3xf16>
-    ```
-  }];
-}
 
 // -----
 
@@ -331,6 +301,37 @@ def SPV_OCLFloorOp : SPV_OCLUnaryArithmeticOp<"floor", 25, SPV_Float> {
 
 // -----
 
+def SPV_OCLFmaOp : SPV_OCLTernaryArithmeticOp<"fma", 26, SPV_Float> {
+  let summary = [{
+    Compute the correctly rounded floating-point representation of the sum
+    of c with the infinitely precise product of a and b. Rounding of
+    intermediate products shall not occur. Edge case results are per the
+    IEEE 754-2008 standard.
+  }];
+
+  let description = [{
+    Result Type, a, b and c must be floating-point or vector(2,3,4,8,16) of
+    floating-point values.
+
+    All of the operands, including the Result Type operand, must be of the
+    same type.
+
+    <!-- End of AutoGen section -->
+
+    ```
+    fma-op ::= ssa-id `=` `spv.OCL.fma` ssa-use, ssa-use, ssa-use `:`
+               float-scalar-vector-type
+    ```mlir
+
+    ```
+    %0 = spv.OCL.fma %a, %b, %c : f32
+    %1 = spv.OCL.fma %a, %b, %c : vector<3xf16>
+    ```
+  }];
+}
+
+// -----
+
 def SPV_OCLLogOp : SPV_OCLUnaryArithmeticOp<"log", 37, SPV_Float> {
   let summary = "Compute the natural logarithm of x.";
 
@@ -392,6 +393,38 @@ def SPV_OCLPowOp : SPV_OCLBinaryArithmeticOp<"pow", 48, SPV_Float> {
 
 // -----
 
+def SPV_OCLRoundOp : SPV_OCLUnaryArithmeticOp<"round", 55, SPV_Float> {
+  let summary = [{
+    Return the integral value nearest to x rounding halfway cases away from
+    zero, regardless of the current rounding direction.
+  }];
+
+  let description = [{
+    Result Type and x must be floating-point or vector(2,3,4,8,16) of
+    floating-point values.
+
+    All of the operands, including the Result Type operand, must be of the
+    same type.
+
+    <!-- End of AutoGen section -->
+
+    ```
+    float-scalar-vector-type ::= float-type |
+                                 `vector<` integer-literal `x` float-type `>`
+    round-op ::= ssa-id `=` `spv.OCL.round` ssa-use `:`
+               float-scalar-vector-type
+    ```
+    #### Example:
+
+    ```mlir
+    %2 = spv.OCL.round %0 : f32
+    %3 = spv.OCL.round %0 : vector<3xf16>
+    ```
+  }];
+}
+
+// -----
+
 def SPV_OCLRsqrtOp : SPV_OCLUnaryArithmeticOp<"rsqrt", 56, SPV_Float> {
   let summary = "Compute inverse square root of x.";
 

diff  --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index ea367b1faa201..dead01932c57e 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/Support/Debug.h"
 
@@ -233,6 +234,43 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
   }
 };
 
+/// Converts math.round to GLSL SPIRV extended ops.
+struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = roundOp.getLoc();
+    auto operand = roundOp.getOperand();
+    auto ty = operand.getType();
+    auto ety = getElementTypeOrSelf(ty);
+
+    auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
+    auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
+    Value half;
+    if (VectorType vty = ty.dyn_cast<VectorType>()) {
+      half = rewriter.create<spirv::ConstantOp>(
+          loc, vty,
+          DenseElementsAttr::get(vty,
+                                 rewriter.getFloatAttr(ety, 0.5).getValue()));
+    } else {
+      half = rewriter.create<spirv::ConstantOp>(
+          loc, ty, rewriter.getFloatAttr(ety, 0.5));
+    }
+
+    auto abs = rewriter.create<spirv::GLSLFAbsOp>(loc, operand);
+    auto floor = rewriter.create<spirv::GLSLFloorOp>(loc, abs);
+    auto sub = rewriter.create<spirv::FSubOp>(loc, abs, floor);
+    auto greater =
+        rewriter.create<spirv::FOrdGreaterThanEqualOp>(loc, sub, half);
+    auto select = rewriter.create<spirv::SelectOp>(loc, greater, one, zero);
+    auto add = rewriter.create<spirv::FAddOp>(loc, floor, select);
+    rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand);
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -248,7 +286,7 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
   // GLSL patterns
   patterns
       .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLSLLogOp>,
-           ExpM1OpPattern<spirv::GLSLExpOp>, PowFOpPattern,
+           ExpM1OpPattern<spirv::GLSLExpOp>, PowFOpPattern, RoundOpPattern,
            spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
            spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
            spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
@@ -273,6 +311,7 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                spirv::ElementwiseOpPattern<math::FmaOp, spirv::OCLFmaOp>,
                spirv::ElementwiseOpPattern<math::LogOp, spirv::OCLLogOp>,
                spirv::ElementwiseOpPattern<math::PowFOp, spirv::OCLPowOp>,
+               spirv::ElementwiseOpPattern<math::RoundOp, spirv::OCLRoundOp>,
                spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,
                spirv::ElementwiseOpPattern<math::SinOp, spirv::OCLSinOp>,
                spirv::ElementwiseOpPattern<math::SqrtOp, spirv::OCLSqrtOp>,

diff  --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
index a3067af661f64..31d25928bc494 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
@@ -145,6 +145,38 @@ func.func @powf_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) -> vector<4xf32
   return %0: vector<4xf32>
 }
 
+// CHECK-LABEL: @round_scalar
+func.func @round_scalar(%x: f32) -> f32 {
+  // CHECK: %[[ZERO:.+]] = spv.Constant 0.000000e+00
+  // CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00
+  // CHECK: %[[HALF:.+]] = spv.Constant 5.000000e-01
+  // CHECK: %[[ABS:.+]] = spv.GLSL.FAbs %arg0
+  // CHECK: %[[FLOOR:.+]] = spv.GLSL.Floor %[[ABS]]
+  // CHECK: %[[SUB:.+]] = spv.FSub %[[ABS]], %[[FLOOR]]
+  // CHECK: %[[GE:.+]] = spv.FOrdGreaterThanEqual %[[SUB]], %[[HALF]]
+  // CHECK: %[[SEL:.+]] = spv.Select %[[GE]], %[[ONE]], %[[ZERO]]
+  // CHECK: %[[ADD:.+]] = spv.FAdd %[[FLOOR]], %[[SEL]]
+  // CHECK: %[[BITCAST:.+]] = spv.Bitcast %[[ADD]]
+  %0 = math.round %x : f32
+  return %0: f32
+}
+
+// CHECK-LABEL: @round_vector
+func.func @round_vector(%x: vector<4xf32>) -> vector<4xf32> {
+  // CHECK: %[[ZERO:.+]] = spv.Constant dense<0.000000e+00>
+  // CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00>
+  // CHECK: %[[HALF:.+]] = spv.Constant dense<5.000000e-01>
+  // CHECK: %[[ABS:.+]] = spv.GLSL.FAbs %arg0
+  // CHECK: %[[FLOOR:.+]] = spv.GLSL.Floor %[[ABS]]
+  // CHECK: %[[SUB:.+]] = spv.FSub %[[ABS]], %[[FLOOR]]
+  // CHECK: %[[GE:.+]] = spv.FOrdGreaterThanEqual %[[SUB]], %[[HALF]]
+  // CHECK: %[[SEL:.+]] = spv.Select %[[GE]], %[[ONE]], %[[ZERO]]
+  // CHECK: %[[ADD:.+]] = spv.FAdd %[[FLOOR]], %[[SEL]]
+  // CHECK: %[[BITCAST:.+]] = spv.Bitcast %[[ADD]]
+  %0 = math.round %x : vector<4xf32>
+  return %0: vector<4xf32>
+}
+
 } // end module
 
 // -----

diff  --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
index 279cff726ed56..e7220f7b33169 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
@@ -34,6 +34,8 @@ func.func @float32_unary_scalar(%arg0: f32) {
   %11 = math.floor %arg0 : f32
   // CHECK: spv.OCL.erf %{{.*}}: f32
   %12 = math.erf %arg0 : f32
+  // CHECK: spv.OCL.round %{{.*}}: f32
+  %13 = math.round %arg0 : f32
   return
 }
 


        


More information about the Mlir-commits mailing list