[Mlir-commits] [mlir] d343a39 - [mlir][spirv][vector] Enable vector3 when converting to dot product

Lei Zhang llvmlistbot at llvm.org
Tue Apr 18 13:59:20 PDT 2023


Author: Lei Zhang
Date: 2023-04-18T13:57:44-07:00
New Revision: d343a395431f70f63d66ef31cb69c8c4babdb21f

URL: https://github.com/llvm/llvm-project/commit/d343a395431f70f63d66ef31cb69c8c4babdb21f
DIFF: https://github.com/llvm/llvm-project/commit/d343a395431f70f63d66ef31cb69c8c4babdb21f.diff

LOG: [mlir][spirv][vector] Enable vector3 when converting to dot product

It's common to see such cases for contraction from convolution with
input channel as 3. Although we aren't utilizing all 4 lanes for
dot product, it should still be better than performing the multiply
and reduction separately.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D148642

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 20c52f536a23f..dae90f26199a4 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -457,8 +457,8 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
       return rewriter.notifyMatchFailure(op, "unsupported integer bitwidth");
 
     VectorType inVecTy = op.getSourceVectorType();
-    if (inVecTy.getNumElements() != 4 || inVecTy.getShape().size() != 1 ||
-        inVecTy.isScalable())
+    if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
+        inVecTy.getShape().size() != 1 || inVecTy.isScalable())
       return rewriter.notifyMatchFailure(op, "unsupported vector shape");
 
     auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
@@ -491,15 +491,31 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
   static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul,
                                   PatternRewriter &rewriter) {
     auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>();
-    if (!lhs || !getElementTypeOrSelf(lhs.getIn().getType()).isInteger(8))
+    if (!lhs)
+      return failure();
+    Value lhsIn = lhs.getIn();
+    auto lhsInType = cast<VectorType>(lhsIn.getType());
+    if (!lhsInType.getElementType().isInteger(8))
       return failure();
 
     auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>();
-    if (!rhs || !getElementTypeOrSelf(rhs.getIn().getType()).isInteger(8))
+    if (!rhs)
       return failure();
-
-    Value lhsIn = lhs.getIn();
     Value rhsIn = rhs.getIn();
+    auto rhsInType = cast<VectorType>(rhsIn.getType());
+    if (!rhsInType.getElementType().isInteger(8))
+      return failure();
+
+    if (op.getSourceVectorType().getNumElements() == 3) {
+      IntegerType i8Type = rewriter.getI8Type();
+      auto v4i8Type = VectorType::get({4}, i8Type);
+      Location loc = op.getLoc();
+      Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter);
+      lhsIn = rewriter.create<spirv::CompositeConstructOp>(
+          loc, v4i8Type, ValueRange{lhsIn, zero});
+      rhsIn = rewriter.create<spirv::CompositeConstructOp>(
+          loc, v4i8Type, ValueRange{rhsIn, zero});
+    }
 
     // There's no variant of dot prod ops for unsigned LHS and signed RHS, so
     // we have to swap operands instead in that case.

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir
index bfe6d8608a99d..e13a51733ec1e 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir
@@ -123,18 +123,34 @@ func.func @to_unsigned_signed_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>,
   return %red : i32
 }
 
+// CHECK-LABEL: func.func @to_sdot_vector3
+//  CHECK-SAME: (%[[ARG0:.+]]: vector<3xi8>, %[[ARG1:.+]]: vector<3xi8>)
+//       CHECK:   %[[ZERO:.+]] = spirv.Constant 0 : i8
+//       CHECK:   %[[LHS:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ZERO]] : (vector<3xi8>, i8) -> vector<4xi8>
+//       CHECK:   %[[RHS:.+]] = spirv.CompositeConstruct %[[ARG1]], %[[ZERO]] : (vector<3xi8>, i8) -> vector<4xi8>
+//       CHECK:   %[[SDOT:.+]] = spirv.SDot %[[LHS]], %[[RHS]] : (vector<4xi8>, vector<4xi8>) -> i32
+//       CHECK:   return %[[SDOT]]
+func.func @to_sdot_vector3(%arg0: vector<3xi8>, %arg1: vector<3xi8>) -> i32 {
+  %lhs = arith.extsi %arg0 : vector<3xi8> to vector<3xi32>
+  %rhs = arith.extsi %arg1 : vector<3xi8> to vector<3xi32>
+  %mul = arith.muli %lhs, %rhs : vector<3xi32>
+  %red = vector.reduction <add>, %mul : vector<3xi32> into i32
+  return %red : i32
+}
+
 // -----
+
 // Negative tests.
 
 // CHECK-LABEL: func.func @too_short
-//  CHECK-SAME:   ([[ARG0:%.+]]: vector<3xi8>, [[ARG1:%.+]]: vector<3xi8>)
+//  CHECK-SAME:   ([[ARG0:%.+]]: vector<2xi8>, [[ARG1:%.+]]: vector<2xi8>)
 //  CHECK:        [[RED:%.+]] = vector.reduction
 //  CHECK-NEXT:   return [[RED]] : i32
-func.func @too_short(%arg0: vector<3xi8>, %arg1: vector<3xi8>) -> i32 {
-  %lhs = arith.extsi %arg0 : vector<3xi8> to vector<3xi32>
-  %rhs = arith.extsi %arg1 : vector<3xi8> to vector<3xi32>
-  %mul = arith.muli %lhs, %rhs : vector<3xi32>
-  %red = vector.reduction <add>, %mul : vector<3xi32> into i32
+func.func @too_short(%arg0: vector<2xi8>, %arg1: vector<2xi8>) -> i32 {
+  %lhs = arith.extsi %arg0 : vector<2xi8> to vector<2xi32>
+  %rhs = arith.extsi %arg1 : vector<2xi8> to vector<2xi32>
+  %mul = arith.muli %lhs, %rhs : vector<2xi32>
+  %red = vector.reduction <add>, %mul : vector<2xi32> into i32
   return %red : i32
 }
 


        


More information about the Mlir-commits mailing list