[Mlir-commits] [mlir] [mlir][spir] Add floating point dot product (PR #73466)

Jakub Kuderski llvmlistbot at llvm.org
Sun Nov 26 14:40:13 PST 2023


https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/73466

Because `OpDot` does not require any extra capabilities or extensions, enable it by default in the vector to spirv conversion.

>From f64b841f786489d33acc80a7ed4a6fdafaf55165 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 26 Nov 2023 17:38:22 -0500
Subject: [PATCH] [mlir][spir] Add floating point dot product

Because `OpDot` does not require any extra capabilities or extensions,
enable it by default in the vector to spirv conversion.
---
 .../Dialect/SPIRV/IR/SPIRVArithmeticOps.td    | 34 +++++++++++++++++
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        | 10 +++--
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 38 +++++++++++++++++--
 .../VectorToSPIRV/vector-to-spirv.mlir        | 25 ++++++++++++
 .../test/Dialect/SPIRV/IR/arithmetic-ops.mlir | 36 ++++++++++++++++++
 5 files changed, 137 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index c4d1e01f9feef83..3e90775790e6aac 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -503,6 +503,40 @@ def SPIRV_ISubBorrowOp : SPIRV_ArithmeticExtendedBinaryOp<"ISubBorrow",
 
 // -----
 
+def SPIRV_DotOp : SPIRV_Op<"Dot",
+                    [Pure, AllTypesMatch<["vector1", "vector2"]>,
+                     AllElementTypesMatch<["vector1", "result"]>]> {
+  let summary = "Dot product of Vector 1 and Vector 2";
+
+  let description = [{
+    Result Type must be a floating point scalar.
+
+    Vector 1 and Vector 2 must be vectors of the same type, and their component
+    type must be Result Type.
+
+    #### Example:
+
+    ```mlir
+    %0 = spirv.Dot %v1, %v2 : vector<4xf32> -> f32
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_VectorOf<SPIRV_Float>:$vector1,
+    SPIRV_VectorOf<SPIRV_Float>:$vector2
+  );
+
+  let results = (outs
+    SPIRV_Float:$result
+  );
+
+  let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";
+
+  let hasVerifier = 0;
+}
+
+// -----
+
 def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
                                         SPIRV_Integer,
                                         [UsableInSpecConstantOp]> {
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 8eaf2a98a58560e..f315da356e0d2c6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4205,11 +4205,14 @@ class SPIRV_JointMatrixOfType<list<Type> allowedTypes> :
     "::llvm::cast<::mlir::spirv::JointMatrixINTELType>($_self).getElementType()",
     "Joint Matrix">;
 
+class SPIRV_VectorOf<Type type> :
+    VectorOfLengthAndType<[2, 3, 4, 8,16], [type]>;
+
 class SPIRV_ScalarOrVectorOf<Type type> :
-    AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>;
+    AnyTypeOf<[type, SPIRV_VectorOf<type>]>;
 
 class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
-    AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
+    AnyTypeOf<[type, SPIRV_VectorOf<type>,
                SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>;
 
 class SPIRV_MatrixOrCoopMatrixOf<Type type> :
@@ -4357,6 +4360,7 @@ def SPIRV_OC_OpFMod                       : I32EnumAttrCase<"OpFMod", 141>;
 def SPIRV_OC_OpVectorTimesScalar          : I32EnumAttrCase<"OpVectorTimesScalar", 142>;
 def SPIRV_OC_OpMatrixTimesScalar          : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
 def SPIRV_OC_OpMatrixTimesMatrix          : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
+def SPIRV_OC_OpDot                        : I32EnumAttrCase<"OpDot", 148>;
 def SPIRV_OC_OpIAddCarry                  : I32EnumAttrCase<"OpIAddCarry", 149>;
 def SPIRV_OC_OpISubBorrow                 : I32EnumAttrCase<"OpISubBorrow", 150>;
 def SPIRV_OC_OpUMulExtended               : I32EnumAttrCase<"OpUMulExtended", 151>;
@@ -4526,7 +4530,7 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpFSub, SPIRV_OC_OpIMul, SPIRV_OC_OpFMul, SPIRV_OC_OpUDiv,
       SPIRV_OC_OpSDiv, SPIRV_OC_OpFDiv, SPIRV_OC_OpUMod, SPIRV_OC_OpSRem,
       SPIRV_OC_OpSMod, SPIRV_OC_OpFRem, SPIRV_OC_OpFMod,
-      SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar,
+      SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar, SPIRV_OC_OpDot,
       SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpIAddCarry, SPIRV_OC_OpISubBorrow,
       SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended, SPIRV_OC_OpIsNan,
       SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered,
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 05ef535dde4b5c7..35b0b17d2d50787 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -646,7 +646,8 @@ struct VectorStoreOpConverter final
   }
 };
 
-struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
+struct VectorReductionToIntDotProd final
+    : OpRewritePattern<vector::ReductionOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::ReductionOp op,
@@ -740,6 +741,36 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
   }
 };
 
+struct VectorReductionToFPDotProd final
+    : OpConversionPattern<vector::ReductionOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (op.getKind() != vector::CombiningKind::ADD)
+      return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
+
+    auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
+    if (!resultType)
+      return rewriter.notifyMatchFailure(op, "result is not a float");
+
+    auto mul = adaptor.getVector().getDefiningOp<arith::MulFOp>();
+    if (!mul)
+      return rewriter.notifyMatchFailure(
+          op, "reduction operand is not 'arith.mulf'");
+
+    Location loc = op.getLoc();
+    Value res = rewriter.create<spirv::DotOp>(loc, resultType, mul.getLhs(),
+                                              mul.getRhs());
+    if (op.getAcc())
+      res = rewriter.create<spirv::FAddOp>(loc, adaptor.getAcc(), res);
+
+    rewriter.replaceOp(op, res);
+    return success();
+  }
+};
+
 } // namespace
 #define CL_INT_MAX_MIN_OPS                                                     \
   spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -760,7 +791,8 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
       VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
-      VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
+      VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>,
+      VectorReductionToFPDotProd, VectorShapeCast,
       VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
       VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
       typeConverter, patterns.getContext());
@@ -768,5 +800,5 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
 
 void mlir::populateVectorReductionToSPIRVDotProductPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<VectorReductionToDotProd>(patterns.getContext());
+  patterns.add<VectorReductionToIntDotProd>(patterns.getContext());
 }
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 75b2822a8527363..d8585d59770bfdc 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -500,6 +500,31 @@ func.func @reduction_add(%v : vector<4xi32>) -> i32 {
 
 // -----
 
+// CHECK-LABEL: func @reduction_addf
+//  CHECK-SAME:  (%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>)
+//  CHECK:       %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ARG1]] : vector<4xf32> -> f32
+//  CHECK:       return %[[DOT]] : f32
+func.func @reduction_addf(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+  %mul = arith.mulf %arg0, %arg1 : vector<4xf32>
+  %red = vector.reduction <add>, %mul : vector<4xf32> into f32
+  return %red : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_addf_acc
+//  CHECK-SAME:  (%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>, %[[ACC:.+]]: f32)
+//  CHECK:       %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ARG1]] : vector<4xf32> -> f32
+//  CHECK:       %[[RES:.+]] = spirv.FAdd %[[ACC]], %[[DOT]] : f32
+//  CHECK:       return %[[RES]] : f32
+func.func @reduction_addf_acc(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %acc: f32) -> f32 {
+  %mul = arith.mulf %arg0, %arg1 : vector<4xf32>
+  %red = vector.reduction <add>, %mul, %acc : vector<4xf32> into f32
+  return %red : f32
+}
+
+// -----
+
 // CHECK-LABEL: func @reduction_mul
 //  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
 //       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
index 9617204d3419c4d..2d0c86e08de5ac4 100644
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
@@ -254,6 +254,42 @@ func.func @isub_borrow(%arg: i64) -> !spirv.struct<(i32, i32)> {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.Dot
+//===----------------------------------------------------------------------===//
+
+func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+  %0 = spirv.Dot %arg0, %arg1 : vector<4xf32> -> f32
+  return %0 : f32
+}
+
+// -----
+
+// expected-note @+1 {{prior use here}}
+func.func @dot(%arg0: vector<4xf32>, %arg1: vector<3xf32>) -> f32 {
+  // expected-error @+1 {{use of value '%arg1' expects different type than prior uses}}
+  %0 = spirv.Dot %arg0, %arg1 : vector<4xf32> -> f32
+  return %0 : f32
+}
+
+// -----
+
+func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
+  // expected-error @+1 {{'spirv.Dot' op failed to verify that all of {vector1, result} have same element type}}
+  %0 = spirv.Dot %arg0, %arg1 : vector<4xf32> -> f16
+  return %0 : f16
+}
+
+// -----
+
+func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
+  // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+  %0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
+  return %0 : i32
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.SMulExtended
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list