[Mlir-commits] [mlir] [mlir][spir] Add floating point dot product (PR #73466)
Jakub Kuderski
llvmlistbot at llvm.org
Sun Nov 26 15:22:06 PST 2023
https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/73466
>From f64b841f786489d33acc80a7ed4a6fdafaf55165 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 26 Nov 2023 17:38:22 -0500
Subject: [PATCH 1/2] [mlir][spir] Add floating point dot product
Because `OpDot` does not require any extra capabilities or extensions,
enable it by default in the vector to spirv conversion.
---
.../Dialect/SPIRV/IR/SPIRVArithmeticOps.td | 34 +++++++++++++++++
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 10 +++--
.../VectorToSPIRV/VectorToSPIRV.cpp | 38 +++++++++++++++++--
.../VectorToSPIRV/vector-to-spirv.mlir | 25 ++++++++++++
.../test/Dialect/SPIRV/IR/arithmetic-ops.mlir | 36 ++++++++++++++++++
5 files changed, 137 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index c4d1e01f9feef83..3e90775790e6aac 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -503,6 +503,40 @@ def SPIRV_ISubBorrowOp : SPIRV_ArithmeticExtendedBinaryOp<"ISubBorrow",
// -----
+def SPIRV_DotOp : SPIRV_Op<"Dot",
+ [Pure, AllTypesMatch<["vector1", "vector2"]>,
+ AllElementTypesMatch<["vector1", "result"]>]> {
+ let summary = "Dot product of Vector 1 and Vector 2";
+
+ let description = [{
+ Result Type must be a floating point scalar.
+
+ Vector 1 and Vector 2 must be vectors of the same type, and their component
+ type must be Result Type.
+
+ #### Example:
+
+ ```mlir
+ %0 = spirv.Dot %v1, %v2 : vector<4xf32> -> f32
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_VectorOf<SPIRV_Float>:$vector1,
+ SPIRV_VectorOf<SPIRV_Float>:$vector2
+ );
+
+ let results = (outs
+ SPIRV_Float:$result
+ );
+
+ let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";
+
+ let hasVerifier = 0;
+}
+
+// -----
+
def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
SPIRV_Integer,
[UsableInSpecConstantOp]> {
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 8eaf2a98a58560e..f315da356e0d2c6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4205,11 +4205,14 @@ class SPIRV_JointMatrixOfType<list<Type> allowedTypes> :
"::llvm::cast<::mlir::spirv::JointMatrixINTELType>($_self).getElementType()",
"Joint Matrix">;
+class SPIRV_VectorOf<Type type> :
+ VectorOfLengthAndType<[2, 3, 4, 8,16], [type]>;
+
class SPIRV_ScalarOrVectorOf<Type type> :
- AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>;
+ AnyTypeOf<[type, SPIRV_VectorOf<type>]>;
class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
- AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
+ AnyTypeOf<[type, SPIRV_VectorOf<type>,
SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>;
class SPIRV_MatrixOrCoopMatrixOf<Type type> :
@@ -4357,6 +4360,7 @@ def SPIRV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
def SPIRV_OC_OpVectorTimesScalar : I32EnumAttrCase<"OpVectorTimesScalar", 142>;
def SPIRV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
def SPIRV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
+def SPIRV_OC_OpDot : I32EnumAttrCase<"OpDot", 148>;
def SPIRV_OC_OpIAddCarry : I32EnumAttrCase<"OpIAddCarry", 149>;
def SPIRV_OC_OpISubBorrow : I32EnumAttrCase<"OpISubBorrow", 150>;
def SPIRV_OC_OpUMulExtended : I32EnumAttrCase<"OpUMulExtended", 151>;
@@ -4526,7 +4530,7 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpFSub, SPIRV_OC_OpIMul, SPIRV_OC_OpFMul, SPIRV_OC_OpUDiv,
SPIRV_OC_OpSDiv, SPIRV_OC_OpFDiv, SPIRV_OC_OpUMod, SPIRV_OC_OpSRem,
SPIRV_OC_OpSMod, SPIRV_OC_OpFRem, SPIRV_OC_OpFMod,
- SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar,
+ SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar, SPIRV_OC_OpDot,
SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpIAddCarry, SPIRV_OC_OpISubBorrow,
SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended, SPIRV_OC_OpIsNan,
SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered,
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 05ef535dde4b5c7..35b0b17d2d50787 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -646,7 +646,8 @@ struct VectorStoreOpConverter final
}
};
-struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
+struct VectorReductionToIntDotProd final
+ : OpRewritePattern<vector::ReductionOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ReductionOp op,
@@ -740,6 +741,36 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
}
};
+struct VectorReductionToFPDotProd final
+ : OpConversionPattern<vector::ReductionOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (op.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
+
+ auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
+ if (!resultType)
+ return rewriter.notifyMatchFailure(op, "result is not a float");
+
+ auto mul = adaptor.getVector().getDefiningOp<arith::MulFOp>();
+ if (!mul)
+ return rewriter.notifyMatchFailure(
+ op, "reduction operand is not 'arith.mulf'");
+
+ Location loc = op.getLoc();
+ Value res = rewriter.create<spirv::DotOp>(loc, resultType, mul.getLhs(),
+ mul.getRhs());
+ if (op.getAcc())
+ res = rewriter.create<spirv::FAddOp>(loc, adaptor.getAcc(), res);
+
+ rewriter.replaceOp(op, res);
+ return success();
+ }
+};
+
} // namespace
#define CL_INT_MAX_MIN_OPS \
spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -760,7 +791,8 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
- VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
+ VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>,
+ VectorReductionToFPDotProd, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
typeConverter, patterns.getContext());
@@ -768,5 +800,5 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void mlir::populateVectorReductionToSPIRVDotProductPatterns(
RewritePatternSet &patterns) {
- patterns.add<VectorReductionToDotProd>(patterns.getContext());
+ patterns.add<VectorReductionToIntDotProd>(patterns.getContext());
}
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 75b2822a8527363..d8585d59770bfdc 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -500,6 +500,31 @@ func.func @reduction_add(%v : vector<4xi32>) -> i32 {
// -----
+// CHECK-LABEL: func @reduction_addf
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>)
+// CHECK: %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ARG1]] : vector<4xf32> -> f32
+// CHECK: return %[[DOT]] : f32
+func.func @reduction_addf(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+ %mul = arith.mulf %arg0, %arg1 : vector<4xf32>
+ %red = vector.reduction <add>, %mul : vector<4xf32> into f32
+ return %red : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_addf_acc
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>, %[[ACC:.+]]: f32)
+// CHECK: %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ARG1]] : vector<4xf32> -> f32
+// CHECK: %[[RES:.+]] = spirv.FAdd %[[ACC]], %[[DOT]] : f32
+// CHECK: return %[[RES]] : f32
+func.func @reduction_addf_acc(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %acc: f32) -> f32 {
+ %mul = arith.mulf %arg0, %arg1 : vector<4xf32>
+ %red = vector.reduction <add>, %mul, %acc : vector<4xf32> into f32
+ return %red : f32
+}
+
+// -----
+
// CHECK-LABEL: func @reduction_mul
// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
index 9617204d3419c4d..2d0c86e08de5ac4 100644
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
@@ -254,6 +254,42 @@ func.func @isub_borrow(%arg: i64) -> !spirv.struct<(i32, i32)> {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.Dot
+//===----------------------------------------------------------------------===//
+
+func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+ %0 = spirv.Dot %arg0, %arg1 : vector<4xf32> -> f32
+ return %0 : f32
+}
+
+// -----
+
+// expected-note @+1 {{prior use here}}
+func.func @dot(%arg0: vector<4xf32>, %arg1: vector<3xf32>) -> f32 {
+ // expected-error @+1 {{use of value '%arg1' expects different type than prior uses}}
+ %0 = spirv.Dot %arg0, %arg1 : vector<4xf32> -> f32
+ return %0 : f32
+}
+
+// -----
+
+func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
+ // expected-error @+1 {{'spirv.Dot' op failed to verify that all of {vector1, result} have same element type}}
+ %0 = spirv.Dot %arg0, %arg1 : vector<4xf32> -> f16
+ return %0 : f16
+}
+
+// -----
+
+func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
+ // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+ %0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
+ return %0 : i32
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.SMulExtended
//===----------------------------------------------------------------------===//
>From fd25a163c53fabf3811874ef2da20217954a0180 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 26 Nov 2023 18:21:50 -0500
Subject: [PATCH 2/2] Apply pattern benefit
---
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 10 +++++++---
1 file changed, 7 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 35b0b17d2d50787..76fa43062d10b82 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -791,11 +791,15 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
- VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>,
- VectorReductionToFPDotProd, VectorShapeCast,
+ VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
- typeConverter, patterns.getContext());
+ typeConverter, patterns.getContext(), PatternBenefit(1));
+
+ // Make sure that the more specialized dot produce pattern has higher benefit
+ // than the generic one that extracts all elements.
+ patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
+ PatternBenefit(2));
}
void mlir::populateVectorReductionToSPIRVDotProductPatterns(
More information about the Mlir-commits
mailing list