[Mlir-commits] [mlir] 6329c73 - [mlir][spirv] Add support for fast math mode

Lei Zhang llvmlistbot at llvm.org
Fri Sep 9 13:27:16 PDT 2022


Author: Lei Zhang
Date: 2022-09-09T16:27:07-04:00
New Revision: 6329c7387f6b5cad2e2f80136932a297974728c1

URL: https://github.com/llvm/llvm-project/commit/6329c7387f6b5cad2e2f80136932a297974728c1
DIFF: https://github.com/llvm/llvm-project/commit/6329c7387f6b5cad2e2f80136932a297974728c1.diff

LOG: [mlir][spirv] Add support for fast math mode

This commit introduces a new option to SPIRVConversionOptions
to allow enabling fast math mode. With it, various patterns
would assume no NaN/infinity for floating point values and
avoid guards to check them. This is particularly useful for
CodeGen towards WebGPU environment, where fast math is assumed.

Along the way, fixed the conversion for arith.minf/maxf to
handle the NaN cases properly for Shader cases.

Part of https://github.com/llvm/llvm-project/issues/57584.

Reviewed By: ThomasRaoux, hanchung

Differential Revision: https://reviews.llvm.org/D133599

Added: 
    mlir/test/Conversion/ArithmeticToSPIRV/fast-math.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
    mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
    mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 71821d7074b00..6163c6ae5d0cb 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -122,7 +122,11 @@ def ConvertArithmeticToSPIRV : Pass<"convert-arith-to-spirv"> {
     Option<"emulateNon32BitScalarTypes", "emulate-non-32-bit-scalar-types",
            "bool", /*default=*/"true",
            "Emulate non-32-bit scalar types with 32-bit ones if "
-           "missing native support">
+           "missing native support">,
+    Option<"enableFastMath", "enable-fast-math",
+           "bool", /*default=*/"false",
+           "Enable fast math mode (assuming no NaN and infinity for floating "
+           "point values) when performing conversion">
   ];
 }
 

diff  --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 81e701fbd2d19..ff1cdacd4d521 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -26,6 +26,9 @@ namespace mlir {
 //===----------------------------------------------------------------------===//
 
 struct SPIRVConversionOptions {
+  /// The number of bits to store a boolean value.
+  unsigned boolNumBits{8};
+
   /// Whether to emulate non-32-bit scalar types with 32-bit scalar types if
   /// no native support.
   ///
@@ -45,9 +48,10 @@ struct SPIRVConversionOptions {
   /// Use 64-bit integers to convert index types.
   bool use64bitIndex{false};
 
-  /// The number of bits to store a boolean value. It is eight bits by
-  /// default.
-  unsigned boolNumBits{8};
+  /// Whether to enable fast math mode during conversion. If true, various
+  /// patterns would assume no NaN/infinity numbers as inputs, and thus there
+  /// will be no special guards emitted to check and handle such cases.
+  bool enableFastMathMode{false};
 };
 
 /// Type conversion from builtin types to SPIR-V types for shader interface.

diff  --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index 2af5fc9318fd2..18c3abe7a7793 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -219,6 +219,16 @@ class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Converts arith.maxf to spv.GL.FMax.
+template <typename Op, typename SPIRVOp>
+class MinMaxFOpPattern final : public OpConversionPattern<Op> {
+public:
+  using OpConversionPattern<Op>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -839,13 +849,25 @@ LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
     return failure();
 
   Location loc = op.getLoc();
+  auto *converter = getTypeConverter<SPIRVTypeConverter>();
 
-  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
-  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+  Value replace;
+  if (converter->getOptions().enableFastMathMode) {
+    if (op.getPredicate() == arith::CmpFPredicate::ORD) {
+      // Ordered comparsion checks if neither operand is NaN.
+      replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
+    } else {
+      // Unordered comparsion checks if either operand is NaN.
+      replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
+    }
+  } else {
+    Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
+    Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
 
-  Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
-  if (op.getPredicate() == arith::CmpFPredicate::ORD)
-    replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
+    replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
+    if (op.getPredicate() == arith::CmpFPredicate::ORD)
+      replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
+  }
 
   rewriter.replaceOp(op, replace);
   return success();
@@ -889,6 +911,45 @@ SelectOpPattern::matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// MaxFOpPattern
+//===----------------------------------------------------------------------===//
+
+template <typename Op, typename SPIRVOp>
+LogicalResult MinMaxFOpPattern<Op, SPIRVOp>::matchAndRewrite(
+    Op op, typename Op::Adaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
+  auto dstType = converter->convertType(op.getType());
+  if (!dstType)
+    return failure();
+
+  // arith.maxf/minf:
+  //   "if one of the arguments is NaN, then the result is also NaN."
+  // spv.GL.FMax/FMin:
+  //   "which operand is the result is undefined if one of the operands
+  //   is a NaN."
+
+  Location loc = op.getLoc();
+  Value spirvOp = rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
+
+  if (converter->getOptions().enableFastMathMode) {
+    rewriter.replaceOp(op, spirvOp);
+    return success();
+  }
+
+  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
+  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+
+  Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
+                                                   adaptor.getLhs(), spirvOp);
+  Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
+                                                   adaptor.getRhs(), select1);
+
+  rewriter.replaceOp(op, select2);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Pattern Population
 //===----------------------------------------------------------------------===//
@@ -932,10 +993,10 @@ void mlir::arith::populateArithmeticToSPIRVPatterns(
     CmpFOpNanNonePattern, CmpFOpPattern,
     AddICarryOpPattern, SelectOpPattern,
 
-    spirv::ElementwiseOpPattern<arith::MaxFOp, spirv::GLFMaxOp>,
+    MinMaxFOpPattern<arith::MaxFOp, spirv::GLFMaxOp>,
+    MinMaxFOpPattern<arith::MinFOp, spirv::GLFMinOp>,
     spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,
     spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLUMaxOp>,
-    spirv::ElementwiseOpPattern<arith::MinFOp, spirv::GLFMinOp>,
     spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSMinOp>,
     spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLUMinOp>
   >(typeConverter, patterns.getContext());
@@ -961,6 +1022,7 @@ struct ConvertArithmeticToSPIRVPass
 
     SPIRVConversionOptions options;
     options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
+    options.enableFastMathMode = this->enableFastMath;
     SPIRVTypeConverter typeConverter(targetAttr, options);
 
     // Use UnrealizedConversionCast as the bridge so that we don't need to pull

diff  --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
index 4ed91ca31abb6..9430183d37d0b 100644
--- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
@@ -1093,13 +1093,35 @@ func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
   %3 = arith.divf %lhs, %rhs: f32
   // CHECK: spv.FRem %{{.*}}, %{{.*}}: f32
   %4 = arith.remf %lhs, %rhs: f32
-  // CHECK: spv.GL.FMax %{{.*}}, %{{.*}}: f32
-  %5 = arith.maxf %lhs, %rhs: f32
-  // CHECK: spv.GL.FMin %{{.*}}, %{{.*}}: f32
-  %6 = arith.minf %lhs, %rhs: f32
   return
 }
 
+// CHECK-LABEL: @float32_minf_scalar
+// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
+func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
+  // CHECK: %[[MIN:.+]] = spv.GL.FMin %arg0, %arg1 : f32
+  // CHECK: %[[LHS_NAN:.+]] = spv.IsNan %[[LHS]] : f32
+  // CHECK: %[[RHS_NAN:.+]] = spv.IsNan %[[RHS]] : f32
+  // CHECK: %[[SELECT1:.+]] = spv.Select %[[LHS_NAN]], %[[LHS]], %[[MIN]]
+  // CHECK: %[[SELECT2:.+]] = spv.Select %[[RHS_NAN]], %[[RHS]], %[[SELECT1]]
+  %0 = arith.minf %arg0, %arg1 : f32
+  // CHECK: return %[[SELECT2]]
+  return %0: f32
+}
+
+// CHECK-LABEL: @float32_maxf_scalar
+// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
+func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
+  // CHECK: %[[MAX:.+]] = spv.GL.FMax %arg0, %arg1 : vector<2xf32>
+  // CHECK: %[[LHS_NAN:.+]] = spv.IsNan %[[LHS]] : vector<2xf32>
+  // CHECK: %[[RHS_NAN:.+]] = spv.IsNan %[[RHS]] : vector<2xf32>
+  // CHECK: %[[SELECT1:.+]] = spv.Select %[[LHS_NAN]], %[[LHS]], %[[MAX]]
+  // CHECK: %[[SELECT2:.+]] = spv.Select %[[RHS_NAN]], %[[RHS]], %[[SELECT1]]
+  %0 = arith.maxf %arg0, %arg1 : vector<2xf32>
+  // CHECK: return %[[SELECT2]]
+  return %0: vector<2xf32>
+}
+
 // Check int vector types.
 // CHECK-LABEL: @int_vector234
 func.func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<4xi64>) {

diff  --git a/mlir/test/Conversion/ArithmeticToSPIRV/fast-math.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/fast-math.mlir
new file mode 100644
index 0000000000000..f6a57660a28e3
--- /dev/null
+++ b/mlir/test/Conversion/ArithmeticToSPIRV/fast-math.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt -split-input-file -convert-arith-to-spirv=enable-fast-math -verify-diagnostics %s | FileCheck %s
+
+module attributes {
+  spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, #spv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @cmpf_ordered
+// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
+func.func @cmpf_ordered(%arg0 : f32, %arg1 : f32) -> i1 {
+  // CHECK: %[[T:.+]] = spv.Constant true
+  %0 = arith.cmpf ord, %arg0, %arg1 : f32
+  // CHECK: return %[[T]]
+  return %0: i1
+}
+
+// CHECK-LABEL: @cmpf_unordered
+// CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
+func.func @cmpf_unordered(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xi1> {
+  // CHECK: %[[F:.+]] = spv.Constant dense<false>
+  %0 = arith.cmpf uno, %arg0, %arg1 : vector<4xf32>
+  // CHECK: return %[[F]]
+  return %0: vector<4xi1>
+}
+
+} // end module
+
+// -----
+
+module attributes {
+  spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], []>, #spv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @minf
+// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
+func.func @minf(%arg0 : f32, %arg1 : f32) -> f32 {
+  // CHECK: %[[F:.+]] = spv.GL.FMin %[[LHS]], %[[RHS]]
+  %0 = arith.minf %arg0, %arg1 : f32
+  // CHECK: return %[[F]]
+  return %0: f32
+}
+
+// CHECK-LABEL: @maxf
+// CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
+func.func @maxf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
+  // CHECK: %[[F:.+]] = spv.GL.FMax %[[LHS]], %[[RHS]]
+  %0 = arith.maxf %arg0, %arg1 : vector<4xf32>
+  // CHECK: return %[[F]]
+  return %0: vector<4xf32>
+}
+
+} // end module


        


More information about the Mlir-commits mailing list