[Mlir-commits] [mlir] d343a39 - [mlir][spirv][vector] Enable vector3 when converting to dot product
Lei Zhang
llvmlistbot at llvm.org
Tue Apr 18 13:59:20 PDT 2023
Author: Lei Zhang
Date: 2023-04-18T13:57:44-07:00
New Revision: d343a395431f70f63d66ef31cb69c8c4babdb21f
URL: https://github.com/llvm/llvm-project/commit/d343a395431f70f63d66ef31cb69c8c4babdb21f
DIFF: https://github.com/llvm/llvm-project/commit/d343a395431f70f63d66ef31cb69c8c4babdb21f.diff
LOG: [mlir][spirv][vector] Enable vector3 when converting to dot product
It's common to see such cases for contraction from convolution with
input channel as 3. Although we aren't utilizing all 4 lanes for
dot product, it should still be better than performing the multiply
and reduction separately.
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D148642
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 20c52f536a23f..dae90f26199a4 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -457,8 +457,8 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
return rewriter.notifyMatchFailure(op, "unsupported integer bitwidth");
VectorType inVecTy = op.getSourceVectorType();
- if (inVecTy.getNumElements() != 4 || inVecTy.getShape().size() != 1 ||
- inVecTy.isScalable())
+ if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
+ inVecTy.getShape().size() != 1 || inVecTy.isScalable())
return rewriter.notifyMatchFailure(op, "unsupported vector shape");
auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
@@ -491,15 +491,31 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul,
PatternRewriter &rewriter) {
auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>();
- if (!lhs || !getElementTypeOrSelf(lhs.getIn().getType()).isInteger(8))
+ if (!lhs)
+ return failure();
+ Value lhsIn = lhs.getIn();
+ auto lhsInType = cast<VectorType>(lhsIn.getType());
+ if (!lhsInType.getElementType().isInteger(8))
return failure();
auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>();
- if (!rhs || !getElementTypeOrSelf(rhs.getIn().getType()).isInteger(8))
+ if (!rhs)
return failure();
-
- Value lhsIn = lhs.getIn();
Value rhsIn = rhs.getIn();
+ auto rhsInType = cast<VectorType>(rhsIn.getType());
+ if (!rhsInType.getElementType().isInteger(8))
+ return failure();
+
+ if (op.getSourceVectorType().getNumElements() == 3) {
+ IntegerType i8Type = rewriter.getI8Type();
+ auto v4i8Type = VectorType::get({4}, i8Type);
+ Location loc = op.getLoc();
+ Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter);
+ lhsIn = rewriter.create<spirv::CompositeConstructOp>(
+ loc, v4i8Type, ValueRange{lhsIn, zero});
+ rhsIn = rewriter.create<spirv::CompositeConstructOp>(
+ loc, v4i8Type, ValueRange{rhsIn, zero});
+ }
// There's no variant of dot prod ops for unsigned LHS and signed RHS, so
// we have to swap operands instead in that case.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir
index bfe6d8608a99d..e13a51733ec1e 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir
@@ -123,18 +123,34 @@ func.func @to_unsigned_signed_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>,
return %red : i32
}
+// CHECK-LABEL: func.func @to_sdot_vector3
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi8>, %[[ARG1:.+]]: vector<3xi8>)
+// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i8
+// CHECK: %[[LHS:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ZERO]] : (vector<3xi8>, i8) -> vector<4xi8>
+// CHECK: %[[RHS:.+]] = spirv.CompositeConstruct %[[ARG1]], %[[ZERO]] : (vector<3xi8>, i8) -> vector<4xi8>
+// CHECK: %[[SDOT:.+]] = spirv.SDot %[[LHS]], %[[RHS]] : (vector<4xi8>, vector<4xi8>) -> i32
+// CHECK: return %[[SDOT]]
+func.func @to_sdot_vector3(%arg0: vector<3xi8>, %arg1: vector<3xi8>) -> i32 {
+ %lhs = arith.extsi %arg0 : vector<3xi8> to vector<3xi32>
+ %rhs = arith.extsi %arg1 : vector<3xi8> to vector<3xi32>
+ %mul = arith.muli %lhs, %rhs : vector<3xi32>
+ %red = vector.reduction <add>, %mul : vector<3xi32> into i32
+ return %red : i32
+}
+
// -----
+
// Negative tests.
// CHECK-LABEL: func.func @too_short
-// CHECK-SAME: ([[ARG0:%.+]]: vector<3xi8>, [[ARG1:%.+]]: vector<3xi8>)
+// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi8>, [[ARG1:%.+]]: vector<2xi8>)
// CHECK: [[RED:%.+]] = vector.reduction
// CHECK-NEXT: return [[RED]] : i32
-func.func @too_short(%arg0: vector<3xi8>, %arg1: vector<3xi8>) -> i32 {
- %lhs = arith.extsi %arg0 : vector<3xi8> to vector<3xi32>
- %rhs = arith.extsi %arg1 : vector<3xi8> to vector<3xi32>
- %mul = arith.muli %lhs, %rhs : vector<3xi32>
- %red = vector.reduction <add>, %mul : vector<3xi32> into i32
+func.func @too_short(%arg0: vector<2xi8>, %arg1: vector<2xi8>) -> i32 {
+ %lhs = arith.extsi %arg0 : vector<2xi8> to vector<2xi32>
+ %rhs = arith.extsi %arg1 : vector<2xi8> to vector<2xi32>
+ %mul = arith.muli %lhs, %rhs : vector<2xi32>
+ %red = vector.reduction <add>, %mul : vector<2xi32> into i32
return %red : i32
}
More information about the Mlir-commits
mailing list