[Mlir-commits] [mlir] b194ef6 - [mlir][spirv][vector] Add pattern to convert reduction to SPIR-V dot prod

Jakub Kuderski llvmlistbot at llvm.org
Fri Mar 10 10:56:53 PST 2023


Author: Jakub Kuderski
Date: 2023-03-10T13:54:16-05:00
New Revision: b194ef692cf3965bac141af31a69428b4f6ae2df

URL: https://github.com/llvm/llvm-project/commit/b194ef692cf3965bac141af31a69428b4f6ae2df
DIFF: https://github.com/llvm/llvm-project/commit/b194ef692cf3965bac141af31a69428b4f6ae2df.diff

LOG: [mlir][spirv][vector] Add pattern to convert reduction to SPIR-V dot prod

This converts a specific form of `vector.reduction` to SPIR-V integer
dot product ops.

Add a new test pass to excercise this outside of the main vector to
spirv conversion pass.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D145760

Added: 
    mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir
    mlir/test/lib/Conversion/VectorToSPIRV/CMakeLists.txt
    mlir/test/lib/Conversion/VectorToSPIRV/TestVectorReductionToSPIRVDotProd.cpp

Modified: 
    mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
    mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/lib/Conversion/CMakeLists.txt
    mlir/tools/mlir-opt/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
    utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
index 7b8882b36cd5d..f8c02c54066b8 100644
--- a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
+++ b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
@@ -23,6 +23,15 @@ class SPIRVTypeConverter;
 void populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                    RewritePatternSet &patterns);
 
+/// Appends patterns to convert vector reduction of the form:
+/// ```
+///   vector.reduction <add>, (muli (ext %lhs), (ext %rhs)), [%acc]
+/// ```
+///
+/// to SPIR-V integer dot product ops.
+void populateVectorReductionToSPIRVDotProductPatterns(
+    RewritePatternSet &patterns);
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_VECTORTOSPIRV_VECTORTOSPIRV_H

diff  --git a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
index 7794ce71b024d..bb9f793d7fe0f 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_conversion_library(MLIRVectorToSPIRV
   intrinsics_gen
 
   LINK_LIBS PUBLIC
+  MLIRArithDialect
   MLIRSPIRVDialect
   MLIRSPIRVConversion
   MLIRVectorDialect

diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index a3a3a612ed147..20c52f536a23f 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
 
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
@@ -20,6 +21,9 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
@@ -436,6 +440,84 @@ struct VectorShuffleOpConvert final
   }
 };
 
+struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ReductionOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.getKind() != vector::CombiningKind::ADD)
+      return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
+
+    auto resultType = dyn_cast<IntegerType>(op.getType());
+    if (!resultType)
+      return rewriter.notifyMatchFailure(op, "result is not an integer");
+
+    int64_t resultBitwidth = resultType.getIntOrFloatBitWidth();
+    if (!llvm::is_contained({32, 64}, resultBitwidth))
+      return rewriter.notifyMatchFailure(op, "unsupported integer bitwidth");
+
+    VectorType inVecTy = op.getSourceVectorType();
+    if (inVecTy.getNumElements() != 4 || inVecTy.getShape().size() != 1 ||
+        inVecTy.isScalable())
+      return rewriter.notifyMatchFailure(op, "unsupported vector shape");
+
+    auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
+    if (!mul)
+      return rewriter.notifyMatchFailure(
+          op, "reduction operand is not 'arith.muli'");
+
+    if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
+                             spirv::SDotAccSatOp, false>(op, mul, rewriter)))
+      return success();
+
+    if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
+                             spirv::UDotAccSatOp, false>(op, mul, rewriter)))
+      return success();
+
+    if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
+                             spirv::SUDotAccSatOp, false>(op, mul, rewriter)))
+      return success();
+
+    if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
+                             spirv::SUDotAccSatOp, true>(op, mul, rewriter)))
+      return success();
+
+    return failure();
+  }
+
+private:
+  template <typename LhsExtensionOp, typename RhsExtensionOp, typename DotOp,
+            typename DotAccOp, bool SwapOperands>
+  static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul,
+                                  PatternRewriter &rewriter) {
+    auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>();
+    if (!lhs || !getElementTypeOrSelf(lhs.getIn().getType()).isInteger(8))
+      return failure();
+
+    auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>();
+    if (!rhs || !getElementTypeOrSelf(rhs.getIn().getType()).isInteger(8))
+      return failure();
+
+    Value lhsIn = lhs.getIn();
+    Value rhsIn = rhs.getIn();
+
+    // There's no variant of dot prod ops for unsigned LHS and signed RHS, so
+    // we have to swap operands instead in that case.
+    if (SwapOperands)
+      std::swap(lhsIn, rhsIn);
+
+    if (Value acc = op.getAcc()) {
+      rewriter.replaceOpWithNewOp<DotAccOp>(op, op.getType(), lhsIn, rhsIn, acc,
+                                            nullptr);
+    } else {
+      rewriter.replaceOpWithNewOp<DotOp>(op, op.getType(), lhsIn, rhsIn,
+                                         nullptr);
+    }
+
+    return success();
+  }
+};
+
 } // namespace
 #define CL_MAX_MIN_OPS                                                         \
   spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp,          \
@@ -457,3 +539,8 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorShuffleOpConvert, VectorSplatPattern>(typeConverter,
                                                   patterns.getContext());
 }
+
+void mlir::populateVectorReductionToSPIRVDotProductPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<VectorReductionToDotProd>(patterns.getContext());
+}

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir
new file mode 100644
index 0000000000000..bfe6d8608a99d
--- /dev/null
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir
@@ -0,0 +1,176 @@
+// RUN: mlir-opt --split-input-file --verify-diagnostics \
+// RUN:   --test-vector-reduction-to-spirv-dot-prod %s -o - | FileCheck %s
+
+// Positive tests.
+
+// CHECK-LABEL: func.func @to_sdot
+//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i32
+//  CHECK-NEXT:   return [[DOT]] : i32
+func.func @to_sdot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
+  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
+  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
+  %mul = arith.muli %lhs, %rhs : vector<4xi32>
+  %red = vector.reduction <add>, %mul : vector<4xi32> into i32
+  return %red : i32
+}
+
+// CHECK-LABEL: func.func @to_sdot_acc
+//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
+//  CHECK-NEXT:   return [[DOT]] : i32
+func.func @to_sdot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
+  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
+  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
+  %mul = arith.muli %lhs, %rhs : vector<4xi32>
+  %red = vector.reduction <add>, %mul, %acc : vector<4xi32> into i32
+  return %red : i32
+}
+
+// CHECK-LABEL: func.func @to_sdot_i64
+//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i64
+//  CHECK-NEXT:   return [[DOT]] : i64
+func.func @to_sdot_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i64 {
+  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi64>
+  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi64>
+  %mul = arith.muli %lhs, %rhs : vector<4xi64>
+  %red = vector.reduction <add>, %mul : vector<4xi64> into i64
+  return %red : i64
+}
+
+// CHECK-LABEL: func.func @to_sdot_acc_i64
+//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i64)
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i64) -> i64
+//  CHECK-NEXT:   return [[DOT]] : i64
+func.func @to_sdot_acc_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i64) -> i64 {
+  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi64>
+  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi64>
+  %mul = arith.muli %lhs, %rhs : vector<4xi64>
+  %red = vector.reduction <add>, %mul, %acc : vector<4xi64> into i64
+  return %red : i64
+}
+
+// CHECK-LABEL: func.func @to_udot
+//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.UDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i32
+//  CHECK-NEXT:   return [[DOT]] : i32
+func.func @to_udot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
+  %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
+  %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32>
+  %mul = arith.muli %lhs, %rhs : vector<4xi32>
+  %red = vector.reduction <add>, %mul : vector<4xi32> into i32
+  return %red : i32
+}
+
+// CHECK-LABEL: func.func @to_udot_acc
+//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.UDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
+//  CHECK-NEXT:   return [[DOT]] : i32
+func.func @to_udot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
+  %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
+  %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32>
+  %mul = arith.muli %lhs, %rhs : vector<4xi32>
+  %red = vector.reduction <add>, %mul, %acc : vector<4xi32> into i32
+  return %red : i32
+}
+
+// CHECK-LABEL: func.func @to_signed_unsigned_dot
+//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i32
+//  CHECK-NEXT:   return [[DOT]] : i32
+func.func @to_signed_unsigned_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
+  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
+  %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32>
+  %mul = arith.muli %lhs, %rhs : vector<4xi32>
+  %red = vector.reduction <add>, %mul : vector<4xi32> into i32
+  return %red : i32
+}
+
+// CHECK-LABEL: func.func @to_signed_unsigned_dot_acc
+//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
+//  CHECK-NEXT:   return [[DOT]] : i32
+func.func @to_signed_unsigned_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
+  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
+  %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32>
+  %mul = arith.muli %lhs, %rhs : vector<4xi32>
+  %red = vector.reduction <add>, %mul, %acc : vector<4xi32> into i32
+  return %red : i32
+}
+
+// CHECK-LABEL: func.func @to_unsigned_signed_dot
+//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDot [[ARG1]], [[ARG0]] : (vector<4xi8>, vector<4xi8>) -> i32
+//  CHECK-NEXT:   return [[DOT]] : i32
+func.func @to_unsigned_signed_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
+  %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
+  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
+  %mul = arith.muli %lhs, %rhs : vector<4xi32>
+  %red = vector.reduction <add>, %mul : vector<4xi32> into i32
+  return %red : i32
+}
+
+// CHECK-LABEL: func.func @to_unsigned_signed_dot_acc
+//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
+//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDotAccSat [[ARG1]], [[ARG0]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
+//  CHECK-NEXT:   return [[DOT]] : i32
+func.func @to_unsigned_signed_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
+  %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
+  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
+  %mul = arith.muli %lhs, %rhs : vector<4xi32>
+  %red = vector.reduction <add>, %mul, %acc : vector<4xi32> into i32
+  return %red : i32
+}
+
+// -----
+// Negative tests.
+
+// CHECK-LABEL: func.func @too_short
+//  CHECK-SAME:   ([[ARG0:%.+]]: vector<3xi8>, [[ARG1:%.+]]: vector<3xi8>)
+//  CHECK:        [[RED:%.+]] = vector.reduction
+//  CHECK-NEXT:   return [[RED]] : i32
+func.func @too_short(%arg0: vector<3xi8>, %arg1: vector<3xi8>) -> i32 {
+  %lhs = arith.extsi %arg0 : vector<3xi8> to vector<3xi32>
+  %rhs = arith.extsi %arg1 : vector<3xi8> to vector<3xi32>
+  %mul = arith.muli %lhs, %rhs : vector<3xi32>
+  %red = vector.reduction <add>, %mul : vector<3xi32> into i32
+  return %red : i32
+}
+
+// CHECK-LABEL: func.func @too_long
+//  CHECK-SAME:   ([[ARG0:%.+]]: vector<6xi8>, [[ARG1:%.+]]: vector<6xi8>)
+//  CHECK:        [[RED:%.+]] = vector.reduction
+//  CHECK-NEXT:   return [[RED]] : i32
+func.func @too_long(%arg0: vector<6xi8>, %arg1: vector<6xi8>) -> i32 {
+  %lhs = arith.extsi %arg0 : vector<6xi8> to vector<6xi32>
+  %rhs = arith.extsi %arg1 : vector<6xi8> to vector<6xi32>
+  %mul = arith.muli %lhs, %rhs : vector<6xi32>
+  %red = vector.reduction <add>, %mul : vector<6xi32> into i32
+  return %red : i32
+}
+
+// CHECK-LABEL: func.func @wrong_reduction_kind
+//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
+//  CHECK:        [[RED:%.+]] = vector.reduction <mul>
+//  CHECK-NEXT:   return [[RED]] : i32
+func.func @wrong_reduction_kind(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
+  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
+  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
+  %mul = arith.muli %lhs, %rhs : vector<4xi32>
+  %red = vector.reduction <mul>, %mul : vector<4xi32> into i32
+  return %red : i32
+}
+
+// CHECK-LABEL: func.func @wrong_arith_op
+//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
+//  CHECK:        [[ADD:%.+]] = arith.addi
+//  CHECK:        [[RED:%.+]] = vector.reduction <mul>, [[ADD]]
+//  CHECK-NEXT:   return [[RED]] : i32
+func.func @wrong_arith_op(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
+  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
+  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
+  %add = arith.addi %lhs, %rhs : vector<4xi32>
+  %red = vector.reduction <mul>, %add : vector<4xi32> into i32
+  return %red : i32
+}

diff  --git a/mlir/test/lib/Conversion/CMakeLists.txt b/mlir/test/lib/Conversion/CMakeLists.txt
index bc6103dedd490..14f0e0dbe1802 100644
--- a/mlir/test/lib/Conversion/CMakeLists.txt
+++ b/mlir/test/lib/Conversion/CMakeLists.txt
@@ -1 +1,2 @@
 add_subdirectory(FuncToLLVM)
+add_subdirectory(VectorToSPIRV)

diff  --git a/mlir/test/lib/Conversion/VectorToSPIRV/CMakeLists.txt b/mlir/test/lib/Conversion/VectorToSPIRV/CMakeLists.txt
new file mode 100644
index 0000000000000..09ed283ac97bc
--- /dev/null
+++ b/mlir/test/lib/Conversion/VectorToSPIRV/CMakeLists.txt
@@ -0,0 +1,15 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRTestVectorToSPIRV
+  TestVectorReductionToSPIRVDotProd.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  LINK_LIBS PUBLIC
+  MLIRVectorToSPIRV
+  MLIRArithDialect
+  MLIRFuncDialect
+  MLIRSPIRVDialect
+  MLIRVectorDialect
+  MLIRPass
+  MLIRTransforms
+  )

diff  --git a/mlir/test/lib/Conversion/VectorToSPIRV/TestVectorReductionToSPIRVDotProd.cpp b/mlir/test/lib/Conversion/VectorToSPIRV/TestVectorReductionToSPIRVDotProd.cpp
new file mode 100644
index 0000000000000..1864d2f7f5036
--- /dev/null
+++ b/mlir/test/lib/Conversion/VectorToSPIRV/TestVectorReductionToSPIRVDotProd.cpp
@@ -0,0 +1,55 @@
+//===- TestVectorReductionToSPIRVDotProd.cpp - Test reduction to dot prod -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace {
+
+struct TestVectorReductionToSPIRVDotProd
+    : PassWrapper<TestVectorReductionToSPIRVDotProd,
+                  OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestVectorReductionToSPIRVDotProd)
+
+  StringRef getArgument() const final {
+    return "test-vector-reduction-to-spirv-dot-prod";
+  }
+
+  StringRef getDescription() const final {
+    return "Test lowering patterns that converts vector.reduction to SPIR-V "
+           "integer dot product ops";
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arith::ArithDialect, func::FuncDialect, spirv::SPIRVDialect,
+                    vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorReductionToSPIRVDotProductPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
+} // namespace
+
+namespace test {
+void registerTestVectorReductionToSPIRVDotProd() {
+  PassRegistration<TestVectorReductionToSPIRVDotProd>();
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index fb7b4e77c5ec8..f84fbe631cf16 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -41,6 +41,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRTestTransforms
     MLIRTilingInterfaceTestPasses
     MLIRVectorTestPasses
+    MLIRTestVectorToSPIRV
     MLIRLLVMTestPasses
     )
 endif()

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index b56c883da587b..568b4710552f5 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -124,6 +124,7 @@ void registerTestTransformDialectEraseSchedulePass();
 void registerTestTransformDialectInterpreterPass();
 void registerTestWrittenToPass();
 void registerTestVectorLowerings();
+void registerTestVectorReductionToSPIRVDotProd();
 void registerTestNvgpuLowerings();
 } // namespace test
 } // namespace mlir
@@ -231,6 +232,7 @@ void registerTestPasses() {
   mlir::test::registerTestTransformDialectEraseSchedulePass();
   mlir::test::registerTestTransformDialectInterpreterPass();
   mlir::test::registerTestVectorLowerings();
+  mlir::test::registerTestVectorReductionToSPIRVDotProd();
   mlir::test::registerTestNvgpuLowerings();
   mlir::test::registerTestWrittenToPass();
 }

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 1ed68d055ebe6..3f98d8247bdcc 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4263,9 +4263,11 @@ cc_library(
     ]),
     includes = ["include"],
     deps = [
+        ":ArithDialect",
         ":ConversionPassIncGen",
         ":IR",
         ":Pass",
+        ":Support",
         ":SPIRVConversion",
         ":SPIRVDialect",
         ":Transforms",
@@ -7077,6 +7079,7 @@ cc_binary(
         "//mlir/test:TestTransforms",
         "//mlir/test:TestTypeDialect",
         "//mlir/test:TestVector",
+        "//mlir/test:TestVectorToSPIRV",
     ],
 )
 

diff  --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index f1672b1ddd7aa..e3c5db0f189f9 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -507,6 +507,20 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "TestVectorToSPIRV",
+    srcs = glob(["lib/Conversion/VectorToSPIRV/*.cpp"]),
+    deps = [
+        "//mlir:ArithDialect",
+        "//mlir:FuncDialect",
+        "//mlir:Pass",
+        "//mlir:SPIRVDialect",
+        "//mlir:Transforms",
+        "//mlir:VectorDialect",
+        "//mlir:VectorToSPIRV",
+    ],
+)
+
 cc_library(
     name = "TestAffine",
     srcs = glob([


        


More information about the Mlir-commits mailing list