[Mlir-commits] [mlir] f525305 - [mlir][spirv] Lower max/min vector.reduction for OpenCL

Lei Zhang llvmlistbot at llvm.org
Tue Sep 20 14:22:53 PDT 2022


Author: Stanley Winata
Date: 2022-09-20T17:22:41-04:00
New Revision: f5253058144aca1e9fcacdca53accdc975e804cf

URL: https://github.com/llvm/llvm-project/commit/f5253058144aca1e9fcacdca53accdc975e804cf
DIFF: https://github.com/llvm/llvm-project/commit/f5253058144aca1e9fcacdca53accdc975e804cf.diff

LOG: [mlir][spirv] Lower max/min vector.reduction for OpenCL

Templatizing vector reduction to enable lowering from
vector.reduction max/min to CL ops.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D134313

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 23f664c44e1bd..7b1fad1f7a391 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -272,6 +272,8 @@ struct VectorInsertStridedSliceOpConvert final
   }
 };
 
+template <class SPVFMaxOp, class SPVFMinOp, class SPVUMaxOp, class SPVUMinOp,
+          class SPVSMaxOp, class SPVSMinOp>
 struct VectorReductionPattern final
     : public OpConversionPattern<vector::ReductionOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -317,18 +319,18 @@ struct VectorReductionPattern final
 
 #define INT_OR_FLOAT_CASE(kind, fop)                                           \
   case vector::CombiningKind::kind:                                            \
-    result = rewriter.create<spirv::fop>(loc, resultType, result, next);       \
+    result = rewriter.create<fop>(loc, resultType, result, next);              \
     break
 
         INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
         INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
 
-        INT_OR_FLOAT_CASE(MAXF, GLFMaxOp);
-        INT_OR_FLOAT_CASE(MINF, GLFMinOp);
-        INT_OR_FLOAT_CASE(MINUI, GLUMinOp);
-        INT_OR_FLOAT_CASE(MINSI, GLSMinOp);
-        INT_OR_FLOAT_CASE(MAXUI, GLUMaxOp);
-        INT_OR_FLOAT_CASE(MAXSI, GLSMaxOp);
+        INT_OR_FLOAT_CASE(MAXF, SPVFMaxOp);
+        INT_OR_FLOAT_CASE(MINF, SPVFMinOp);
+        INT_OR_FLOAT_CASE(MINUI, SPVUMinOp);
+        INT_OR_FLOAT_CASE(MINSI, SPVSMinOp);
+        INT_OR_FLOAT_CASE(MAXUI, SPVUMaxOp);
+        INT_OR_FLOAT_CASE(MAXSI, SPVSMaxOp);
 
       case vector::CombiningKind::AND:
       case vector::CombiningKind::OR:
@@ -403,15 +405,23 @@ struct VectorShuffleOpConvert final
 };
 
 } // namespace
+#define CL_MAX_MIN_OPS                                                         \
+  spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp,          \
+      spirv::CLSMaxOp, spirv::CLSMinOp
+
+#define GL_MAX_MIN_OPS                                                         \
+  spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp,          \
+      spirv::GLSMaxOp, spirv::GLSMinOp
 
 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                          RewritePatternSet &patterns) {
-  patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
-               VectorExtractElementOpConvert, VectorExtractOpConvert,
-               VectorExtractStridedSliceOpConvert,
-               VectorFmaOpConvert<spirv::GLFmaOp>,
-               VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
-               VectorInsertOpConvert, VectorReductionPattern,
-               VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
-               VectorSplatPattern>(typeConverter, patterns.getContext());
+  patterns.add<
+      VectorBitcastConvert, VectorBroadcastConvert,
+      VectorExtractElementOpConvert, VectorExtractOpConvert,
+      VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
+      VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
+      VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
+      VectorReductionPattern<CL_MAX_MIN_OPS>, VectorInsertStridedSliceOpConvert,
+      VectorShuffleOpConvert, VectorSplatPattern>(typeConverter,
+                                                  patterns.getContext());
 }

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index f483fc3e0c531..afce0493ebb77 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -33,6 +33,90 @@ func.func @cl_fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<
   return %0 : vector<1xf32>
 }
 
+// CHECK-LABEL: func @cl_reduction_maxf
+//  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+//       CHECK:   %[[MAX0:.+]] = spv.CL.fmax %[[S0]], %[[S1]]
+//       CHECK:   %[[MAX1:.+]] = spv.CL.fmax %[[MAX0]], %[[S2]]
+//       CHECK:   %[[MAX2:.+]] = spv.CL.fmax %[[MAX1]], %[[S]]
+//       CHECK:   return %[[MAX2]]
+func.func @cl_reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <maxf>, %v, %s : vector<3xf32> into f32
+  return %reduce : f32
+}
+
+// CHECK-LABEL: func @cl_reduction_minf
+//  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+//       CHECK:   %[[MIN0:.+]] = spv.CL.fmin %[[S0]], %[[S1]]
+//       CHECK:   %[[MIN1:.+]] = spv.CL.fmin %[[MIN0]], %[[S2]]
+//       CHECK:   %[[MIN2:.+]] = spv.CL.fmin %[[MIN1]], %[[S]]
+//       CHECK:   return %[[MIN2]]
+func.func @cl_reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <minf>, %v, %s : vector<3xf32> into f32
+  return %reduce : f32
+}
+
+// CHECK-LABEL: func @cl_reduction_maxsi
+//  CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+//       CHECK:   %[[MAX0:.+]] = spv.CL.s_max %[[S0]], %[[S1]]
+//       CHECK:   %[[MAX1:.+]] = spv.CL.s_max %[[MAX0]], %[[S2]]
+//       CHECK:   %[[MAX2:.+]] = spv.CL.s_max %[[MAX1]], %[[S]]
+//       CHECK:   return %[[MAX2]]
+func.func @cl_reduction_maxsi(%v : vector<3xi32>, %s: i32) -> i32 {
+  %reduce = vector.reduction <maxsi>, %v, %s : vector<3xi32> into i32
+  return %reduce : i32
+}
+
+// CHECK-LABEL: func @cl_reduction_minsi
+//  CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+//       CHECK:   %[[MIN0:.+]] = spv.CL.s_min %[[S0]], %[[S1]]
+//       CHECK:   %[[MIN1:.+]] = spv.CL.s_min %[[MIN0]], %[[S2]]
+//       CHECK:   %[[MIN2:.+]] = spv.CL.s_min %[[MIN1]], %[[S]]
+//       CHECK:   return %[[MIN2]]
+func.func @cl_reduction_minsi(%v : vector<3xi32>, %s: i32) -> i32 {
+  %reduce = vector.reduction <minsi>, %v, %s : vector<3xi32> into i32
+  return %reduce : i32
+}
+
+// CHECK-LABEL: func @cl_reduction_maxui
+//  CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+//       CHECK:   %[[MAX0:.+]] = spv.CL.u_max %[[S0]], %[[S1]]
+//       CHECK:   %[[MAX1:.+]] = spv.CL.u_max %[[MAX0]], %[[S2]]
+//       CHECK:   %[[MAX2:.+]] = spv.CL.u_max %[[MAX1]], %[[S]]
+//       CHECK:   return %[[MAX2]]
+func.func @cl_reduction_maxui(%v : vector<3xi32>, %s: i32) -> i32 {
+  %reduce = vector.reduction <maxui>, %v, %s : vector<3xi32> into i32
+  return %reduce : i32
+}
+
+// CHECK-LABEL: func @cl_reduction_minui
+//  CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+//       CHECK:   %[[MIN0:.+]] = spv.CL.u_min %[[S0]], %[[S1]]
+//       CHECK:   %[[MIN1:.+]] = spv.CL.u_min %[[MIN0]], %[[S2]]
+//       CHECK:   %[[MIN2:.+]] = spv.CL.u_min %[[MIN1]], %[[S]]
+//       CHECK:   return %[[MIN2]]
+func.func @cl_reduction_minui(%v : vector<3xi32>, %s: i32) -> i32 {
+  %reduce = vector.reduction <minui>, %v, %s : vector<3xi32> into i32
+  return %reduce : i32
+}
+
 } // end module
 
 // -----


        


More information about the Mlir-commits mailing list