[Mlir-commits] [mlir] 7d23d1e - [mlir][spirv] Lower arith max/min ops to OpenCL ones

Lei Zhang llvmlistbot at llvm.org
Mon Sep 19 10:34:29 PDT 2022


Author: Stanley Winata
Date: 2022-09-19T13:34:09-04:00
New Revision: 7d23d1e640dcde0e90a42353102198b95e20e5f4

URL: https://github.com/llvm/llvm-project/commit/7d23d1e640dcde0e90a42353102198b95e20e5f4
DIFF: https://github.com/llvm/llvm-project/commit/7d23d1e640dcde0e90a42353102198b95e20e5f4.diff

LOG: [mlir][spirv] Lower arith max/min ops to OpenCL ones

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D132881

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
    mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
    mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
    mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
    mlir/test/Target/SPIRV/ocl-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
index 524c375247860..d95ed47bf486e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
@@ -332,6 +332,67 @@ def SPV_CLFmaOp : SPV_CLTernaryArithmeticOp<"fma", 26, SPV_Float> {
 
 // -----
 
+def SPV_CLFMaxOp : SPV_CLBinaryArithmeticOp<"fmax", 27, SPV_Float> {
+  let summary = "Return maximum of two floating-point operands";
+
+  let description = [{
+    Returns y if x < y, otherwise it returns x. If one argument is a NaN,
+    Fmax returns the other argument. If both arguments are NaNs, Fmax returns a NaN.
+
+    Result Type, x and y must be floating-point or vector(2,3,4,8,16) 
+    of floating-point values.
+
+    All of the operands, including the Result Type operand, 
+    must be of the same type.
+
+    <!-- End of AutoGen section -->
+    ```
+    float-scalar-vector-type ::= float-type |
+                                 `vector<` integer-literal `x` float-type `>`
+    fmax-op ::= ssa-id `=` `spv.CL.fmax` ssa-use `:`
+                float-scalar-vector-type
+    ```
+    #### Example:
+
+    ```mlir
+    %2 = spv.CL.fmax %0, %1 : f32
+    %3 = spv.CL.fmax %0, %1 : vector<3xf16>
+    ```
+  }];
+}
+
+// -----
+
+def SPV_CLFMinOp : SPV_CLBinaryArithmeticOp<"fmin", 28, SPV_Float> {
+  let summary = "Return minimum of two floating-point operands";
+
+  let description = [{
+    Returns y if y < x, otherwise it returns x. If one argument is a NaN, Fmin returns the other argument. 
+    If both arguments are NaNs, Fmin returns a NaN.
+
+    Result Type,x and y must be floating-point or vector(2,3,4,8,16) of floating-point values.
+
+    All of the operands, including the Result Type operand, must be of the same type.
+
+
+    <!-- End of AutoGen section -->
+    ```
+    float-scalar-vector-type ::= float-type |
+                                 `vector<` integer-literal `x` float-type `>`
+    fmin-op ::= ssa-id `=` `spv.CL.fmin` ssa-use `:`
+                float-scalar-vector-type
+    ```
+    #### Example:
+
+    ```mlir
+    %2 = spv.CL.fmin %0, %1 : f32
+    %3 = spv.CL.fmin %0, %1 : vector<3xf16>
+    ```
+  }];
+}
+
+// -----
+
 def SPV_CLLogOp : SPV_CLUnaryArithmeticOp<"log", 37, SPV_Float> {
   let summary = "Compute the natural logarithm of x.";
 
@@ -573,4 +634,110 @@ def SPV_CLSAbsOp : SPV_CLUnaryArithmeticOp<"s_abs", 141, SPV_Integer> {
   }];
 }
 
+// -----
+
+def SPV_CLSMaxOp : SPV_CLBinaryArithmeticOp<"s_max", 156, SPV_Integer> {
+  let summary = "Return maximum of two signed integer operands";
+
+  let description = [{
+    Returns y if x < y, otherwise it returns x, where x and y are treated as signed integers.
+
+    Result Type,x and y must be integer or vector(2,3,4,8,16) of integer values.
+
+    All of the operands, including the Result Type operand, must be of the same type.
+
+    <!-- End of AutoGen section -->
+    ```
+    integer-scalar-vector-type ::= integer-type |
+                                   `vector<` integer-literal `x` integer-type `>`
+    smax-op ::= ssa-id `=` `spv.CL.s_max` ssa-use `:`
+                integer-scalar-vector-type
+    ```
+    #### Example:
+    ```mlir
+    %2 = spv.CL.s_max %0, %1 : i32
+    %3 = spv.CL.s_max %0, %1 : vector<3xi16>
+    ```
+  }];
+}
+
+// -----
+
+def SPV_CLUMaxOp : SPV_CLBinaryArithmeticOp<"u_max", 157, SPV_Integer> {
+  let summary = "Return maximum of two unsigned integer operands";
+
+  let description = [{
+    Returns y if x < y, otherwise it returns x, where x and y are treated as unsigned integers.
+
+    Result Type,x and y must be integer or vector(2,3,4,8,16) of integer values.
+
+    All of the operands, including the Result Type operand, must be of the same type.
+
+    <!-- End of AutoGen section -->
+    ```
+    integer-scalar-vector-type ::= integer-type |
+                                   `vector<` integer-literal `x` integer-type `>`
+    umax-op ::= ssa-id `=` `spv.CL.u_max` ssa-use `:`
+                integer-scalar-vector-type
+    ```
+    #### Example:
+    ```mlir
+    %2 = spv.CL.u_max %0, %1 : i32
+    %3 = spv.CL.u_max %0, %1 : vector<3xi16>
+    ```
+  }];
+}
+
+def SPV_CLSMinOp : SPV_CLBinaryArithmeticOp<"s_min", 158, SPV_Integer> {
+  let summary = "Return minimum of two signed integer operands";
+
+  let description = [{
+    Returns y if x < y, otherwise it returns x, where x and y are treated as signed integers.
+
+    Result Type,x and y must be integer or vector(2,3,4,8,16) of integer values.
+
+    All of the operands, including the Result Type operand, must be of the same type.
+
+    <!-- End of AutoGen section -->
+    ```
+    integer-scalar-vector-type ::= integer-type |
+                                   `vector<` integer-literal `x` integer-type `>`
+    smin-op ::= ssa-id `=` `spv.CL.s_min` ssa-use `:`
+                integer-scalar-vector-type
+    ```
+    #### Example:
+    ```mlir
+    %2 = spv.CL.s_min %0, %1 : i32
+    %3 = spv.CL.s_min %0, %1 : vector<3xi16>
+    ```
+  }];
+}
+
+// -----
+
+def SPV_CLUMinOp : SPV_CLBinaryArithmeticOp<"u_min", 159, SPV_Integer> {
+  let summary = "Return minimum of two unsigned integer operands";
+
+  let description = [{
+    Returns y if x < y, otherwise it returns x, where x and y are treated as unsigned integers.
+
+    Result Type,x and y must be integer or vector(2,3,4,8,16) of integer values.
+
+    All of the operands, including the Result Type operand, must be of the same type.
+
+    <!-- End of AutoGen section -->
+    ```
+    integer-scalar-vector-type ::= integer-type |
+                                   `vector<` integer-literal `x` integer-type `>`
+    umin-op ::= ssa-id `=` `spv.CL.u_min` ssa-use `:`
+                integer-scalar-vector-type
+    ```
+    #### Example:
+    ```mlir
+    %2 = spv.CL.u_min %0, %1 : i32
+    %3 = spv.CL.u_min %0, %1 : vector<3xi16>
+    ```
+  }];
+}
+
 #endif // MLIR_DIALECT_SPIRV_IR_CL_OPS

diff  --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index 18c3abe7a7793..c6c3f2b366ea8 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -219,7 +219,7 @@ class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
-/// Converts arith.maxf to spv.GL.FMax.
+/// Converts arith.maxf to spv.GL.FMax or spv.CL.fmax.
 template <typename Op, typename SPIRVOp>
 class MinMaxFOpPattern final : public OpConversionPattern<Op> {
 public:
@@ -926,9 +926,11 @@ LogicalResult MinMaxFOpPattern<Op, SPIRVOp>::matchAndRewrite(
 
   // arith.maxf/minf:
   //   "if one of the arguments is NaN, then the result is also NaN."
-  // spv.GL.FMax/FMin:
+  // spv.GL.FMax/FMin
   //   "which operand is the result is undefined if one of the operands
   //   is a NaN."
+  // spv.CL.fmax/fmin:
+  //   "If one argument is a NaN, Fmin returns the other argument."
 
   Location loc = op.getLoc();
   Value spirvOp = rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
@@ -998,7 +1000,14 @@ void mlir::arith::populateArithmeticToSPIRVPatterns(
     spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,
     spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLUMaxOp>,
     spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSMinOp>,
-    spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLUMinOp>
+    spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLUMinOp>,
+
+    MinMaxFOpPattern<arith::MaxFOp, spirv::CLFMaxOp>,
+    MinMaxFOpPattern<arith::MinFOp, spirv::CLFMinOp>,
+    spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::CLSMaxOp>,
+    spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::CLUMaxOp>,
+    spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::CLSMinOp>,
+    spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::CLUMinOp>
   >(typeConverter, patterns.getContext());
   // clang-format on
 

diff  --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
index 9430183d37d0b..8b4de3da02083 100644
--- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
@@ -970,12 +970,80 @@ func.func @sitofp(%arg0 : i64) -> f64 {
 
 // -----
 
-// Check OpenCL lowering of arith.remsi
+// Check various lowerings for OpenCL.
 module attributes {
   spv.target_env = #spv.target_env<
     #spv.vce<v1.0, [Int16, Kernel], []>, #spv.resource_limits<>>
 } {
 
+// Check integer operation conversions.
+// CHECK-LABEL: @int32_scalar
+func.func @int32_scalar(%lhs: i32, %rhs: i32) {
+  // CHECK: spv.IAdd %{{.*}}, %{{.*}}: i32
+  %0 = arith.addi %lhs, %rhs: i32
+  // CHECK: spv.ISub %{{.*}}, %{{.*}}: i32
+  %1 = arith.subi %lhs, %rhs: i32
+  // CHECK: spv.IMul %{{.*}}, %{{.*}}: i32
+  %2 = arith.muli %lhs, %rhs: i32
+  // CHECK: spv.SDiv %{{.*}}, %{{.*}}: i32
+  %3 = arith.divsi %lhs, %rhs: i32
+  // CHECK: spv.UDiv %{{.*}}, %{{.*}}: i32
+  %4 = arith.divui %lhs, %rhs: i32
+  // CHECK: spv.UMod %{{.*}}, %{{.*}}: i32
+  %5 = arith.remui %lhs, %rhs: i32
+  // CHECK: spv.CL.s_max %{{.*}}, %{{.*}}: i32
+  %6 = arith.maxsi %lhs, %rhs : i32
+  // CHECK: spv.CL.u_max %{{.*}}, %{{.*}}: i32
+  %7 = arith.maxui %lhs, %rhs : i32
+  // CHECK: spv.CL.s_min %{{.*}}, %{{.*}}: i32
+  %8 = arith.minsi %lhs, %rhs : i32
+  // CHECK: spv.CL.u_min %{{.*}}, %{{.*}}: i32
+  %9 = arith.minui %lhs, %rhs : i32
+  return
+}
+
+// Check float binary operation conversions.
+// CHECK-LABEL: @float32_binary_scalar
+func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
+  // CHECK: spv.FAdd %{{.*}}, %{{.*}}: f32
+  %0 = arith.addf %lhs, %rhs: f32
+  // CHECK: spv.FSub %{{.*}}, %{{.*}}: f32
+  %1 = arith.subf %lhs, %rhs: f32
+  // CHECK: spv.FMul %{{.*}}, %{{.*}}: f32
+  %2 = arith.mulf %lhs, %rhs: f32
+  // CHECK: spv.FDiv %{{.*}}, %{{.*}}: f32
+  %3 = arith.divf %lhs, %rhs: f32
+  // CHECK: spv.FRem %{{.*}}, %{{.*}}: f32
+  %4 = arith.remf %lhs, %rhs: f32
+  return
+}
+
+// CHECK-LABEL: @float32_minf_scalar
+// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
+func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
+  // CHECK: %[[MIN:.+]] = spv.CL.fmin %arg0, %arg1 : f32
+  // CHECK: %[[LHS_NAN:.+]] = spv.IsNan %[[LHS]] : f32
+  // CHECK: %[[RHS_NAN:.+]] = spv.IsNan %[[RHS]] : f32
+  // CHECK: %[[SELECT1:.+]] = spv.Select %[[LHS_NAN]], %[[LHS]], %[[MIN]]
+  // CHECK: %[[SELECT2:.+]] = spv.Select %[[RHS_NAN]], %[[RHS]], %[[SELECT1]]
+  %0 = arith.minf %arg0, %arg1 : f32
+  // CHECK: return %[[SELECT2]]
+  return %0: f32
+}
+
+// CHECK-LABEL: @float32_maxf_scalar
+// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
+func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
+  // CHECK: %[[MAX:.+]] = spv.CL.fmax %arg0, %arg1 : vector<2xf32>
+  // CHECK: %[[LHS_NAN:.+]] = spv.IsNan %[[LHS]] : vector<2xf32>
+  // CHECK: %[[RHS_NAN:.+]] = spv.IsNan %[[RHS]] : vector<2xf32>
+  // CHECK: %[[SELECT1:.+]] = spv.Select %[[LHS_NAN]], %[[LHS]], %[[MAX]]
+  // CHECK: %[[SELECT2:.+]] = spv.Select %[[RHS_NAN]], %[[RHS]], %[[SELECT1]]
+  %0 = arith.maxf %arg0, %arg1 : vector<2xf32>
+  // CHECK: return %[[SELECT2]]
+  return %0: vector<2xf32>
+}
+
 // CHECK-LABEL: @scalar_srem
 // CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
 func.func @scalar_srem(%lhs: i32, %rhs: i32) {

diff  --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
index d4eba2c29e5f8..c4d5478694245 100644
--- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
@@ -185,3 +185,45 @@ func.func @fma(%a : vector<3xf32>, %b : vector<3xf32>, %c : vector<3xf32>) -> ()
   %2 = spv.CL.fma %a, %b, %c : vector<3xf32>
   return
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.CL.{F|S|U}{Max|Min}
+//===----------------------------------------------------------------------===//
+
+func.func @fmaxmin(%arg0 : f32, %arg1 : f32) {
+  // CHECK: spv.CL.fmax {{%.*}}, {{%.*}} : f32
+  %1 = spv.CL.fmax %arg0, %arg1 : f32
+  // CHECK: spv.CL.fmin {{%.*}}, {{%.*}} : f32
+  %2 = spv.CL.fmin %arg0, %arg1 : f32
+  return
+}
+
+func.func @fmaxminvec(%arg0 : vector<3xf16>, %arg1 : vector<3xf16>) {
+  // CHECK: spv.CL.fmax {{%.*}}, {{%.*}} : vector<3xf16>
+  %1 = spv.CL.fmax %arg0, %arg1 : vector<3xf16>
+  // CHECK: spv.CL.fmin {{%.*}}, {{%.*}} : vector<3xf16>
+  %2 = spv.CL.fmin %arg0, %arg1 : vector<3xf16>
+  return
+}
+
+func.func @fmaxminf64(%arg0 : f64, %arg1 : f64) {
+  // CHECK: spv.CL.fmax {{%.*}}, {{%.*}} : f64
+  %1 = spv.CL.fmax %arg0, %arg1 : f64
+  // CHECK: spv.CL.fmin {{%.*}}, {{%.*}} : f64
+  %2 = spv.CL.fmin %arg0, %arg1 : f64
+  return
+}
+
+func.func @iminmax(%arg0: i32, %arg1: i32) {
+  // CHECK: spv.CL.s_max {{%.*}}, {{%.*}} : i32
+  %1 = spv.CL.s_max %arg0, %arg1 : i32
+  // CHECK: spv.CL.u_max {{%.*}}, {{%.*}} : i32
+  %2 = spv.CL.u_max %arg0, %arg1 : i32
+  // CHECK: spv.CL.s_min {{%.*}}, {{%.*}} : i32
+  %3 = spv.CL.s_min %arg0, %arg1 : i32
+  // CHECK: spv.CL.u_min {{%.*}}, {{%.*}} : i32
+  %4 = spv.CL.u_min %arg0, %arg1 : i32
+  return
+}

diff  --git a/mlir/test/Target/SPIRV/ocl-ops.mlir b/mlir/test/Target/SPIRV/ocl-ops.mlir
index 784f56379e6b4..d5934f074784c 100644
--- a/mlir/test/Target/SPIRV/ocl-ops.mlir
+++ b/mlir/test/Target/SPIRV/ocl-ops.mlir
@@ -44,4 +44,21 @@ spv.module Physical64 OpenCL requires #spv.vce<v1.0, [Kernel, Addresses], []> {
     %13 = spv.CL.fma %arg0, %arg1, %arg2 : f32
     spv.Return
   }
+
+  spv.func @maxmin(%arg0 : f32, %arg1 : f32, %arg2 : i32, %arg3 : i32) "None" {
+    // CHECK: {{%.*}} = spv.CL.fmax {{%.*}}, {{%.*}} : f32
+    %1 = spv.CL.fmax %arg0, %arg1 : f32
+    // CHECK: {{%.*}} = spv.CL.s_max {{%.*}}, {{%.*}} : i32
+    %2 = spv.CL.s_max %arg2, %arg3 : i32
+    // CHECK: {{%.*}} = spv.CL.u_max {{%.*}}, {{%.*}} : i32
+    %3 = spv.CL.u_max %arg2, %arg3 : i32
+
+    // CHECK: {{%.*}} = spv.CL.fmin {{%.*}}, {{%.*}} : f32
+    %4 = spv.CL.fmin %arg0, %arg1 : f32
+    // CHECK: {{%.*}} = spv.CL.s_min {{%.*}}, {{%.*}} : i32
+    %5 = spv.CL.s_min %arg2, %arg3 : i32
+    // CHECK: {{%.*}} = spv.CL.u_min {{%.*}}, {{%.*}} : i32
+    %6 = spv.CL.u_min %arg2, %arg3 : i32
+    spv.Return
+  }
 }


        


More information about the Mlir-commits mailing list