[Mlir-commits] [mlir] e27197f - [mlir][spirv] Define spv.IsNan/spv.IsInf and add lowerings

Lei Zhang llvmlistbot at llvm.org
Fri Jan 22 10:09:41 PST 2021


Author: Lei Zhang
Date: 2021-01-22T13:09:33-05:00
New Revision: e27197f3605450c372ddc71922d0e9982b30e115

URL: https://github.com/llvm/llvm-project/commit/e27197f3605450c372ddc71922d0e9982b30e115
DIFF: https://github.com/llvm/llvm-project/commit/e27197f3605450c372ddc71922d0e9982b30e115.diff

LOG: [mlir][spirv] Define spv.IsNan/spv.IsInf and add lowerings

spv.Ordered/spv.Unordered are meant for OpenCL Kernel capability.
For Vulkan Shader capability, we should use spv.IsNan to check
whether a number is NaN.

Add a new pattern for converting `std.cmpf ord|uno` to spv.IsNan
and bumped the pattern converting to spv.Ordered/spv.Unordered
to a higher benefit. The SPIR-V target environment will properly
select between these two patterns.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D95237

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
    mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
    mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
    mlir/test/Target/SPIRV/logical-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index c369304cf18b..347b65a7739e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -3216,6 +3216,8 @@ def SPV_OC_OpFRem                      : I32EnumAttrCase<"OpFRem", 140>;
 def SPV_OC_OpFMod                      : I32EnumAttrCase<"OpFMod", 141>;
 def SPV_OC_OpMatrixTimesScalar         : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
 def SPV_OC_OpMatrixTimesMatrix         : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
+def SPV_OC_OpIsNan                     : I32EnumAttrCase<"OpIsNan", 156>;
+def SPV_OC_OpIsInf                     : I32EnumAttrCase<"OpIsInf", 157>;
 def SPV_OC_OpOrdered                   : I32EnumAttrCase<"OpOrdered", 162>;
 def SPV_OC_OpUnordered                 : I32EnumAttrCase<"OpUnordered", 163>;
 def SPV_OC_OpLogicalEqual              : I32EnumAttrCase<"OpLogicalEqual", 164>;
@@ -3332,15 +3334,15 @@ def SPV_OpcodeAttr :
       SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv,
       SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod,
       SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar,
-      SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpOrdered, SPV_OC_OpUnordered,
-      SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr,
-      SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual,
-      SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
-      SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
-      SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
-      SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
-      SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
-      SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
+      SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpIsNan, SPV_OC_OpIsInf, SPV_OC_OpOrdered,
+      SPV_OC_OpUnordered, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
+      SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
+      SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
+      SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
+      SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
+      SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
+      SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
+      SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
       SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
       SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
       SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index 0516e70f87c4..019b63f3a582 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -41,6 +41,11 @@ class SPV_LogicalUnaryOp<string mnemonic, Type operandType,
                                        SameOperandsAndResultShape])> {
   let parser = [{ return ::parseLogicalUnaryOp(parser, result); }];
   let printer = [{ return ::printLogicalOp(getOperation(), p); }];
+
+  let builders = [
+    OpBuilderDAG<(ins "Value":$value),
+    [{::buildLogicalUnaryOp($_builder, $_state, value);}]>
+  ];
 }
 
 // -----
@@ -507,6 +512,70 @@ def SPV_INotEqualOp : SPV_LogicalBinaryOp<"INotEqual",
 
 // -----
 
+def SPV_IsInfOp : SPV_LogicalUnaryOp<"IsInf", SPV_Float, []> {
+  let summary = "Result is true if x is an IEEE Inf, otherwise result is false";
+
+  let description = [{
+    Result Type must be a scalar or vector of Boolean type.
+
+    x must be a scalar or vector of floating-point type.  It must have the
+    same number of components as Result Type.
+
+     Results are computed per component.
+
+    <!-- End of AutoGen section -->
+
+    ```
+    float-scalar-vector-type ::= float-type |
+                                 `vector<` integer-literal `x` float-type `>`
+    isinf-op ::= ssa-id `=` `spv.IsInf` ssa-use
+                            `:` float-scalar-vector-type
+    ```
+
+    #### Example:
+
+    ```mlir
+    %2 = spv.IsInf %0: f32
+    %3 = spv.IsInf %1: vector<4xi32>
+    ```
+  }];
+}
+
+// -----
+
+def SPV_IsNanOp : SPV_LogicalUnaryOp<"IsNan", SPV_Float, []> {
+  let summary = [{
+    Result is true if x is an IEEE NaN, otherwise result is false.
+  }];
+
+  let description = [{
+    Result Type must be a scalar or vector of Boolean type.
+
+    x must be a scalar or vector of floating-point type.  It must have the
+    same number of components as Result Type.
+
+     Results are computed per component.
+
+    <!-- End of AutoGen section -->
+
+    ```
+    float-scalar-vector-type ::= float-type |
+                                 `vector<` integer-literal `x` float-type `>`
+    isnan-op ::= ssa-id `=` `spv.IsNan` ssa-use
+                            `:` float-scalar-vector-type
+    ```
+
+    #### Example:
+
+    ```mlir
+    %2 = spv.IsNan %0: f32
+    %3 = spv.IsNan %1: vector<4xi32>
+    ```
+  }];
+}
+
+// -----
+
 def SPV_LogicalAndOp : SPV_LogicalBinaryOp<"LogicalAnd",
                                            SPV_Bool,
                                            [Commutative,

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index 95bb0eca4496..041495e2b7cb 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -386,6 +386,28 @@ class CmpFOpPattern final : public OpConversionPattern<CmpFOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Converts floating point NaN check to SPIR-V ops. This pattern requires
+/// Kernel capability.
+class CmpFOpNanKernelPattern final : public OpConversionPattern<CmpFOp> {
+public:
+  using OpConversionPattern<CmpFOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Converts floating point NaN check to SPIR-V ops. This pattern does not
+/// require additional capability.
+class CmpFOpNanNonePattern final : public OpConversionPattern<CmpFOp> {
+public:
+  using OpConversionPattern<CmpFOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
 class BoolCmpIOpPattern final : public OpConversionPattern<CmpIOp> {
 public:
@@ -730,7 +752,6 @@ CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
     DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp);
     DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
     DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
-    DISPATCH(CmpFPredicate::ORD, spirv::OrderedOp);
     // Unordered.
     DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp);
     DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
@@ -738,7 +759,6 @@ CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
     DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp);
     DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
     DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
-    DISPATCH(CmpFPredicate::UNO, spirv::UnorderedOp);
 
 #undef DISPATCH
 
@@ -748,6 +768,47 @@ CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
   return failure();
 }
 
+LogicalResult CmpFOpNanKernelPattern::matchAndRewrite(
+    CmpFOp cmpFOp, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  CmpFOpAdaptor cmpFOpOperands(operands);
+
+  if (cmpFOp.getPredicate() == CmpFPredicate::ORD) {
+    rewriter.replaceOpWithNewOp<spirv::OrderedOp>(cmpFOp, cmpFOpOperands.lhs(),
+                                                  cmpFOpOperands.rhs());
+    return success();
+  }
+
+  if (cmpFOp.getPredicate() == CmpFPredicate::UNO) {
+    rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(
+        cmpFOp, cmpFOpOperands.lhs(), cmpFOpOperands.rhs());
+    return success();
+  }
+
+  return failure();
+}
+
+LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
+    CmpFOp cmpFOp, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  if (cmpFOp.getPredicate() != CmpFPredicate::ORD &&
+      cmpFOp.getPredicate() != CmpFPredicate::UNO)
+    return failure();
+
+  CmpFOpAdaptor cmpFOpOperands(operands);
+  Location loc = cmpFOp.getLoc();
+
+  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, cmpFOpOperands.lhs());
+  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, cmpFOpOperands.rhs());
+
+  Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
+  if (cmpFOp.getPredicate() == CmpFPredicate::ORD)
+    replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
+
+  rewriter.replaceOp(cmpFOp, replace);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // CmpIOp
 //===----------------------------------------------------------------------===//
@@ -1102,7 +1163,7 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
       SignedRemIOpPattern, XOrOpPattern,
 
       // Comparison patterns
-      BoolCmpIOpPattern, CmpFOpPattern, CmpIOpPattern,
+      BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern,
 
       // Constant patterns
       ConstantCompositeOpPattern, ConstantScalarOpPattern,
@@ -1124,5 +1185,10 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
       TypeCastingOpPattern<FPExtOp, spirv::FConvertOp>,
       TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>>(typeConverter,
                                                           context);
+
+  // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
+  // capability is available.
+  patterns.insert<CmpFOpNanKernelPattern>(typeConverter, context,
+                                          /*benefit=*/2);
 }
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 3d99696d6882..4506447b0503 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -900,6 +900,16 @@ static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state,
   state.addOperands({lhs, rhs});
 }
 
+static void buildLogicalUnaryOp(OpBuilder &builder, OperationState &state,
+                                Value value) {
+  Type boolType = builder.getI1Type();
+  if (auto vecType = value.getType().dyn_cast<VectorType>())
+    boolType = VectorType::get(vecType.getShape(), boolType);
+  state.addTypes(boolType);
+
+  state.addOperands(value);
+}
+
 //===----------------------------------------------------------------------===//
 // spv.AccessChainOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index a33db1dd42cf..8ae93c2e4b9b 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -301,6 +301,7 @@ func @cmpf(%arg0 : f32, %arg1 : f32) {
 
 // -----
 
+// With Kernel capability, we can convert NaN check to spv.Ordered/spv.Unordered.
 module attributes {
   spv.target_env = #spv.target_env<#spv.vce<v1.0, [Kernel], []>, {}>
 } {
@@ -318,6 +319,31 @@ func @cmpf(%arg0 : f32, %arg1 : f32) {
 
 // -----
 
+// Without Kernel capability, we need to convert NaN check to spv.IsNan.
+module attributes {
+  spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
+} {
+
+// CHECK-LABEL: @cmpf
+// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
+func @cmpf(%arg0 : f32, %arg1 : f32) {
+  // CHECK:      %[[LHS_NAN:.+]] = spv.IsNan %[[LHS]] : f32
+  // CHECK-NEXT: %[[RHS_NAN:.+]] = spv.IsNan %[[RHS]] : f32
+  // CHECK-NEXT: %[[OR:.+]] = spv.LogicalOr %[[LHS_NAN]], %[[RHS_NAN]] : i1
+  // CHECK-NEXT: %{{.+}} = spv.LogicalNot %[[OR]] : i1
+  %0 = cmpf ord, %arg0, %arg1 : f32
+
+  // CHECK-NEXT: %[[LHS_NAN:.+]] = spv.IsNan %[[LHS]] : f32
+  // CHECK-NEXT: %[[RHS_NAN:.+]] = spv.IsNan %[[RHS]] : f32
+  // CHECK-NEXT: %{{.+}} = spv.LogicalOr %[[LHS_NAN]], %[[RHS_NAN]] : i1
+  %1 = cmpf uno, %arg0, %arg1 : f32
+  return
+}
+
+} // end module
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // std.cmpi
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
index baf8b45d7eaf..b2c34b85f194 100644
--- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
@@ -32,6 +32,40 @@ func @inotequal_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spv.IsInf
+//===----------------------------------------------------------------------===//
+
+func @isinf_scalar(%arg0: f32) -> i1 {
+  // CHECK: spv.IsInf {{.*}} : f32
+  %0 = spv.IsInf %arg0 : f32
+  return %0 : i1
+}
+
+func @isinf_vector(%arg0: vector<2xf32>) -> vector<2xi1> {
+  // CHECK: spv.IsInf {{.*}} : vector<2xf32>
+  %0 = spv.IsInf %arg0 : vector<2xf32>
+  return %0 : vector<2xi1>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.IsNan
+//===----------------------------------------------------------------------===//
+
+func @isnan_scalar(%arg0: f32) -> i1 {
+  // CHECK: spv.IsNan {{.*}} : f32
+  %0 = spv.IsNan %arg0 : f32
+  return %0 : i1
+}
+
+func @isnan_vector(%arg0: vector<2xf32>) -> vector<2xi1> {
+  // CHECK: spv.IsNan {{.*}} : vector<2xf32>
+  %0 = spv.IsNan %arg0 : vector<2xf32>
+  return %0 : vector<2xi1>
+}
+
 //===----------------------------------------------------------------------===//
 // spv.LogicalAnd
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Target/SPIRV/logical-ops.mlir b/mlir/test/Target/SPIRV/logical-ops.mlir
index 000cf49d733a..bd92074de39f 100644
--- a/mlir/test/Target/SPIRV/logical-ops.mlir
+++ b/mlir/test/Target/SPIRV/logical-ops.mlir
@@ -80,6 +80,10 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     %13 = spv.Ordered %arg0, %arg1 : f32
     // CHECK: spv.Unordered
     %14 = spv.Unordered %arg0, %arg1 : f32
+    // CHCK: spv.IsNan
+    %15 = spv.IsNan %arg0 : f32
+    // CHCK: spv.IsInf
+    %16 = spv.IsInf %arg1 : f32
     spv.Return
   }
 }


        


More information about the Mlir-commits mailing list