[Mlir-commits] [mlir] b194ef6 - [mlir][spirv][vector] Add pattern to convert reduction to SPIR-V dot prod
Jakub Kuderski
llvmlistbot at llvm.org
Fri Mar 10 10:56:53 PST 2023
Author: Jakub Kuderski
Date: 2023-03-10T13:54:16-05:00
New Revision: b194ef692cf3965bac141af31a69428b4f6ae2df
URL: https://github.com/llvm/llvm-project/commit/b194ef692cf3965bac141af31a69428b4f6ae2df
DIFF: https://github.com/llvm/llvm-project/commit/b194ef692cf3965bac141af31a69428b4f6ae2df.diff
LOG: [mlir][spirv][vector] Add pattern to convert reduction to SPIR-V dot prod
This converts a specific form of `vector.reduction` to SPIR-V integer
dot product ops.
Add a new test pass to excercise this outside of the main vector to
spirv conversion pass.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D145760
Added:
mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir
mlir/test/lib/Conversion/VectorToSPIRV/CMakeLists.txt
mlir/test/lib/Conversion/VectorToSPIRV/TestVectorReductionToSPIRVDotProd.cpp
Modified:
mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/lib/Conversion/CMakeLists.txt
mlir/tools/mlir-opt/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
index 7b8882b36cd5d..f8c02c54066b8 100644
--- a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
+++ b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
@@ -23,6 +23,15 @@ class SPIRVTypeConverter;
void populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
+/// Appends patterns to convert vector reduction of the form:
+/// ```
+/// vector.reduction <add>, (muli (ext %lhs), (ext %rhs)), [%acc]
+/// ```
+///
+/// to SPIR-V integer dot product ops.
+void populateVectorReductionToSPIRVDotProductPatterns(
+ RewritePatternSet &patterns);
+
} // namespace mlir
#endif // MLIR_CONVERSION_VECTORTOSPIRV_VECTORTOSPIRV_H
diff --git a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
index 7794ce71b024d..bb9f793d7fe0f 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_conversion_library(MLIRVectorToSPIRV
intrinsics_gen
LINK_LIBS PUBLIC
+ MLIRArithDialect
MLIRSPIRVDialect
MLIRSPIRVConversion
MLIRVectorDialect
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index a3a3a612ed147..20c52f536a23f 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -12,6 +12,7 @@
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
@@ -20,6 +21,9 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
@@ -436,6 +440,84 @@ struct VectorShuffleOpConvert final
}
};
+struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ReductionOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
+
+ auto resultType = dyn_cast<IntegerType>(op.getType());
+ if (!resultType)
+ return rewriter.notifyMatchFailure(op, "result is not an integer");
+
+ int64_t resultBitwidth = resultType.getIntOrFloatBitWidth();
+ if (!llvm::is_contained({32, 64}, resultBitwidth))
+ return rewriter.notifyMatchFailure(op, "unsupported integer bitwidth");
+
+ VectorType inVecTy = op.getSourceVectorType();
+ if (inVecTy.getNumElements() != 4 || inVecTy.getShape().size() != 1 ||
+ inVecTy.isScalable())
+ return rewriter.notifyMatchFailure(op, "unsupported vector shape");
+
+ auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
+ if (!mul)
+ return rewriter.notifyMatchFailure(
+ op, "reduction operand is not 'arith.muli'");
+
+ if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
+ spirv::SDotAccSatOp, false>(op, mul, rewriter)))
+ return success();
+
+ if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
+ spirv::UDotAccSatOp, false>(op, mul, rewriter)))
+ return success();
+
+ if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
+ spirv::SUDotAccSatOp, false>(op, mul, rewriter)))
+ return success();
+
+ if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
+ spirv::SUDotAccSatOp, true>(op, mul, rewriter)))
+ return success();
+
+ return failure();
+ }
+
+private:
+ template <typename LhsExtensionOp, typename RhsExtensionOp, typename DotOp,
+ typename DotAccOp, bool SwapOperands>
+ static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul,
+ PatternRewriter &rewriter) {
+ auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>();
+ if (!lhs || !getElementTypeOrSelf(lhs.getIn().getType()).isInteger(8))
+ return failure();
+
+ auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>();
+ if (!rhs || !getElementTypeOrSelf(rhs.getIn().getType()).isInteger(8))
+ return failure();
+
+ Value lhsIn = lhs.getIn();
+ Value rhsIn = rhs.getIn();
+
+ // There's no variant of dot prod ops for unsigned LHS and signed RHS, so
+ // we have to swap operands instead in that case.
+ if (SwapOperands)
+ std::swap(lhsIn, rhsIn);
+
+ if (Value acc = op.getAcc()) {
+ rewriter.replaceOpWithNewOp<DotAccOp>(op, op.getType(), lhsIn, rhsIn, acc,
+ nullptr);
+ } else {
+ rewriter.replaceOpWithNewOp<DotOp>(op, op.getType(), lhsIn, rhsIn,
+ nullptr);
+ }
+
+ return success();
+ }
+};
+
} // namespace
#define CL_MAX_MIN_OPS \
spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp, \
@@ -457,3 +539,8 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorShuffleOpConvert, VectorSplatPattern>(typeConverter,
patterns.getContext());
}
+
+void mlir::populateVectorReductionToSPIRVDotProductPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<VectorReductionToDotProd>(patterns.getContext());
+}
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir
new file mode 100644
index 0000000000000..bfe6d8608a99d
--- /dev/null
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir
@@ -0,0 +1,176 @@
+// RUN: mlir-opt --split-input-file --verify-diagnostics \
+// RUN: --test-vector-reduction-to-spirv-dot-prod %s -o - | FileCheck %s
+
+// Positive tests.
+
+// CHECK-LABEL: func.func @to_sdot
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
+// CHECK-NEXT: [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i32
+// CHECK-NEXT: return [[DOT]] : i32
+func.func @to_sdot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
+ %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
+ %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
+ %mul = arith.muli %lhs, %rhs : vector<4xi32>
+ %red = vector.reduction <add>, %mul : vector<4xi32> into i32
+ return %red : i32
+}
+
+// CHECK-LABEL: func.func @to_sdot_acc
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
+// CHECK-NEXT: [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
+// CHECK-NEXT: return [[DOT]] : i32
+func.func @to_sdot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
+ %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
+ %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
+ %mul = arith.muli %lhs, %rhs : vector<4xi32>
+ %red = vector.reduction <add>, %mul, %acc : vector<4xi32> into i32
+ return %red : i32
+}
+
+// CHECK-LABEL: func.func @to_sdot_i64
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
+// CHECK-NEXT: [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i64
+// CHECK-NEXT: return [[DOT]] : i64
+func.func @to_sdot_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i64 {
+ %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi64>
+ %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi64>
+ %mul = arith.muli %lhs, %rhs : vector<4xi64>
+ %red = vector.reduction <add>, %mul : vector<4xi64> into i64
+ return %red : i64
+}
+
+// CHECK-LABEL: func.func @to_sdot_acc_i64
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i64)
+// CHECK-NEXT: [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i64) -> i64
+// CHECK-NEXT: return [[DOT]] : i64
+func.func @to_sdot_acc_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i64) -> i64 {
+ %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi64>
+ %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi64>
+ %mul = arith.muli %lhs, %rhs : vector<4xi64>
+ %red = vector.reduction <add>, %mul, %acc : vector<4xi64> into i64
+ return %red : i64
+}
+
+// CHECK-LABEL: func.func @to_udot
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
+// CHECK-NEXT: [[DOT:%.+]] = spirv.UDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i32
+// CHECK-NEXT: return [[DOT]] : i32
+func.func @to_udot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
+ %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
+ %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32>
+ %mul = arith.muli %lhs, %rhs : vector<4xi32>
+ %red = vector.reduction <add>, %mul : vector<4xi32> into i32
+ return %red : i32
+}
+
+// CHECK-LABEL: func.func @to_udot_acc
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
+// CHECK-NEXT: [[DOT:%.+]] = spirv.UDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
+// CHECK-NEXT: return [[DOT]] : i32
+func.func @to_udot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
+ %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
+ %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32>
+ %mul = arith.muli %lhs, %rhs : vector<4xi32>
+ %red = vector.reduction <add>, %mul, %acc : vector<4xi32> into i32
+ return %red : i32
+}
+
+// CHECK-LABEL: func.func @to_signed_unsigned_dot
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
+// CHECK-NEXT: [[DOT:%.+]] = spirv.SUDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i32
+// CHECK-NEXT: return [[DOT]] : i32
+func.func @to_signed_unsigned_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
+ %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
+ %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32>
+ %mul = arith.muli %lhs, %rhs : vector<4xi32>
+ %red = vector.reduction <add>, %mul : vector<4xi32> into i32
+ return %red : i32
+}
+
+// CHECK-LABEL: func.func @to_signed_unsigned_dot_acc
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
+// CHECK-NEXT: [[DOT:%.+]] = spirv.SUDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
+// CHECK-NEXT: return [[DOT]] : i32
+func.func @to_signed_unsigned_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
+ %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
+ %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32>
+ %mul = arith.muli %lhs, %rhs : vector<4xi32>
+ %red = vector.reduction <add>, %mul, %acc : vector<4xi32> into i32
+ return %red : i32
+}
+
+// CHECK-LABEL: func.func @to_unsigned_signed_dot
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
+// CHECK-NEXT: [[DOT:%.+]] = spirv.SUDot [[ARG1]], [[ARG0]] : (vector<4xi8>, vector<4xi8>) -> i32
+// CHECK-NEXT: return [[DOT]] : i32
+func.func @to_unsigned_signed_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
+ %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
+ %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
+ %mul = arith.muli %lhs, %rhs : vector<4xi32>
+ %red = vector.reduction <add>, %mul : vector<4xi32> into i32
+ return %red : i32
+}
+
+// CHECK-LABEL: func.func @to_unsigned_signed_dot_acc
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
+// CHECK-NEXT: [[DOT:%.+]] = spirv.SUDotAccSat [[ARG1]], [[ARG0]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
+// CHECK-NEXT: return [[DOT]] : i32
+func.func @to_unsigned_signed_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
+ %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
+ %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
+ %mul = arith.muli %lhs, %rhs : vector<4xi32>
+ %red = vector.reduction <add>, %mul, %acc : vector<4xi32> into i32
+ return %red : i32
+}
+
+// -----
+// Negative tests.
+
+// CHECK-LABEL: func.func @too_short
+// CHECK-SAME: ([[ARG0:%.+]]: vector<3xi8>, [[ARG1:%.+]]: vector<3xi8>)
+// CHECK: [[RED:%.+]] = vector.reduction
+// CHECK-NEXT: return [[RED]] : i32
+func.func @too_short(%arg0: vector<3xi8>, %arg1: vector<3xi8>) -> i32 {
+ %lhs = arith.extsi %arg0 : vector<3xi8> to vector<3xi32>
+ %rhs = arith.extsi %arg1 : vector<3xi8> to vector<3xi32>
+ %mul = arith.muli %lhs, %rhs : vector<3xi32>
+ %red = vector.reduction <add>, %mul : vector<3xi32> into i32
+ return %red : i32
+}
+
+// CHECK-LABEL: func.func @too_long
+// CHECK-SAME: ([[ARG0:%.+]]: vector<6xi8>, [[ARG1:%.+]]: vector<6xi8>)
+// CHECK: [[RED:%.+]] = vector.reduction
+// CHECK-NEXT: return [[RED]] : i32
+func.func @too_long(%arg0: vector<6xi8>, %arg1: vector<6xi8>) -> i32 {
+ %lhs = arith.extsi %arg0 : vector<6xi8> to vector<6xi32>
+ %rhs = arith.extsi %arg1 : vector<6xi8> to vector<6xi32>
+ %mul = arith.muli %lhs, %rhs : vector<6xi32>
+ %red = vector.reduction <add>, %mul : vector<6xi32> into i32
+ return %red : i32
+}
+
+// CHECK-LABEL: func.func @wrong_reduction_kind
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
+// CHECK: [[RED:%.+]] = vector.reduction <mul>
+// CHECK-NEXT: return [[RED]] : i32
+func.func @wrong_reduction_kind(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
+ %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
+ %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
+ %mul = arith.muli %lhs, %rhs : vector<4xi32>
+ %red = vector.reduction <mul>, %mul : vector<4xi32> into i32
+ return %red : i32
+}
+
+// CHECK-LABEL: func.func @wrong_arith_op
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
+// CHECK: [[ADD:%.+]] = arith.addi
+// CHECK: [[RED:%.+]] = vector.reduction <mul>, [[ADD]]
+// CHECK-NEXT: return [[RED]] : i32
+func.func @wrong_arith_op(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
+ %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
+ %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
+ %add = arith.addi %lhs, %rhs : vector<4xi32>
+ %red = vector.reduction <mul>, %add : vector<4xi32> into i32
+ return %red : i32
+}
diff --git a/mlir/test/lib/Conversion/CMakeLists.txt b/mlir/test/lib/Conversion/CMakeLists.txt
index bc6103dedd490..14f0e0dbe1802 100644
--- a/mlir/test/lib/Conversion/CMakeLists.txt
+++ b/mlir/test/lib/Conversion/CMakeLists.txt
@@ -1 +1,2 @@
add_subdirectory(FuncToLLVM)
+add_subdirectory(VectorToSPIRV)
diff --git a/mlir/test/lib/Conversion/VectorToSPIRV/CMakeLists.txt b/mlir/test/lib/Conversion/VectorToSPIRV/CMakeLists.txt
new file mode 100644
index 0000000000000..09ed283ac97bc
--- /dev/null
+++ b/mlir/test/lib/Conversion/VectorToSPIRV/CMakeLists.txt
@@ -0,0 +1,15 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRTestVectorToSPIRV
+ TestVectorReductionToSPIRVDotProd.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+
+ LINK_LIBS PUBLIC
+ MLIRVectorToSPIRV
+ MLIRArithDialect
+ MLIRFuncDialect
+ MLIRSPIRVDialect
+ MLIRVectorDialect
+ MLIRPass
+ MLIRTransforms
+ )
diff --git a/mlir/test/lib/Conversion/VectorToSPIRV/TestVectorReductionToSPIRVDotProd.cpp b/mlir/test/lib/Conversion/VectorToSPIRV/TestVectorReductionToSPIRVDotProd.cpp
new file mode 100644
index 0000000000000..1864d2f7f5036
--- /dev/null
+++ b/mlir/test/lib/Conversion/VectorToSPIRV/TestVectorReductionToSPIRVDotProd.cpp
@@ -0,0 +1,55 @@
+//===- TestVectorReductionToSPIRVDotProd.cpp - Test reduction to dot prod -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace {
+
+struct TestVectorReductionToSPIRVDotProd
+ : PassWrapper<TestVectorReductionToSPIRVDotProd,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestVectorReductionToSPIRVDotProd)
+
+ StringRef getArgument() const final {
+ return "test-vector-reduction-to-spirv-dot-prod";
+ }
+
+ StringRef getDescription() const final {
+ return "Test lowering patterns that converts vector.reduction to SPIR-V "
+ "integer dot product ops";
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<arith::ArithDialect, func::FuncDialect, spirv::SPIRVDialect,
+ vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorReductionToSPIRVDotProductPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+
+} // namespace
+
+namespace test {
+void registerTestVectorReductionToSPIRVDotProd() {
+ PassRegistration<TestVectorReductionToSPIRVDotProd>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index fb7b4e77c5ec8..f84fbe631cf16 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -41,6 +41,7 @@ if(MLIR_INCLUDE_TESTS)
MLIRTestTransforms
MLIRTilingInterfaceTestPasses
MLIRVectorTestPasses
+ MLIRTestVectorToSPIRV
MLIRLLVMTestPasses
)
endif()
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index b56c883da587b..568b4710552f5 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -124,6 +124,7 @@ void registerTestTransformDialectEraseSchedulePass();
void registerTestTransformDialectInterpreterPass();
void registerTestWrittenToPass();
void registerTestVectorLowerings();
+void registerTestVectorReductionToSPIRVDotProd();
void registerTestNvgpuLowerings();
} // namespace test
} // namespace mlir
@@ -231,6 +232,7 @@ void registerTestPasses() {
mlir::test::registerTestTransformDialectEraseSchedulePass();
mlir::test::registerTestTransformDialectInterpreterPass();
mlir::test::registerTestVectorLowerings();
+ mlir::test::registerTestVectorReductionToSPIRVDotProd();
mlir::test::registerTestNvgpuLowerings();
mlir::test::registerTestWrittenToPass();
}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 1ed68d055ebe6..3f98d8247bdcc 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4263,9 +4263,11 @@ cc_library(
]),
includes = ["include"],
deps = [
+ ":ArithDialect",
":ConversionPassIncGen",
":IR",
":Pass",
+ ":Support",
":SPIRVConversion",
":SPIRVDialect",
":Transforms",
@@ -7077,6 +7079,7 @@ cc_binary(
"//mlir/test:TestTransforms",
"//mlir/test:TestTypeDialect",
"//mlir/test:TestVector",
+ "//mlir/test:TestVectorToSPIRV",
],
)
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index f1672b1ddd7aa..e3c5db0f189f9 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -507,6 +507,20 @@ cc_library(
],
)
+cc_library(
+ name = "TestVectorToSPIRV",
+ srcs = glob(["lib/Conversion/VectorToSPIRV/*.cpp"]),
+ deps = [
+ "//mlir:ArithDialect",
+ "//mlir:FuncDialect",
+ "//mlir:Pass",
+ "//mlir:SPIRVDialect",
+ "//mlir:Transforms",
+ "//mlir:VectorDialect",
+ "//mlir:VectorToSPIRV",
+ ],
+)
+
cc_library(
name = "TestAffine",
srcs = glob([
More information about the Mlir-commits
mailing list