[Mlir-commits] [mlir] 6b9c186 - [mlir][spirv] Handle non-innerprod float vector add reductions (#73476)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 27 08:10:29 PST 2023


Author: Jakub Kuderski
Date: 2023-11-27T11:10:25-05:00
New Revision: 6b9c186b2d4c8bd315034a9655a28d32bcf745ab

URL: https://github.com/llvm/llvm-project/commit/6b9c186b2d4c8bd315034a9655a28d32bcf745ab
DIFF: https://github.com/llvm/llvm-project/commit/6b9c186b2d4c8bd315034a9655a28d32bcf745ab.diff

LOG: [mlir][spirv] Handle non-innerprod float vector add reductions (#73476)

Instead of extracting all individual vector components and performing a
scalar summation, use `spirv.Dot` with the original reduction operand
and a vector constant of all ones.

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index ade41b0372c82f1..e48f29a4f170290 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Location.h"
@@ -755,16 +756,43 @@ struct VectorReductionToFPDotProd final
     if (!resultType)
       return rewriter.notifyMatchFailure(op, "result is not a float");
 
-    auto mul = adaptor.getVector().getDefiningOp<arith::MulFOp>();
-    if (!mul)
-      return rewriter.notifyMatchFailure(
-          op, "reduction operand is not 'arith.mulf'");
+    Value vec = adaptor.getVector();
+    Value acc = adaptor.getAcc();
+
+    auto vectorType = dyn_cast<VectorType>(vec.getType());
+    if (!vectorType) {
+      assert(isa<FloatType>(vec.getType()) &&
+             "Expected the vector to be scalarized");
+      if (acc) {
+        rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, acc, vec);
+        return success();
+      }
+
+      rewriter.replaceOp(op, vec);
+      return success();
+    }
 
     Location loc = op.getLoc();
-    Value res = rewriter.create<spirv::DotOp>(loc, resultType, mul.getLhs(),
-                                              mul.getRhs());
-    if (op.getAcc())
-      res = rewriter.create<spirv::FAddOp>(loc, adaptor.getAcc(), res);
+    Value lhs;
+    Value rhs;
+    if (auto mul = vec.getDefiningOp<arith::MulFOp>()) {
+      lhs = mul.getLhs();
+      rhs = mul.getRhs();
+    } else {
+      // If the operand is not a mul, use a vector of ones for the dot operand
+      // to just sum up all values.
+      lhs = vec;
+      Attribute oneAttr =
+          rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
+      oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
+      rhs = rewriter.create<spirv::ConstantOp>(loc, vectorType, oneAttr);
+    }
+    assert(lhs);
+    assert(rhs);
+
+    Value res = rewriter.create<spirv::DotOp>(loc, resultType, lhs, rhs);
+    if (acc)
+      res = rewriter.create<spirv::FAddOp>(loc, acc, res);
 
     rewriter.replaceOp(op, res);
     return success();

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index d8585d59770bfdc..c9984091d5acc6a 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -500,11 +500,11 @@ func.func @reduction_add(%v : vector<4xi32>) -> i32 {
 
 // -----
 
-// CHECK-LABEL: func @reduction_addf
+// CHECK-LABEL: func @reduction_addf_mulf
 //  CHECK-SAME:  (%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>)
 //  CHECK:       %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ARG1]] : vector<4xf32> -> f32
 //  CHECK:       return %[[DOT]] : f32
-func.func @reduction_addf(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+func.func @reduction_addf_mulf(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
   %mul = arith.mulf %arg0, %arg1 : vector<4xf32>
   %red = vector.reduction <add>, %mul : vector<4xf32> into f32
   return %red : f32
@@ -512,12 +512,12 @@ func.func @reduction_addf(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
 
 // -----
 
-// CHECK-LABEL: func @reduction_addf_acc
+// CHECK-LABEL: func @reduction_addf_acc_mulf
 //  CHECK-SAME:  (%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>, %[[ACC:.+]]: f32)
 //  CHECK:       %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ARG1]] : vector<4xf32> -> f32
 //  CHECK:       %[[RES:.+]] = spirv.FAdd %[[ACC]], %[[DOT]] : f32
 //  CHECK:       return %[[RES]] : f32
-func.func @reduction_addf_acc(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %acc: f32) -> f32 {
+func.func @reduction_addf_acc_mulf(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %acc: f32) -> f32 {
   %mul = arith.mulf %arg0, %arg1 : vector<4xf32>
   %red = vector.reduction <add>, %mul, %acc : vector<4xf32> into f32
   return %red : f32
@@ -525,6 +525,54 @@ func.func @reduction_addf_acc(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %acc:
 
 // -----
 
+// CHECK-LABEL: func @reduction_addf
+//  CHECK-SAME:  (%[[ARG0:.+]]: vector<4xf32>)
+//  CHECK:       %[[ONE:.+]] = spirv.Constant dense<1.0{{.+}}> : vector<4xf32>
+//  CHECK:       %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ONE]] : vector<4xf32> -> f32
+//  CHECK:       return %[[DOT]] : f32
+func.func @reduction_addf_mulf(%arg0: vector<4xf32>) -> f32 {
+  %red = vector.reduction <add>, %arg0 : vector<4xf32> into f32
+  return %red : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_addf_acc
+//  CHECK-SAME:  (%[[ARG0:.+]]: vector<4xf32>, %[[ACC:.+]]: f32)
+//  CHECK:       %[[ONE:.+]] = spirv.Constant dense<1.0{{.*}}> : vector<4xf32>
+//  CHECK:       %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ONE]] : vector<4xf32> -> f32
+//  CHECK:       %[[RES:.+]] = spirv.FAdd %[[ACC]], %[[DOT]] : f32
+//  CHECK:       return %[[RES]] : f32
+func.func @reduction_addf_acc(%arg0: vector<4xf32>, %acc: f32) -> f32 {
+  %red = vector.reduction <add>, %arg0, %acc : vector<4xf32> into f32
+  return %red : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_addf_one_elem
+//  CHECK-SAME:  (%[[ARG0:.+]]: vector<1xf32>)
+//  CHECK:       %[[RES:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<1xf32> to f32
+//  CHECK:       return %[[RES]] : f32
+func.func @reduction_addf_one_elem(%arg0: vector<1xf32>) -> f32 {
+  %red = vector.reduction <add>, %arg0 : vector<1xf32> into f32
+  return %red : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_addf_one_elem_acc
+//  CHECK-SAME:  (%[[ARG0:.+]]: vector<1xf32>, %[[ACC:.+]]: f32)
+//  CHECK:       %[[RHS:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<1xf32> to f32
+//  CHECK:       %[[RES:.+]] = spirv.FAdd %[[ACC]], %[[RHS]] : f32
+//  CHECK:       return %[[RES]] : f32
+func.func @reduction_addf_one_elem_acc(%arg0: vector<1xf32>, %acc: f32) -> f32 {
+  %red = vector.reduction <add>, %arg0, %acc : vector<1xf32> into f32
+  return %red : f32
+}
+
+// -----
+
 // CHECK-LABEL: func @reduction_mul
 //  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
 //       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>


        


More information about the Mlir-commits mailing list