[Mlir-commits] [mlir] 0f6103a - [mlir][spirv] Support more max/min vector.reduction
Lei Zhang
llvmlistbot at llvm.org
Fri Sep 2 14:23:51 PDT 2022
Author: Lei Zhang
Date: 2022-09-02T17:21:57-04:00
New Revision: 0f6103af97e71ef17d2ac3d8bb00bc2f173a2ffb
URL: https://github.com/llvm/llvm-project/commit/0f6103af97e71ef17d2ac3d8bb00bc2f173a2ffb
DIFF: https://github.com/llvm/llvm-project/commit/0f6103af97e71ef17d2ac3d8bb00bc2f173a2ffb.diff
LOG: [mlir][spirv] Support more max/min vector.reduction
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D133168
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 483619ba708d9..53492615e0607 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -283,7 +283,8 @@ struct VectorReductionPattern final
Value result = values.front();
for (Value next : llvm::makeArrayRef(values).drop_front()) {
switch (reduceOp.getKind()) {
-#define INT_FLOAT_CASE(kind, iop, fop) \
+
+#define INT_AND_FLOAT_CASE(kind, iop, fop) \
case vector::CombiningKind::kind: \
if (resultType.isa<IntegerType>()) { \
result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
@@ -293,15 +294,21 @@ struct VectorReductionPattern final
} \
break
- INT_FLOAT_CASE(ADD, IAddOp, FAddOp);
- INT_FLOAT_CASE(MUL, IMulOp, FMulOp);
+#define INT_OR_FLOAT_CASE(kind, fop) \
+ case vector::CombiningKind::kind: \
+ result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
+ break
+
+ INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
+ INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
+
+ INT_OR_FLOAT_CASE(MAXF, GLFMaxOp);
+ INT_OR_FLOAT_CASE(MINF, GLFMinOp);
+ INT_OR_FLOAT_CASE(MINUI, GLUMinOp);
+ INT_OR_FLOAT_CASE(MINSI, GLSMinOp);
+ INT_OR_FLOAT_CASE(MAXUI, GLUMaxOp);
+ INT_OR_FLOAT_CASE(MAXSI, GLSMaxOp);
- case vector::CombiningKind::MINUI:
- case vector::CombiningKind::MINSI:
- case vector::CombiningKind::MINF:
- case vector::CombiningKind::MAXUI:
- case vector::CombiningKind::MAXSI:
- case vector::CombiningKind::MAXF:
case vector::CombiningKind::AND:
case vector::CombiningKind::OR:
case vector::CombiningKind::XOR:
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index a5af59e41453e..f1de62c622d5f 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -254,7 +254,7 @@ func.func @shuffle(%v0 : vector<2x16xf32>, %v1: vector<1x16xf32>) -> vector<3x16
// -----
-// CHECK-LABEL: func @reduction
+// CHECK-LABEL: func @reduction_add
// CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>
// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<4xi32>
@@ -264,23 +264,119 @@ func.func @shuffle(%v0 : vector<2x16xf32>, %v1: vector<1x16xf32>) -> vector<3x16
// CHECK: %[[ADD1:.+]] = spv.IAdd %[[ADD0]], %[[S2]]
// CHECK: %[[ADD2:.+]] = spv.IAdd %[[ADD1]], %[[S3]]
// CHECK: return %[[ADD2]]
-func.func @reduction(%v : vector<4xi32>) -> i32 {
+func.func @reduction_add(%v : vector<4xi32>) -> i32 {
%reduce = vector.reduction <add>, %v : vector<4xi32> into i32
return %reduce : i32
}
// -----
-// CHECK-LABEL: func @reduction
+// CHECK-LABEL: func @reduction_mul
// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
-// CHECK: %[[ADD0:.+]] = spv.FMul %[[S0]], %[[S1]]
-// CHECK: %[[ADD1:.+]] = spv.FMul %[[ADD0]], %[[S2]]
-// CHECK: %[[ADD2:.+]] = spv.FMul %[[ADD1]], %[[S]]
-// CHECK: return %[[ADD2]]
-func.func @reduction(%v : vector<3xf32>, %s: f32) -> f32 {
+// CHECK: %[[MUL0:.+]] = spv.FMul %[[S0]], %[[S1]]
+// CHECK: %[[MUL1:.+]] = spv.FMul %[[MUL0]], %[[S2]]
+// CHECK: %[[MUL2:.+]] = spv.FMul %[[MUL1]], %[[S]]
+// CHECK: return %[[MUL2]]
+func.func @reduction_mul(%v : vector<3xf32>, %s: f32) -> f32 {
%reduce = vector.reduction <mul>, %v, %s : vector<3xf32> into f32
return %reduce : f32
}
+
+// -----
+
+// CHECK-LABEL: func @reduction_maxf
+// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+// CHECK: %[[MAX0:.+]] = spv.GL.FMax %[[S0]], %[[S1]]
+// CHECK: %[[MAX1:.+]] = spv.GL.FMax %[[MAX0]], %[[S2]]
+// CHECK: %[[MAX2:.+]] = spv.GL.FMax %[[MAX1]], %[[S]]
+// CHECK: return %[[MAX2]]
+func.func @reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <maxf>, %v, %s : vector<3xf32> into f32
+ return %reduce : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_minf
+// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+// CHECK: %[[MIN0:.+]] = spv.GL.FMin %[[S0]], %[[S1]]
+// CHECK: %[[MIN1:.+]] = spv.GL.FMin %[[MIN0]], %[[S2]]
+// CHECK: %[[MIN2:.+]] = spv.GL.FMin %[[MIN1]], %[[S]]
+// CHECK: return %[[MIN2]]
+func.func @reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <minf>, %v, %s : vector<3xf32> into f32
+ return %reduce : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_maxsi
+// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+// CHECK: %[[MAX0:.+]] = spv.GL.SMax %[[S0]], %[[S1]]
+// CHECK: %[[MAX1:.+]] = spv.GL.SMax %[[MAX0]], %[[S2]]
+// CHECK: %[[MAX2:.+]] = spv.GL.SMax %[[MAX1]], %[[S]]
+// CHECK: return %[[MAX2]]
+func.func @reduction_maxsi(%v : vector<3xi32>, %s: i32) -> i32 {
+ %reduce = vector.reduction <maxsi>, %v, %s : vector<3xi32> into i32
+ return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_minsi
+// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+// CHECK: %[[MIN0:.+]] = spv.GL.SMin %[[S0]], %[[S1]]
+// CHECK: %[[MIN1:.+]] = spv.GL.SMin %[[MIN0]], %[[S2]]
+// CHECK: %[[MIN2:.+]] = spv.GL.SMin %[[MIN1]], %[[S]]
+// CHECK: return %[[MIN2]]
+func.func @reduction_minsi(%v : vector<3xi32>, %s: i32) -> i32 {
+ %reduce = vector.reduction <minsi>, %v, %s : vector<3xi32> into i32
+ return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_maxui
+// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+// CHECK: %[[MAX0:.+]] = spv.GL.UMax %[[S0]], %[[S1]]
+// CHECK: %[[MAX1:.+]] = spv.GL.UMax %[[MAX0]], %[[S2]]
+// CHECK: %[[MAX2:.+]] = spv.GL.UMax %[[MAX1]], %[[S]]
+// CHECK: return %[[MAX2]]
+func.func @reduction_maxui(%v : vector<3xi32>, %s: i32) -> i32 {
+ %reduce = vector.reduction <maxui>, %v, %s : vector<3xi32> into i32
+ return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_minui
+// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+// CHECK: %[[MIN0:.+]] = spv.GL.UMin %[[S0]], %[[S1]]
+// CHECK: %[[MIN1:.+]] = spv.GL.UMin %[[MIN0]], %[[S2]]
+// CHECK: %[[MIN2:.+]] = spv.GL.UMin %[[MIN1]], %[[S]]
+// CHECK: return %[[MIN2]]
+func.func @reduction_minui(%v : vector<3xi32>, %s: i32) -> i32 {
+ %reduce = vector.reduction <minui>, %v, %s : vector<3xi32> into i32
+ return %reduce : i32
+}
More information about the Mlir-commits
mailing list