[Mlir-commits] [mlir] 7d23d1e - [mlir][spirv] Lower arith max/min ops to OpenCL ones
Lei Zhang
llvmlistbot at llvm.org
Mon Sep 19 10:34:29 PDT 2022
Author: Stanley Winata
Date: 2022-09-19T13:34:09-04:00
New Revision: 7d23d1e640dcde0e90a42353102198b95e20e5f4
URL: https://github.com/llvm/llvm-project/commit/7d23d1e640dcde0e90a42353102198b95e20e5f4
DIFF: https://github.com/llvm/llvm-project/commit/7d23d1e640dcde0e90a42353102198b95e20e5f4.diff
LOG: [mlir][spirv] Lower arith max/min ops to OpenCL ones
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D132881
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
mlir/test/Target/SPIRV/ocl-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
index 524c375247860..d95ed47bf486e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
@@ -332,6 +332,67 @@ def SPV_CLFmaOp : SPV_CLTernaryArithmeticOp<"fma", 26, SPV_Float> {
// -----
+def SPV_CLFMaxOp : SPV_CLBinaryArithmeticOp<"fmax", 27, SPV_Float> {
+ let summary = "Return maximum of two floating-point operands";
+
+ let description = [{
+ Returns y if x < y, otherwise it returns x. If one argument is a NaN,
+ Fmax returns the other argument. If both arguments are NaNs, Fmax returns a NaN.
+
+ Result Type, x and y must be floating-point or vector(2,3,4,8,16)
+ of floating-point values.
+
+ All of the operands, including the Result Type operand,
+ must be of the same type.
+
+ <!-- End of AutoGen section -->
+ ```
+ float-scalar-vector-type ::= float-type |
+ `vector<` integer-literal `x` float-type `>`
+ fmax-op ::= ssa-id `=` `spv.CL.fmax` ssa-use `:`
+ float-scalar-vector-type
+ ```
+ #### Example:
+
+ ```mlir
+ %2 = spv.CL.fmax %0, %1 : f32
+ %3 = spv.CL.fmax %0, %1 : vector<3xf16>
+ ```
+ }];
+}
+
+// -----
+
+def SPV_CLFMinOp : SPV_CLBinaryArithmeticOp<"fmin", 28, SPV_Float> {
+ let summary = "Return minimum of two floating-point operands";
+
+ let description = [{
+ Returns y if y < x, otherwise it returns x. If one argument is a NaN, Fmin returns the other argument.
+ If both arguments are NaNs, Fmin returns a NaN.
+
+ Result Type,x and y must be floating-point or vector(2,3,4,8,16) of floating-point values.
+
+ All of the operands, including the Result Type operand, must be of the same type.
+
+
+ <!-- End of AutoGen section -->
+ ```
+ float-scalar-vector-type ::= float-type |
+ `vector<` integer-literal `x` float-type `>`
+ fmin-op ::= ssa-id `=` `spv.CL.fmin` ssa-use `:`
+ float-scalar-vector-type
+ ```
+ #### Example:
+
+ ```mlir
+ %2 = spv.CL.fmin %0, %1 : f32
+ %3 = spv.CL.fmin %0, %1 : vector<3xf16>
+ ```
+ }];
+}
+
+// -----
+
def SPV_CLLogOp : SPV_CLUnaryArithmeticOp<"log", 37, SPV_Float> {
let summary = "Compute the natural logarithm of x.";
@@ -573,4 +634,110 @@ def SPV_CLSAbsOp : SPV_CLUnaryArithmeticOp<"s_abs", 141, SPV_Integer> {
}];
}
+// -----
+
+def SPV_CLSMaxOp : SPV_CLBinaryArithmeticOp<"s_max", 156, SPV_Integer> {
+ let summary = "Return maximum of two signed integer operands";
+
+ let description = [{
+ Returns y if x < y, otherwise it returns x, where x and y are treated as signed integers.
+
+ Result Type,x and y must be integer or vector(2,3,4,8,16) of integer values.
+
+ All of the operands, including the Result Type operand, must be of the same type.
+
+ <!-- End of AutoGen section -->
+ ```
+ integer-scalar-vector-type ::= integer-type |
+ `vector<` integer-literal `x` integer-type `>`
+ smax-op ::= ssa-id `=` `spv.CL.s_max` ssa-use `:`
+ integer-scalar-vector-type
+ ```
+ #### Example:
+ ```mlir
+ %2 = spv.CL.s_max %0, %1 : i32
+ %3 = spv.CL.s_max %0, %1 : vector<3xi16>
+ ```
+ }];
+}
+
+// -----
+
+def SPV_CLUMaxOp : SPV_CLBinaryArithmeticOp<"u_max", 157, SPV_Integer> {
+ let summary = "Return maximum of two unsigned integer operands";
+
+ let description = [{
+ Returns y if x < y, otherwise it returns x, where x and y are treated as unsigned integers.
+
+ Result Type,x and y must be integer or vector(2,3,4,8,16) of integer values.
+
+ All of the operands, including the Result Type operand, must be of the same type.
+
+ <!-- End of AutoGen section -->
+ ```
+ integer-scalar-vector-type ::= integer-type |
+ `vector<` integer-literal `x` integer-type `>`
+ umax-op ::= ssa-id `=` `spv.CL.u_max` ssa-use `:`
+ integer-scalar-vector-type
+ ```
+ #### Example:
+ ```mlir
+ %2 = spv.CL.u_max %0, %1 : i32
+ %3 = spv.CL.u_max %0, %1 : vector<3xi16>
+ ```
+ }];
+}
+
+def SPV_CLSMinOp : SPV_CLBinaryArithmeticOp<"s_min", 158, SPV_Integer> {
+ let summary = "Return minimum of two signed integer operands";
+
+ let description = [{
+ Returns y if x < y, otherwise it returns x, where x and y are treated as signed integers.
+
+ Result Type,x and y must be integer or vector(2,3,4,8,16) of integer values.
+
+ All of the operands, including the Result Type operand, must be of the same type.
+
+ <!-- End of AutoGen section -->
+ ```
+ integer-scalar-vector-type ::= integer-type |
+ `vector<` integer-literal `x` integer-type `>`
+ smin-op ::= ssa-id `=` `spv.CL.s_min` ssa-use `:`
+ integer-scalar-vector-type
+ ```
+ #### Example:
+ ```mlir
+ %2 = spv.CL.s_min %0, %1 : i32
+ %3 = spv.CL.s_min %0, %1 : vector<3xi16>
+ ```
+ }];
+}
+
+// -----
+
+def SPV_CLUMinOp : SPV_CLBinaryArithmeticOp<"u_min", 159, SPV_Integer> {
+ let summary = "Return minimum of two unsigned integer operands";
+
+ let description = [{
+ Returns y if x < y, otherwise it returns x, where x and y are treated as unsigned integers.
+
+ Result Type,x and y must be integer or vector(2,3,4,8,16) of integer values.
+
+ All of the operands, including the Result Type operand, must be of the same type.
+
+ <!-- End of AutoGen section -->
+ ```
+ integer-scalar-vector-type ::= integer-type |
+ `vector<` integer-literal `x` integer-type `>`
+ umin-op ::= ssa-id `=` `spv.CL.u_min` ssa-use `:`
+ integer-scalar-vector-type
+ ```
+ #### Example:
+ ```mlir
+ %2 = spv.CL.u_min %0, %1 : i32
+ %3 = spv.CL.u_min %0, %1 : vector<3xi16>
+ ```
+ }];
+}
+
#endif // MLIR_DIALECT_SPIRV_IR_CL_OPS
diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index 18c3abe7a7793..c6c3f2b366ea8 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -219,7 +219,7 @@ class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
ConversionPatternRewriter &rewriter) const override;
};
-/// Converts arith.maxf to spv.GL.FMax.
+/// Converts arith.maxf to spv.GL.FMax or spv.CL.fmax.
template <typename Op, typename SPIRVOp>
class MinMaxFOpPattern final : public OpConversionPattern<Op> {
public:
@@ -926,9 +926,11 @@ LogicalResult MinMaxFOpPattern<Op, SPIRVOp>::matchAndRewrite(
// arith.maxf/minf:
// "if one of the arguments is NaN, then the result is also NaN."
- // spv.GL.FMax/FMin:
+ // spv.GL.FMax/FMin
// "which operand is the result is undefined if one of the operands
// is a NaN."
+ // spv.CL.fmax/fmin:
+ // "If one argument is a NaN, Fmin returns the other argument."
Location loc = op.getLoc();
Value spirvOp = rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
@@ -998,7 +1000,14 @@ void mlir::arith::populateArithmeticToSPIRVPatterns(
spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,
spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLUMaxOp>,
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSMinOp>,
- spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLUMinOp>
+ spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLUMinOp>,
+
+ MinMaxFOpPattern<arith::MaxFOp, spirv::CLFMaxOp>,
+ MinMaxFOpPattern<arith::MinFOp, spirv::CLFMinOp>,
+ spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::CLSMaxOp>,
+ spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::CLUMaxOp>,
+ spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::CLSMinOp>,
+ spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::CLUMinOp>
>(typeConverter, patterns.getContext());
// clang-format on
diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
index 9430183d37d0b..8b4de3da02083 100644
--- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
@@ -970,12 +970,80 @@ func.func @sitofp(%arg0 : i64) -> f64 {
// -----
-// Check OpenCL lowering of arith.remsi
+// Check various lowerings for OpenCL.
module attributes {
spv.target_env = #spv.target_env<
#spv.vce<v1.0, [Int16, Kernel], []>, #spv.resource_limits<>>
} {
+// Check integer operation conversions.
+// CHECK-LABEL: @int32_scalar
+func.func @int32_scalar(%lhs: i32, %rhs: i32) {
+ // CHECK: spv.IAdd %{{.*}}, %{{.*}}: i32
+ %0 = arith.addi %lhs, %rhs: i32
+ // CHECK: spv.ISub %{{.*}}, %{{.*}}: i32
+ %1 = arith.subi %lhs, %rhs: i32
+ // CHECK: spv.IMul %{{.*}}, %{{.*}}: i32
+ %2 = arith.muli %lhs, %rhs: i32
+ // CHECK: spv.SDiv %{{.*}}, %{{.*}}: i32
+ %3 = arith.divsi %lhs, %rhs: i32
+ // CHECK: spv.UDiv %{{.*}}, %{{.*}}: i32
+ %4 = arith.divui %lhs, %rhs: i32
+ // CHECK: spv.UMod %{{.*}}, %{{.*}}: i32
+ %5 = arith.remui %lhs, %rhs: i32
+ // CHECK: spv.CL.s_max %{{.*}}, %{{.*}}: i32
+ %6 = arith.maxsi %lhs, %rhs : i32
+ // CHECK: spv.CL.u_max %{{.*}}, %{{.*}}: i32
+ %7 = arith.maxui %lhs, %rhs : i32
+ // CHECK: spv.CL.s_min %{{.*}}, %{{.*}}: i32
+ %8 = arith.minsi %lhs, %rhs : i32
+ // CHECK: spv.CL.u_min %{{.*}}, %{{.*}}: i32
+ %9 = arith.minui %lhs, %rhs : i32
+ return
+}
+
+// Check float binary operation conversions.
+// CHECK-LABEL: @float32_binary_scalar
+func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
+ // CHECK: spv.FAdd %{{.*}}, %{{.*}}: f32
+ %0 = arith.addf %lhs, %rhs: f32
+ // CHECK: spv.FSub %{{.*}}, %{{.*}}: f32
+ %1 = arith.subf %lhs, %rhs: f32
+ // CHECK: spv.FMul %{{.*}}, %{{.*}}: f32
+ %2 = arith.mulf %lhs, %rhs: f32
+ // CHECK: spv.FDiv %{{.*}}, %{{.*}}: f32
+ %3 = arith.divf %lhs, %rhs: f32
+ // CHECK: spv.FRem %{{.*}}, %{{.*}}: f32
+ %4 = arith.remf %lhs, %rhs: f32
+ return
+}
+
+// CHECK-LABEL: @float32_minf_scalar
+// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
+func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
+ // CHECK: %[[MIN:.+]] = spv.CL.fmin %arg0, %arg1 : f32
+ // CHECK: %[[LHS_NAN:.+]] = spv.IsNan %[[LHS]] : f32
+ // CHECK: %[[RHS_NAN:.+]] = spv.IsNan %[[RHS]] : f32
+ // CHECK: %[[SELECT1:.+]] = spv.Select %[[LHS_NAN]], %[[LHS]], %[[MIN]]
+ // CHECK: %[[SELECT2:.+]] = spv.Select %[[RHS_NAN]], %[[RHS]], %[[SELECT1]]
+ %0 = arith.minf %arg0, %arg1 : f32
+ // CHECK: return %[[SELECT2]]
+ return %0: f32
+}
+
+// CHECK-LABEL: @float32_maxf_scalar
+// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
+func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
+ // CHECK: %[[MAX:.+]] = spv.CL.fmax %arg0, %arg1 : vector<2xf32>
+ // CHECK: %[[LHS_NAN:.+]] = spv.IsNan %[[LHS]] : vector<2xf32>
+ // CHECK: %[[RHS_NAN:.+]] = spv.IsNan %[[RHS]] : vector<2xf32>
+ // CHECK: %[[SELECT1:.+]] = spv.Select %[[LHS_NAN]], %[[LHS]], %[[MAX]]
+ // CHECK: %[[SELECT2:.+]] = spv.Select %[[RHS_NAN]], %[[RHS]], %[[SELECT1]]
+ %0 = arith.maxf %arg0, %arg1 : vector<2xf32>
+ // CHECK: return %[[SELECT2]]
+ return %0: vector<2xf32>
+}
+
// CHECK-LABEL: @scalar_srem
// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
func.func @scalar_srem(%lhs: i32, %rhs: i32) {
diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
index d4eba2c29e5f8..c4d5478694245 100644
--- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
@@ -185,3 +185,45 @@ func.func @fma(%a : vector<3xf32>, %b : vector<3xf32>, %c : vector<3xf32>) -> ()
%2 = spv.CL.fma %a, %b, %c : vector<3xf32>
return
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.CL.{F|S|U}{Max|Min}
+//===----------------------------------------------------------------------===//
+
+func.func @fmaxmin(%arg0 : f32, %arg1 : f32) {
+ // CHECK: spv.CL.fmax {{%.*}}, {{%.*}} : f32
+ %1 = spv.CL.fmax %arg0, %arg1 : f32
+ // CHECK: spv.CL.fmin {{%.*}}, {{%.*}} : f32
+ %2 = spv.CL.fmin %arg0, %arg1 : f32
+ return
+}
+
+func.func @fmaxminvec(%arg0 : vector<3xf16>, %arg1 : vector<3xf16>) {
+ // CHECK: spv.CL.fmax {{%.*}}, {{%.*}} : vector<3xf16>
+ %1 = spv.CL.fmax %arg0, %arg1 : vector<3xf16>
+ // CHECK: spv.CL.fmin {{%.*}}, {{%.*}} : vector<3xf16>
+ %2 = spv.CL.fmin %arg0, %arg1 : vector<3xf16>
+ return
+}
+
+func.func @fmaxminf64(%arg0 : f64, %arg1 : f64) {
+ // CHECK: spv.CL.fmax {{%.*}}, {{%.*}} : f64
+ %1 = spv.CL.fmax %arg0, %arg1 : f64
+ // CHECK: spv.CL.fmin {{%.*}}, {{%.*}} : f64
+ %2 = spv.CL.fmin %arg0, %arg1 : f64
+ return
+}
+
+func.func @iminmax(%arg0: i32, %arg1: i32) {
+ // CHECK: spv.CL.s_max {{%.*}}, {{%.*}} : i32
+ %1 = spv.CL.s_max %arg0, %arg1 : i32
+ // CHECK: spv.CL.u_max {{%.*}}, {{%.*}} : i32
+ %2 = spv.CL.u_max %arg0, %arg1 : i32
+ // CHECK: spv.CL.s_min {{%.*}}, {{%.*}} : i32
+ %3 = spv.CL.s_min %arg0, %arg1 : i32
+ // CHECK: spv.CL.u_min {{%.*}}, {{%.*}} : i32
+ %4 = spv.CL.u_min %arg0, %arg1 : i32
+ return
+}
diff --git a/mlir/test/Target/SPIRV/ocl-ops.mlir b/mlir/test/Target/SPIRV/ocl-ops.mlir
index 784f56379e6b4..d5934f074784c 100644
--- a/mlir/test/Target/SPIRV/ocl-ops.mlir
+++ b/mlir/test/Target/SPIRV/ocl-ops.mlir
@@ -44,4 +44,21 @@ spv.module Physical64 OpenCL requires #spv.vce<v1.0, [Kernel, Addresses], []> {
%13 = spv.CL.fma %arg0, %arg1, %arg2 : f32
spv.Return
}
+
+ spv.func @maxmin(%arg0 : f32, %arg1 : f32, %arg2 : i32, %arg3 : i32) "None" {
+ // CHECK: {{%.*}} = spv.CL.fmax {{%.*}}, {{%.*}} : f32
+ %1 = spv.CL.fmax %arg0, %arg1 : f32
+ // CHECK: {{%.*}} = spv.CL.s_max {{%.*}}, {{%.*}} : i32
+ %2 = spv.CL.s_max %arg2, %arg3 : i32
+ // CHECK: {{%.*}} = spv.CL.u_max {{%.*}}, {{%.*}} : i32
+ %3 = spv.CL.u_max %arg2, %arg3 : i32
+
+ // CHECK: {{%.*}} = spv.CL.fmin {{%.*}}, {{%.*}} : f32
+ %4 = spv.CL.fmin %arg0, %arg1 : f32
+ // CHECK: {{%.*}} = spv.CL.s_min {{%.*}}, {{%.*}} : i32
+ %5 = spv.CL.s_min %arg2, %arg3 : i32
+ // CHECK: {{%.*}} = spv.CL.u_min {{%.*}}, {{%.*}} : i32
+ %6 = spv.CL.u_min %arg2, %arg3 : i32
+ spv.Return
+ }
}
More information about the Mlir-commits
mailing list