[llvm-branch-commits] [mlir] e27197f - [mlir][spirv] Define spv.IsNan/spv.IsInf and add lowerings
Lei Zhang via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Jan 22 10:14:12 PST 2021
Author: Lei Zhang
Date: 2021-01-22T13:09:33-05:00
New Revision: e27197f3605450c372ddc71922d0e9982b30e115
URL: https://github.com/llvm/llvm-project/commit/e27197f3605450c372ddc71922d0e9982b30e115
DIFF: https://github.com/llvm/llvm-project/commit/e27197f3605450c372ddc71922d0e9982b30e115.diff
LOG: [mlir][spirv] Define spv.IsNan/spv.IsInf and add lowerings
spv.Ordered/spv.Unordered are meant for OpenCL Kernel capability.
For Vulkan Shader capability, we should use spv.IsNan to check
whether a number is NaN.
Add a new pattern for converting `std.cmpf ord|uno` to spv.IsNan
and bumped the pattern converting to spv.Ordered/spv.Unordered
to a higher benefit. The SPIR-V target environment will properly
select between these two patterns.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D95237
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
mlir/test/Target/SPIRV/logical-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index c369304cf18b..347b65a7739e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -3216,6 +3216,8 @@ def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>;
def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
def SPV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
def SPV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
+def SPV_OC_OpIsNan : I32EnumAttrCase<"OpIsNan", 156>;
+def SPV_OC_OpIsInf : I32EnumAttrCase<"OpIsInf", 157>;
def SPV_OC_OpOrdered : I32EnumAttrCase<"OpOrdered", 162>;
def SPV_OC_OpUnordered : I32EnumAttrCase<"OpUnordered", 163>;
def SPV_OC_OpLogicalEqual : I32EnumAttrCase<"OpLogicalEqual", 164>;
@@ -3332,15 +3334,15 @@ def SPV_OpcodeAttr :
SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv,
SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod,
SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar,
- SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpOrdered, SPV_OC_OpUnordered,
- SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr,
- SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual,
- SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
- SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
- SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
- SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
- SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
- SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
+ SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpIsNan, SPV_OC_OpIsInf, SPV_OC_OpOrdered,
+ SPV_OC_OpUnordered, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
+ SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
+ SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
+ SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
+ SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
+ SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
+ SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
+ SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index 0516e70f87c4..019b63f3a582 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -41,6 +41,11 @@ class SPV_LogicalUnaryOp<string mnemonic, Type operandType,
SameOperandsAndResultShape])> {
let parser = [{ return ::parseLogicalUnaryOp(parser, result); }];
let printer = [{ return ::printLogicalOp(getOperation(), p); }];
+
+ let builders = [
+ OpBuilderDAG<(ins "Value":$value),
+ [{::buildLogicalUnaryOp($_builder, $_state, value);}]>
+ ];
}
// -----
@@ -507,6 +512,70 @@ def SPV_INotEqualOp : SPV_LogicalBinaryOp<"INotEqual",
// -----
+def SPV_IsInfOp : SPV_LogicalUnaryOp<"IsInf", SPV_Float, []> {
+ let summary = "Result is true if x is an IEEE Inf, otherwise result is false";
+
+ let description = [{
+ Result Type must be a scalar or vector of Boolean type.
+
+ x must be a scalar or vector of floating-point type. It must have the
+ same number of components as Result Type.
+
+ Results are computed per component.
+
+ <!-- End of AutoGen section -->
+
+ ```
+ float-scalar-vector-type ::= float-type |
+ `vector<` integer-literal `x` float-type `>`
+ isinf-op ::= ssa-id `=` `spv.IsInf` ssa-use
+ `:` float-scalar-vector-type
+ ```
+
+ #### Example:
+
+ ```mlir
+ %2 = spv.IsInf %0: f32
+ %3 = spv.IsInf %1: vector<4xi32>
+ ```
+ }];
+}
+
+// -----
+
+def SPV_IsNanOp : SPV_LogicalUnaryOp<"IsNan", SPV_Float, []> {
+ let summary = [{
+ Result is true if x is an IEEE NaN, otherwise result is false.
+ }];
+
+ let description = [{
+ Result Type must be a scalar or vector of Boolean type.
+
+ x must be a scalar or vector of floating-point type. It must have the
+ same number of components as Result Type.
+
+ Results are computed per component.
+
+ <!-- End of AutoGen section -->
+
+ ```
+ float-scalar-vector-type ::= float-type |
+ `vector<` integer-literal `x` float-type `>`
+ isnan-op ::= ssa-id `=` `spv.IsNan` ssa-use
+ `:` float-scalar-vector-type
+ ```
+
+ #### Example:
+
+ ```mlir
+ %2 = spv.IsNan %0: f32
+ %3 = spv.IsNan %1: vector<4xi32>
+ ```
+ }];
+}
+
+// -----
+
def SPV_LogicalAndOp : SPV_LogicalBinaryOp<"LogicalAnd",
SPV_Bool,
[Commutative,
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index 95bb0eca4496..041495e2b7cb 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -386,6 +386,28 @@ class CmpFOpPattern final : public OpConversionPattern<CmpFOp> {
ConversionPatternRewriter &rewriter) const override;
};
+/// Converts floating point NaN check to SPIR-V ops. This pattern requires
+/// Kernel capability.
+class CmpFOpNanKernelPattern final : public OpConversionPattern<CmpFOp> {
+public:
+ using OpConversionPattern<CmpFOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Converts floating point NaN check to SPIR-V ops. This pattern does not
+/// require additional capability.
+class CmpFOpNanNonePattern final : public OpConversionPattern<CmpFOp> {
+public:
+ using OpConversionPattern<CmpFOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
/// Converts integer compare operation on i1 type operands to SPIR-V ops.
class BoolCmpIOpPattern final : public OpConversionPattern<CmpIOp> {
public:
@@ -730,7 +752,6 @@ CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp);
DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
- DISPATCH(CmpFPredicate::ORD, spirv::OrderedOp);
// Unordered.
DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp);
DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
@@ -738,7 +759,6 @@ CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp);
DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
- DISPATCH(CmpFPredicate::UNO, spirv::UnorderedOp);
#undef DISPATCH
@@ -748,6 +768,47 @@ CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
return failure();
}
+LogicalResult CmpFOpNanKernelPattern::matchAndRewrite(
+ CmpFOp cmpFOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ CmpFOpAdaptor cmpFOpOperands(operands);
+
+ if (cmpFOp.getPredicate() == CmpFPredicate::ORD) {
+ rewriter.replaceOpWithNewOp<spirv::OrderedOp>(cmpFOp, cmpFOpOperands.lhs(),
+ cmpFOpOperands.rhs());
+ return success();
+ }
+
+ if (cmpFOp.getPredicate() == CmpFPredicate::UNO) {
+ rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(
+ cmpFOp, cmpFOpOperands.lhs(), cmpFOpOperands.rhs());
+ return success();
+ }
+
+ return failure();
+}
+
+LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
+ CmpFOp cmpFOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ if (cmpFOp.getPredicate() != CmpFPredicate::ORD &&
+ cmpFOp.getPredicate() != CmpFPredicate::UNO)
+ return failure();
+
+ CmpFOpAdaptor cmpFOpOperands(operands);
+ Location loc = cmpFOp.getLoc();
+
+ Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, cmpFOpOperands.lhs());
+ Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, cmpFOpOperands.rhs());
+
+ Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
+ if (cmpFOp.getPredicate() == CmpFPredicate::ORD)
+ replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
+
+ rewriter.replaceOp(cmpFOp, replace);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// CmpIOp
//===----------------------------------------------------------------------===//
@@ -1102,7 +1163,7 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
SignedRemIOpPattern, XOrOpPattern,
// Comparison patterns
- BoolCmpIOpPattern, CmpFOpPattern, CmpIOpPattern,
+ BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern,
// Constant patterns
ConstantCompositeOpPattern, ConstantScalarOpPattern,
@@ -1124,5 +1185,10 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
TypeCastingOpPattern<FPExtOp, spirv::FConvertOp>,
TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>>(typeConverter,
context);
+
+ // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
+ // capability is available.
+ patterns.insert<CmpFOpNanKernelPattern>(typeConverter, context,
+ /*benefit=*/2);
}
} // namespace mlir
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 3d99696d6882..4506447b0503 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -900,6 +900,16 @@ static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state,
state.addOperands({lhs, rhs});
}
+static void buildLogicalUnaryOp(OpBuilder &builder, OperationState &state,
+ Value value) {
+ Type boolType = builder.getI1Type();
+ if (auto vecType = value.getType().dyn_cast<VectorType>())
+ boolType = VectorType::get(vecType.getShape(), boolType);
+ state.addTypes(boolType);
+
+ state.addOperands(value);
+}
+
//===----------------------------------------------------------------------===//
// spv.AccessChainOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index a33db1dd42cf..8ae93c2e4b9b 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -301,6 +301,7 @@ func @cmpf(%arg0 : f32, %arg1 : f32) {
// -----
+// With Kernel capability, we can convert NaN check to spv.Ordered/spv.Unordered.
module attributes {
spv.target_env = #spv.target_env<#spv.vce<v1.0, [Kernel], []>, {}>
} {
@@ -318,6 +319,31 @@ func @cmpf(%arg0 : f32, %arg1 : f32) {
// -----
+// Without Kernel capability, we need to convert NaN check to spv.IsNan.
+module attributes {
+ spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
+} {
+
+// CHECK-LABEL: @cmpf
+// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
+func @cmpf(%arg0 : f32, %arg1 : f32) {
+ // CHECK: %[[LHS_NAN:.+]] = spv.IsNan %[[LHS]] : f32
+ // CHECK-NEXT: %[[RHS_NAN:.+]] = spv.IsNan %[[RHS]] : f32
+ // CHECK-NEXT: %[[OR:.+]] = spv.LogicalOr %[[LHS_NAN]], %[[RHS_NAN]] : i1
+ // CHECK-NEXT: %{{.+}} = spv.LogicalNot %[[OR]] : i1
+ %0 = cmpf ord, %arg0, %arg1 : f32
+
+ // CHECK-NEXT: %[[LHS_NAN:.+]] = spv.IsNan %[[LHS]] : f32
+ // CHECK-NEXT: %[[RHS_NAN:.+]] = spv.IsNan %[[RHS]] : f32
+ // CHECK-NEXT: %{{.+}} = spv.LogicalOr %[[LHS_NAN]], %[[RHS_NAN]] : i1
+ %1 = cmpf uno, %arg0, %arg1 : f32
+ return
+}
+
+} // end module
+
+// -----
+
//===----------------------------------------------------------------------===//
// std.cmpi
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
index baf8b45d7eaf..b2c34b85f194 100644
--- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
@@ -32,6 +32,40 @@ func @inotequal_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi
// -----
+//===----------------------------------------------------------------------===//
+// spv.IsInf
+//===----------------------------------------------------------------------===//
+
+func @isinf_scalar(%arg0: f32) -> i1 {
+ // CHECK: spv.IsInf {{.*}} : f32
+ %0 = spv.IsInf %arg0 : f32
+ return %0 : i1
+}
+
+func @isinf_vector(%arg0: vector<2xf32>) -> vector<2xi1> {
+ // CHECK: spv.IsInf {{.*}} : vector<2xf32>
+ %0 = spv.IsInf %arg0 : vector<2xf32>
+ return %0 : vector<2xi1>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.IsNan
+//===----------------------------------------------------------------------===//
+
+func @isnan_scalar(%arg0: f32) -> i1 {
+ // CHECK: spv.IsNan {{.*}} : f32
+ %0 = spv.IsNan %arg0 : f32
+ return %0 : i1
+}
+
+func @isnan_vector(%arg0: vector<2xf32>) -> vector<2xi1> {
+ // CHECK: spv.IsNan {{.*}} : vector<2xf32>
+ %0 = spv.IsNan %arg0 : vector<2xf32>
+ return %0 : vector<2xi1>
+}
+
//===----------------------------------------------------------------------===//
// spv.LogicalAnd
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/logical-ops.mlir b/mlir/test/Target/SPIRV/logical-ops.mlir
index 000cf49d733a..bd92074de39f 100644
--- a/mlir/test/Target/SPIRV/logical-ops.mlir
+++ b/mlir/test/Target/SPIRV/logical-ops.mlir
@@ -80,6 +80,10 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
%13 = spv.Ordered %arg0, %arg1 : f32
// CHECK: spv.Unordered
%14 = spv.Unordered %arg0, %arg1 : f32
+ // CHCK: spv.IsNan
+ %15 = spv.IsNan %arg0 : f32
+ // CHCK: spv.IsInf
+ %16 = spv.IsInf %arg1 : f32
spv.Return
}
}
More information about the llvm-branch-commits
mailing list