[Mlir-commits] [mlir] 0f6103a - [mlir][spirv] Support more max/min vector.reduction

Lei Zhang llvmlistbot at llvm.org
Fri Sep 2 14:23:51 PDT 2022


Author: Lei Zhang
Date: 2022-09-02T17:21:57-04:00
New Revision: 0f6103af97e71ef17d2ac3d8bb00bc2f173a2ffb

URL: https://github.com/llvm/llvm-project/commit/0f6103af97e71ef17d2ac3d8bb00bc2f173a2ffb
DIFF: https://github.com/llvm/llvm-project/commit/0f6103af97e71ef17d2ac3d8bb00bc2f173a2ffb.diff

LOG: [mlir][spirv] Support more max/min vector.reduction

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D133168

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 483619ba708d9..53492615e0607 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -283,7 +283,8 @@ struct VectorReductionPattern final
     Value result = values.front();
     for (Value next : llvm::makeArrayRef(values).drop_front()) {
       switch (reduceOp.getKind()) {
-#define INT_FLOAT_CASE(kind, iop, fop)                                         \
+
+#define INT_AND_FLOAT_CASE(kind, iop, fop)                                     \
   case vector::CombiningKind::kind:                                            \
     if (resultType.isa<IntegerType>()) {                                       \
       result = rewriter.create<spirv::iop>(loc, resultType, result, next);     \
@@ -293,15 +294,21 @@ struct VectorReductionPattern final
     }                                                                          \
     break
 
-        INT_FLOAT_CASE(ADD, IAddOp, FAddOp);
-        INT_FLOAT_CASE(MUL, IMulOp, FMulOp);
+#define INT_OR_FLOAT_CASE(kind, fop)                                           \
+  case vector::CombiningKind::kind:                                            \
+    result = rewriter.create<spirv::fop>(loc, resultType, result, next);       \
+    break
+
+        INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
+        INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
+
+        INT_OR_FLOAT_CASE(MAXF, GLFMaxOp);
+        INT_OR_FLOAT_CASE(MINF, GLFMinOp);
+        INT_OR_FLOAT_CASE(MINUI, GLUMinOp);
+        INT_OR_FLOAT_CASE(MINSI, GLSMinOp);
+        INT_OR_FLOAT_CASE(MAXUI, GLUMaxOp);
+        INT_OR_FLOAT_CASE(MAXSI, GLSMaxOp);
 
-      case vector::CombiningKind::MINUI:
-      case vector::CombiningKind::MINSI:
-      case vector::CombiningKind::MINF:
-      case vector::CombiningKind::MAXUI:
-      case vector::CombiningKind::MAXSI:
-      case vector::CombiningKind::MAXF:
       case vector::CombiningKind::AND:
       case vector::CombiningKind::OR:
       case vector::CombiningKind::XOR:

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index a5af59e41453e..f1de62c622d5f 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -254,7 +254,7 @@ func.func @shuffle(%v0 : vector<2x16xf32>, %v1: vector<1x16xf32>) -> vector<3x16
 
 // -----
 
-// CHECK-LABEL: func @reduction
+// CHECK-LABEL: func @reduction_add
 //  CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
 //       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>
 //       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<4xi32>
@@ -264,23 +264,119 @@ func.func @shuffle(%v0 : vector<2x16xf32>, %v1: vector<1x16xf32>) -> vector<3x16
 //       CHECK:   %[[ADD1:.+]] = spv.IAdd %[[ADD0]], %[[S2]]
 //       CHECK:   %[[ADD2:.+]] = spv.IAdd %[[ADD1]], %[[S3]]
 //       CHECK:   return %[[ADD2]]
-func.func @reduction(%v : vector<4xi32>) -> i32 {
+func.func @reduction_add(%v : vector<4xi32>) -> i32 {
   %reduce = vector.reduction <add>, %v : vector<4xi32> into i32
   return %reduce : i32
 }
 
 // -----
 
-// CHECK-LABEL: func @reduction
+// CHECK-LABEL: func @reduction_mul
 //  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
 //       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
 //       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
 //       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
-//       CHECK:   %[[ADD0:.+]] = spv.FMul %[[S0]], %[[S1]]
-//       CHECK:   %[[ADD1:.+]] = spv.FMul %[[ADD0]], %[[S2]]
-//       CHECK:   %[[ADD2:.+]] = spv.FMul %[[ADD1]], %[[S]]
-//       CHECK:   return %[[ADD2]]
-func.func @reduction(%v : vector<3xf32>, %s: f32) -> f32 {
+//       CHECK:   %[[MUL0:.+]] = spv.FMul %[[S0]], %[[S1]]
+//       CHECK:   %[[MUL1:.+]] = spv.FMul %[[MUL0]], %[[S2]]
+//       CHECK:   %[[MUL2:.+]] = spv.FMul %[[MUL1]], %[[S]]
+//       CHECK:   return %[[MUL2]]
+func.func @reduction_mul(%v : vector<3xf32>, %s: f32) -> f32 {
   %reduce = vector.reduction <mul>, %v, %s : vector<3xf32> into f32
   return %reduce : f32
 }
+
+// -----
+
+// CHECK-LABEL: func @reduction_maxf
+//  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+//       CHECK:   %[[MAX0:.+]] = spv.GL.FMax %[[S0]], %[[S1]]
+//       CHECK:   %[[MAX1:.+]] = spv.GL.FMax %[[MAX0]], %[[S2]]
+//       CHECK:   %[[MAX2:.+]] = spv.GL.FMax %[[MAX1]], %[[S]]
+//       CHECK:   return %[[MAX2]]
+func.func @reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <maxf>, %v, %s : vector<3xf32> into f32
+  return %reduce : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_minf
+//  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+//       CHECK:   %[[MIN0:.+]] = spv.GL.FMin %[[S0]], %[[S1]]
+//       CHECK:   %[[MIN1:.+]] = spv.GL.FMin %[[MIN0]], %[[S2]]
+//       CHECK:   %[[MIN2:.+]] = spv.GL.FMin %[[MIN1]], %[[S]]
+//       CHECK:   return %[[MIN2]]
+func.func @reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <minf>, %v, %s : vector<3xf32> into f32
+  return %reduce : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_maxsi
+//  CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+//       CHECK:   %[[MAX0:.+]] = spv.GL.SMax %[[S0]], %[[S1]]
+//       CHECK:   %[[MAX1:.+]] = spv.GL.SMax %[[MAX0]], %[[S2]]
+//       CHECK:   %[[MAX2:.+]] = spv.GL.SMax %[[MAX1]], %[[S]]
+//       CHECK:   return %[[MAX2]]
+func.func @reduction_maxsi(%v : vector<3xi32>, %s: i32) -> i32 {
+  %reduce = vector.reduction <maxsi>, %v, %s : vector<3xi32> into i32
+  return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_minsi
+//  CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+//       CHECK:   %[[MIN0:.+]] = spv.GL.SMin %[[S0]], %[[S1]]
+//       CHECK:   %[[MIN1:.+]] = spv.GL.SMin %[[MIN0]], %[[S2]]
+//       CHECK:   %[[MIN2:.+]] = spv.GL.SMin %[[MIN1]], %[[S]]
+//       CHECK:   return %[[MIN2]]
+func.func @reduction_minsi(%v : vector<3xi32>, %s: i32) -> i32 {
+  %reduce = vector.reduction <minsi>, %v, %s : vector<3xi32> into i32
+  return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_maxui
+//  CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+//       CHECK:   %[[MAX0:.+]] = spv.GL.UMax %[[S0]], %[[S1]]
+//       CHECK:   %[[MAX1:.+]] = spv.GL.UMax %[[MAX0]], %[[S2]]
+//       CHECK:   %[[MAX2:.+]] = spv.GL.UMax %[[MAX1]], %[[S]]
+//       CHECK:   return %[[MAX2]]
+func.func @reduction_maxui(%v : vector<3xi32>, %s: i32) -> i32 {
+  %reduce = vector.reduction <maxui>, %v, %s : vector<3xi32> into i32
+  return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_minui
+//  CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+//       CHECK:   %[[MIN0:.+]] = spv.GL.UMin %[[S0]], %[[S1]]
+//       CHECK:   %[[MIN1:.+]] = spv.GL.UMin %[[MIN0]], %[[S2]]
+//       CHECK:   %[[MIN2:.+]] = spv.GL.UMin %[[MIN1]], %[[S]]
+//       CHECK:   return %[[MIN2]]
+func.func @reduction_minui(%v : vector<3xi32>, %s: i32) -> i32 {
+  %reduce = vector.reduction <minui>, %v, %s : vector<3xi32> into i32
+  return %reduce : i32
+}


        


More information about the Mlir-commits mailing list