[Mlir-commits] [mlir] [mlir][spir] Add floating point dot product (PR #73466)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Nov 26 14:40:40 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Jakub Kuderski (kuhar)
<details>
<summary>Changes</summary>
Because `OpDot` does not require any extra capabilities or extensions, enable it by default in the vector to spirv conversion.
---
Full diff: https://github.com/llvm/llvm-project/pull/73466.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td (+34)
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+7-3)
- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+35-3)
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+25)
- (modified) mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir (+36)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index c4d1e01f9feef83..3e90775790e6aac 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -503,6 +503,40 @@ def SPIRV_ISubBorrowOp : SPIRV_ArithmeticExtendedBinaryOp<"ISubBorrow",
// -----
+def SPIRV_DotOp : SPIRV_Op<"Dot",
+ [Pure, AllTypesMatch<["vector1", "vector2"]>,
+ AllElementTypesMatch<["vector1", "result"]>]> {
+ let summary = "Dot product of Vector 1 and Vector 2";
+
+ let description = [{
+ Result Type must be a floating point scalar.
+
+ Vector 1 and Vector 2 must be vectors of the same type, and their component
+ type must be Result Type.
+
+ #### Example:
+
+ ```mlir
+ %0 = spirv.Dot %v1, %v2 : vector<4xf32> -> f32
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_VectorOf<SPIRV_Float>:$vector1,
+ SPIRV_VectorOf<SPIRV_Float>:$vector2
+ );
+
+ let results = (outs
+ SPIRV_Float:$result
+ );
+
+ let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";
+
+ let hasVerifier = 0;
+}
+
+// -----
+
def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
SPIRV_Integer,
[UsableInSpecConstantOp]> {
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 8eaf2a98a58560e..f315da356e0d2c6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4205,11 +4205,14 @@ class SPIRV_JointMatrixOfType<list<Type> allowedTypes> :
"::llvm::cast<::mlir::spirv::JointMatrixINTELType>($_self).getElementType()",
"Joint Matrix">;
+class SPIRV_VectorOf<Type type> :
+ VectorOfLengthAndType<[2, 3, 4, 8,16], [type]>;
+
class SPIRV_ScalarOrVectorOf<Type type> :
- AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>;
+ AnyTypeOf<[type, SPIRV_VectorOf<type>]>;
class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
- AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
+ AnyTypeOf<[type, SPIRV_VectorOf<type>,
SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>;
class SPIRV_MatrixOrCoopMatrixOf<Type type> :
@@ -4357,6 +4360,7 @@ def SPIRV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
def SPIRV_OC_OpVectorTimesScalar : I32EnumAttrCase<"OpVectorTimesScalar", 142>;
def SPIRV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
def SPIRV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
+def SPIRV_OC_OpDot : I32EnumAttrCase<"OpDot", 148>;
def SPIRV_OC_OpIAddCarry : I32EnumAttrCase<"OpIAddCarry", 149>;
def SPIRV_OC_OpISubBorrow : I32EnumAttrCase<"OpISubBorrow", 150>;
def SPIRV_OC_OpUMulExtended : I32EnumAttrCase<"OpUMulExtended", 151>;
@@ -4526,7 +4530,7 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpFSub, SPIRV_OC_OpIMul, SPIRV_OC_OpFMul, SPIRV_OC_OpUDiv,
SPIRV_OC_OpSDiv, SPIRV_OC_OpFDiv, SPIRV_OC_OpUMod, SPIRV_OC_OpSRem,
SPIRV_OC_OpSMod, SPIRV_OC_OpFRem, SPIRV_OC_OpFMod,
- SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar,
+ SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar, SPIRV_OC_OpDot,
SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpIAddCarry, SPIRV_OC_OpISubBorrow,
SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended, SPIRV_OC_OpIsNan,
SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered,
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 05ef535dde4b5c7..35b0b17d2d50787 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -646,7 +646,8 @@ struct VectorStoreOpConverter final
}
};
-struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
+struct VectorReductionToIntDotProd final
+ : OpRewritePattern<vector::ReductionOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ReductionOp op,
@@ -740,6 +741,36 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
}
};
+struct VectorReductionToFPDotProd final
+ : OpConversionPattern<vector::ReductionOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (op.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
+
+ auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
+ if (!resultType)
+ return rewriter.notifyMatchFailure(op, "result is not a float");
+
+ auto mul = adaptor.getVector().getDefiningOp<arith::MulFOp>();
+ if (!mul)
+ return rewriter.notifyMatchFailure(
+ op, "reduction operand is not 'arith.mulf'");
+
+ Location loc = op.getLoc();
+ Value res = rewriter.create<spirv::DotOp>(loc, resultType, mul.getLhs(),
+ mul.getRhs());
+ if (op.getAcc())
+ res = rewriter.create<spirv::FAddOp>(loc, adaptor.getAcc(), res);
+
+ rewriter.replaceOp(op, res);
+ return success();
+ }
+};
+
} // namespace
#define CL_INT_MAX_MIN_OPS \
spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -760,7 +791,8 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
- VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
+ VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>,
+ VectorReductionToFPDotProd, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
typeConverter, patterns.getContext());
@@ -768,5 +800,5 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void mlir::populateVectorReductionToSPIRVDotProductPatterns(
RewritePatternSet &patterns) {
- patterns.add<VectorReductionToDotProd>(patterns.getContext());
+ patterns.add<VectorReductionToIntDotProd>(patterns.getContext());
}
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 75b2822a8527363..d8585d59770bfdc 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -500,6 +500,31 @@ func.func @reduction_add(%v : vector<4xi32>) -> i32 {
// -----
+// CHECK-LABEL: func @reduction_addf
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>)
+// CHECK: %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ARG1]] : vector<4xf32> -> f32
+// CHECK: return %[[DOT]] : f32
+func.func @reduction_addf(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+ %mul = arith.mulf %arg0, %arg1 : vector<4xf32>
+ %red = vector.reduction <add>, %mul : vector<4xf32> into f32
+ return %red : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_addf_acc
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>, %[[ACC:.+]]: f32)
+// CHECK: %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ARG1]] : vector<4xf32> -> f32
+// CHECK: %[[RES:.+]] = spirv.FAdd %[[ACC]], %[[DOT]] : f32
+// CHECK: return %[[RES]] : f32
+func.func @reduction_addf_acc(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %acc: f32) -> f32 {
+ %mul = arith.mulf %arg0, %arg1 : vector<4xf32>
+ %red = vector.reduction <add>, %mul, %acc : vector<4xf32> into f32
+ return %red : f32
+}
+
+// -----
+
// CHECK-LABEL: func @reduction_mul
// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
index 9617204d3419c4d..2d0c86e08de5ac4 100644
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
@@ -254,6 +254,42 @@ func.func @isub_borrow(%arg: i64) -> !spirv.struct<(i32, i32)> {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.Dot
+//===----------------------------------------------------------------------===//
+
+func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+ %0 = spirv.Dot %arg0, %arg1 : vector<4xf32> -> f32
+ return %0 : f32
+}
+
+// -----
+
+// expected-note @+1 {{prior use here}}
+func.func @dot(%arg0: vector<4xf32>, %arg1: vector<3xf32>) -> f32 {
+ // expected-error @+1 {{use of value '%arg1' expects different type than prior uses}}
+ %0 = spirv.Dot %arg0, %arg1 : vector<4xf32> -> f32
+ return %0 : f32
+}
+
+// -----
+
+func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
+ // expected-error @+1 {{'spirv.Dot' op failed to verify that all of {vector1, result} have same element type}}
+ %0 = spirv.Dot %arg0, %arg1 : vector<4xf32> -> f16
+ return %0 : f16
+}
+
+// -----
+
+func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
+ // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+ %0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
+ return %0 : i32
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.SMulExtended
//===----------------------------------------------------------------------===//
``````````
</details>
https://github.com/llvm/llvm-project/pull/73466
More information about the Mlir-commits
mailing list