[Mlir-commits] [mlir] f525305 - [mlir][spirv] Lower max/min vector.reduction for OpenCL
Lei Zhang
llvmlistbot at llvm.org
Tue Sep 20 14:22:53 PDT 2022
Author: Stanley Winata
Date: 2022-09-20T17:22:41-04:00
New Revision: f5253058144aca1e9fcacdca53accdc975e804cf
URL: https://github.com/llvm/llvm-project/commit/f5253058144aca1e9fcacdca53accdc975e804cf
DIFF: https://github.com/llvm/llvm-project/commit/f5253058144aca1e9fcacdca53accdc975e804cf.diff
LOG: [mlir][spirv] Lower max/min vector.reduction for OpenCL
Templatizing vector reduction to enable lowering from
vector.reduction max/min to CL ops.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D134313
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 23f664c44e1bd..7b1fad1f7a391 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -272,6 +272,8 @@ struct VectorInsertStridedSliceOpConvert final
}
};
+template <class SPVFMaxOp, class SPVFMinOp, class SPVUMaxOp, class SPVUMinOp,
+ class SPVSMaxOp, class SPVSMinOp>
struct VectorReductionPattern final
: public OpConversionPattern<vector::ReductionOp> {
using OpConversionPattern::OpConversionPattern;
@@ -317,18 +319,18 @@ struct VectorReductionPattern final
#define INT_OR_FLOAT_CASE(kind, fop) \
case vector::CombiningKind::kind: \
- result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
+ result = rewriter.create<fop>(loc, resultType, result, next); \
break
INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
- INT_OR_FLOAT_CASE(MAXF, GLFMaxOp);
- INT_OR_FLOAT_CASE(MINF, GLFMinOp);
- INT_OR_FLOAT_CASE(MINUI, GLUMinOp);
- INT_OR_FLOAT_CASE(MINSI, GLSMinOp);
- INT_OR_FLOAT_CASE(MAXUI, GLUMaxOp);
- INT_OR_FLOAT_CASE(MAXSI, GLSMaxOp);
+ INT_OR_FLOAT_CASE(MAXF, SPVFMaxOp);
+ INT_OR_FLOAT_CASE(MINF, SPVFMinOp);
+ INT_OR_FLOAT_CASE(MINUI, SPVUMinOp);
+ INT_OR_FLOAT_CASE(MINSI, SPVSMinOp);
+ INT_OR_FLOAT_CASE(MAXUI, SPVUMaxOp);
+ INT_OR_FLOAT_CASE(MAXSI, SPVSMaxOp);
case vector::CombiningKind::AND:
case vector::CombiningKind::OR:
@@ -403,15 +405,23 @@ struct VectorShuffleOpConvert final
};
} // namespace
+#define CL_MAX_MIN_OPS \
+ spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp, \
+ spirv::CLSMaxOp, spirv::CLSMinOp
+
+#define GL_MAX_MIN_OPS \
+ spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp, \
+ spirv::GLSMaxOp, spirv::GLSMinOp
void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
- patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
- VectorExtractElementOpConvert, VectorExtractOpConvert,
- VectorExtractStridedSliceOpConvert,
- VectorFmaOpConvert<spirv::GLFmaOp>,
- VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
- VectorInsertOpConvert, VectorReductionPattern,
- VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
- VectorSplatPattern>(typeConverter, patterns.getContext());
+ patterns.add<
+ VectorBitcastConvert, VectorBroadcastConvert,
+ VectorExtractElementOpConvert, VectorExtractOpConvert,
+ VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
+ VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
+ VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
+ VectorReductionPattern<CL_MAX_MIN_OPS>, VectorInsertStridedSliceOpConvert,
+ VectorShuffleOpConvert, VectorSplatPattern>(typeConverter,
+ patterns.getContext());
}
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index f483fc3e0c531..afce0493ebb77 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -33,6 +33,90 @@ func.func @cl_fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<
return %0 : vector<1xf32>
}
+// CHECK-LABEL: func @cl_reduction_maxf
+// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+// CHECK: %[[MAX0:.+]] = spv.CL.fmax %[[S0]], %[[S1]]
+// CHECK: %[[MAX1:.+]] = spv.CL.fmax %[[MAX0]], %[[S2]]
+// CHECK: %[[MAX2:.+]] = spv.CL.fmax %[[MAX1]], %[[S]]
+// CHECK: return %[[MAX2]]
+func.func @cl_reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <maxf>, %v, %s : vector<3xf32> into f32
+ return %reduce : f32
+}
+
+// CHECK-LABEL: func @cl_reduction_minf
+// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+// CHECK: %[[MIN0:.+]] = spv.CL.fmin %[[S0]], %[[S1]]
+// CHECK: %[[MIN1:.+]] = spv.CL.fmin %[[MIN0]], %[[S2]]
+// CHECK: %[[MIN2:.+]] = spv.CL.fmin %[[MIN1]], %[[S]]
+// CHECK: return %[[MIN2]]
+func.func @cl_reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <minf>, %v, %s : vector<3xf32> into f32
+ return %reduce : f32
+}
+
+// CHECK-LABEL: func @cl_reduction_maxsi
+// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+// CHECK: %[[MAX0:.+]] = spv.CL.s_max %[[S0]], %[[S1]]
+// CHECK: %[[MAX1:.+]] = spv.CL.s_max %[[MAX0]], %[[S2]]
+// CHECK: %[[MAX2:.+]] = spv.CL.s_max %[[MAX1]], %[[S]]
+// CHECK: return %[[MAX2]]
+func.func @cl_reduction_maxsi(%v : vector<3xi32>, %s: i32) -> i32 {
+ %reduce = vector.reduction <maxsi>, %v, %s : vector<3xi32> into i32
+ return %reduce : i32
+}
+
+// CHECK-LABEL: func @cl_reduction_minsi
+// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+// CHECK: %[[MIN0:.+]] = spv.CL.s_min %[[S0]], %[[S1]]
+// CHECK: %[[MIN1:.+]] = spv.CL.s_min %[[MIN0]], %[[S2]]
+// CHECK: %[[MIN2:.+]] = spv.CL.s_min %[[MIN1]], %[[S]]
+// CHECK: return %[[MIN2]]
+func.func @cl_reduction_minsi(%v : vector<3xi32>, %s: i32) -> i32 {
+ %reduce = vector.reduction <minsi>, %v, %s : vector<3xi32> into i32
+ return %reduce : i32
+}
+
+// CHECK-LABEL: func @cl_reduction_maxui
+// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+// CHECK: %[[MAX0:.+]] = spv.CL.u_max %[[S0]], %[[S1]]
+// CHECK: %[[MAX1:.+]] = spv.CL.u_max %[[MAX0]], %[[S2]]
+// CHECK: %[[MAX2:.+]] = spv.CL.u_max %[[MAX1]], %[[S]]
+// CHECK: return %[[MAX2]]
+func.func @cl_reduction_maxui(%v : vector<3xi32>, %s: i32) -> i32 {
+ %reduce = vector.reduction <maxui>, %v, %s : vector<3xi32> into i32
+ return %reduce : i32
+}
+
+// CHECK-LABEL: func @cl_reduction_minui
+// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+// CHECK: %[[MIN0:.+]] = spv.CL.u_min %[[S0]], %[[S1]]
+// CHECK: %[[MIN1:.+]] = spv.CL.u_min %[[MIN0]], %[[S2]]
+// CHECK: %[[MIN2:.+]] = spv.CL.u_min %[[MIN1]], %[[S]]
+// CHECK: return %[[MIN2]]
+func.func @cl_reduction_minui(%v : vector<3xi32>, %s: i32) -> i32 {
+ %reduce = vector.reduction <minui>, %v, %s : vector<3xi32> into i32
+ return %reduce : i32
+}
+
} // end module
// -----
More information about the Mlir-commits
mailing list