[Mlir-commits] [mlir] 6867e49 - [mlir][spirv] Implement vector type legalization for function signatures (#98337)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 17 10:09:19 PDT 2024
Author: Angel Zhang
Date: 2024-07-17T13:09:15-04:00
New Revision: 6867e49fc80c8468f9a5a8376ce7d3b89fd4fb51
URL: https://github.com/llvm/llvm-project/commit/6867e49fc80c8468f9a5a8376ce7d3b89fd4fb51
DIFF: https://github.com/llvm/llvm-project/commit/6867e49fc80c8468f9a5a8376ce7d3b89fd4fb51.diff
LOG: [mlir][spirv] Implement vector type legalization for function signatures (#98337)
### Description
This PR implements a minimal version of function signature conversion to
unroll vectors into 1D and with a size supported by SPIR-V (2, 3 or 4
depending on the original dimension). This PR also includes new unit
tests that only check for function signature conversion.
### Future Plans
- Check for capabilities that support vectors of size 8 or 16.
- Set up `OneToNTypeConversion` and `DialectConversion` to replace the
current implementation that uses `GreedyPatternRewriteDriver`.
- Introduce other vector unrolling patterns to cancel out the
`vector.insert_strided_slice` and `vector.extract_strided_slice` ops and
fully legalize the vector types in the function body.
- Handle `func::CallOp` and declarations.
- Restructure the code in `SPIRVConversion.cpp`.
- Create test passes for testing sets of patterns in isolation.
- Optimize the way original shape is splitted into target shapes, e.g.
`vector<5xi32>` can be splitted into `vector<4xi32>` and
`vector<1xi32>`.
---------
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
Added:
mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/test/Conversion/ConvertToSPIRV/arith.mlir
mlir/test/Conversion/ConvertToSPIRV/combined.mlir
mlir/test/Conversion/ConvertToSPIRV/index.mlir
mlir/test/Conversion/ConvertToSPIRV/scf.mlir
mlir/test/Conversion/ConvertToSPIRV/simple.mlir
mlir/test/Conversion/ConvertToSPIRV/ub.mlir
mlir/test/Conversion/ConvertToSPIRV/vector.mlir
mlir/test/lib/Conversion/CMakeLists.txt
mlir/tools/mlir-opt/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 54b94bbfb93d1..748646e605827 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -40,7 +40,15 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
let description = [{
This is a generic pass to convert to SPIR-V.
}];
- let dependentDialects = ["spirv::SPIRVDialect"];
+ let dependentDialects = [
+ "spirv::SPIRVDialect",
+ "vector::VectorDialect",
+ ];
+ let options = [
+ Option<"runSignatureConversion", "run-signature-conversion", "bool",
+ /*default=*/"true",
+ "Run function signature conversion to convert vector types">
+ ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 09eecafc0c8a5..9ad3d5fc85dd3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -17,7 +17,9 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
#include "llvm/ADT/SmallSet.h"
namespace mlir {
@@ -134,6 +136,10 @@ class SPIRVConversionTarget : public ConversionTarget {
void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
+void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns);
+
+void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns);
+
namespace spirv {
class AccessChainOp;
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index b5be4654bcb25..003a5feea9e9b 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -39,18 +39,31 @@ namespace {
/// A pass to perform the SPIR-V conversion.
struct ConvertToSPIRVPass final
: impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
+ using ConvertToSPIRVPassBase::ConvertToSPIRVPassBase;
void runOnOperation() override {
MLIRContext *context = &getContext();
Operation *op = getOperation();
+ if (runSignatureConversion) {
+ // Unroll vectors in function signatures to native vector size.
+ RewritePatternSet patterns(context);
+ populateFuncOpVectorRewritePatterns(patterns);
+ populateReturnOpVectorRewritePatterns(patterns);
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+ return signalPassFailure();
+ }
+
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
+ std::unique_ptr<ConversionTarget> target =
+ SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
-
RewritePatternSet patterns(context);
ScfToSPIRVContext scfToSPIRVContext;
- // Populate patterns.
+ // Populate patterns for each dialect.
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
@@ -60,9 +73,6 @@ struct ConvertToSPIRVPass final
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
- std::unique_ptr<ConversionTarget> target =
- SPIRVConversionTarget::get(targetAttr);
-
if (failed(applyPartialConversion(op, *target, std::move(patterns))))
return signalPassFailure();
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
index 821f82ebc0796..4de9b4729e720 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
@@ -16,9 +16,15 @@ add_mlir_dialect_library(MLIRSPIRVConversion
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
LINK_LIBS PUBLIC
+ MLIRArithDialect
+ MLIRDialectUtils
MLIRFuncDialect
+ MLIRIR
MLIRSPIRVDialect
+ MLIRSupport
MLIRTransformUtils
+ MLIRVectorDialect
+ MLIRVectorTransforms
)
add_mlir_dialect_library(MLIRSPIRVTransforms
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 4072608dc8f87..e3a09ef1ff684 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -11,14 +11,24 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
@@ -34,6 +44,43 @@ using namespace mlir;
// Utility functions
//===----------------------------------------------------------------------===//
+static int getComputeVectorSize(int64_t size) {
+ for (int i : {4, 3, 2}) {
+ if (size % i == 0)
+ return i;
+ }
+ return 1;
+}
+
+static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
+ LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
+ if (vecType.isScalable()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "--scalable vectors are not supported -> BAIL\n");
+ return std::nullopt;
+ }
+ SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
+ std::optional<SmallVector<int64_t>> targetShape =
+ SmallVector<int64_t>(1, getComputeVectorSize(vecType.getShape().back()));
+ if (!targetShape) {
+ LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
+ return std::nullopt;
+ }
+ auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
+ if (!maybeShapeRatio) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "--could not compute integral shape ratio -> BAIL\n");
+ return std::nullopt;
+ }
+ if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
+ LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n");
+ return std::nullopt;
+ }
+ LLVM_DEBUG(llvm::dbgs()
+ << "--found an integral shape ratio to unroll to -> SUCCESS\n");
+ return targetShape;
+}
+
/// Checks that `candidates` extension requirements are possible to be satisfied
/// with the given `targetEnv`.
///
@@ -813,6 +860,249 @@ void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
}
+//===----------------------------------------------------------------------===//
+// func::FuncOp Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A pattern for rewriting function signature to convert vector arguments of
+/// functions to be of valid types
+struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(func::FuncOp funcOp,
+ PatternRewriter &rewriter) const override {
+ FunctionType fnType = funcOp.getFunctionType();
+
+ // TODO: Handle declarations.
+ if (funcOp.isDeclaration()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << fnType << " illegal: declarations are unsupported\n");
+ return failure();
+ }
+
+ // Create a new func op with the original type and copy the function body.
+ auto newFuncOp = rewriter.create<func::FuncOp>(funcOp.getLoc(),
+ funcOp.getName(), fnType);
+ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+
+ Location loc = newFuncOp.getBody().getLoc();
+
+ Block &entryBlock = newFuncOp.getBlocks().front();
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(&entryBlock);
+
+ OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
+
+ // For arguments that are of illegal types and require unrolling.
+ // `unrolledInputNums` stores the indices of arguments that result from
+ // unrolling in the new function signature. `newInputNo` is a counter.
+ SmallVector<size_t> unrolledInputNums;
+ size_t newInputNo = 0;
+
+ // For arguments that are of legal types and do not require unrolling.
+ // `tmpOps` stores a mapping from temporary operations that serve as
+ // placeholders for new arguments that will be added later. These operations
+ // will be erased once the entry block's argument list is updated.
+ llvm::SmallDenseMap<Operation *, size_t> tmpOps;
+
+ // This counts the number of new operations created.
+ size_t newOpCount = 0;
+
+ // Enumerate through the arguments.
+ for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
+ // Check whether the argument is of vector type.
+ auto origVecType = dyn_cast<VectorType>(origType);
+ if (!origVecType) {
+ // We need a placeholder for the old argument that will be erased later.
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, origType, rewriter.getZeroAttr(origType));
+ rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+ tmpOps.insert({result.getDefiningOp(), newInputNo});
+ oneToNTypeMapping.addInputs(origInputNo, origType);
+ ++newInputNo;
+ ++newOpCount;
+ continue;
+ }
+ // Check whether the vector needs unrolling.
+ auto targetShape = getTargetShape(origVecType);
+ if (!targetShape) {
+ // We need a placeholder for the old argument that will be erased later.
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, origType, rewriter.getZeroAttr(origType));
+ rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+ tmpOps.insert({result.getDefiningOp(), newInputNo});
+ oneToNTypeMapping.addInputs(origInputNo, origType);
+ ++newInputNo;
+ ++newOpCount;
+ continue;
+ }
+ VectorType unrolledType =
+ VectorType::get(*targetShape, origVecType.getElementType());
+ auto originalShape =
+ llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
+
+ // Prepare the result vector.
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, origVecType, rewriter.getZeroAttr(origVecType));
+ ++newOpCount;
+ // Prepare the placeholder for the new arguments that will be added later.
+ Value dummy = rewriter.create<arith::ConstantOp>(
+ loc, unrolledType, rewriter.getZeroAttr(unrolledType));
+ ++newOpCount;
+
+ // Create the `vector.insert_strided_slice` ops.
+ SmallVector<int64_t> strides(targetShape->size(), 1);
+ SmallVector<Type> newTypes;
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalShape, *targetShape)) {
+ result = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, dummy, result, offsets, strides);
+ newTypes.push_back(unrolledType);
+ unrolledInputNums.push_back(newInputNo);
+ ++newInputNo;
+ ++newOpCount;
+ }
+ rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+ oneToNTypeMapping.addInputs(origInputNo, newTypes);
+ }
+
+ // Change the function signature.
+ auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
+ auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
+ rewriter.modifyOpInPlace(newFuncOp,
+ [&] { newFuncOp.setFunctionType(newFnType); });
+
+ // Update the arguments in the entry block.
+ entryBlock.eraseArguments(0, fnType.getNumInputs());
+ SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
+ entryBlock.addArguments(convertedTypes, locs);
+
+ // Replace the placeholder values with the new arguments. We assume there is
+ // only one block for now.
+ size_t unrolledInputIdx = 0;
+ for (auto [count, op] : enumerate(entryBlock.getOperations())) {
+ // We first look for operands that are placeholders for initially legal
+ // arguments.
+ Operation &curOp = op;
+ for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
+ Operation *operandOp = operandVal.getDefiningOp();
+ if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
+ size_t idx = operandIdx;
+ rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
+ curOp.setOperand(idx, newFuncOp.getArgument(it->second));
+ });
+ }
+ }
+ // Since all newly created operations are in the beginning, reaching the
+ // end of them means that any later `vector.insert_strided_slice` should
+ // not be touched.
+ if (count >= newOpCount)
+ continue;
+ if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
+ size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
+ rewriter.modifyOpInPlace(&curOp, [&] {
+ curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
+ });
+ ++unrolledInputIdx;
+ }
+ }
+
+ // Erase the original funcOp. The `tmpOps` do not need to be erased since
+ // they have no uses and will be handled by dead-code elimination.
+ rewriter.eraseOp(funcOp);
+ return success();
+ }
+};
+} // namespace
+
+void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
+ patterns.add<FuncOpVectorUnroll>(patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// func::ReturnOp Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A pattern for rewriting function signature and the return op to convert
+/// vectors to be of valid types.
+struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(func::ReturnOp returnOp,
+ PatternRewriter &rewriter) const override {
+ // Check whether the parent funcOp is valid.
+ auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
+ if (!funcOp)
+ return failure();
+
+ FunctionType fnType = funcOp.getFunctionType();
+ OneToNTypeMapping oneToNTypeMapping(fnType.getResults());
+ Location loc = returnOp.getLoc();
+
+ // For the new return op.
+ SmallVector<Value> newOperands;
+
+ // Enumerate through the results.
+ for (auto [origResultNo, origType] : enumerate(fnType.getResults())) {
+ // Check whether the argument is of vector type.
+ auto origVecType = dyn_cast<VectorType>(origType);
+ if (!origVecType) {
+ oneToNTypeMapping.addInputs(origResultNo, origType);
+ newOperands.push_back(returnOp.getOperand(origResultNo));
+ continue;
+ }
+ // Check whether the vector needs unrolling.
+ auto targetShape = getTargetShape(origVecType);
+ if (!targetShape) {
+ // The original argument can be used.
+ oneToNTypeMapping.addInputs(origResultNo, origType);
+ newOperands.push_back(returnOp.getOperand(origResultNo));
+ continue;
+ }
+ VectorType unrolledType =
+ VectorType::get(*targetShape, origVecType.getElementType());
+
+ // Create `vector.extract_strided_slice` ops to form legal vectors from
+ // the original operand of illegal type.
+ auto originalShape =
+ llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
+ SmallVector<int64_t> strides(targetShape->size(), 1);
+ SmallVector<Type> newTypes;
+ Value returnValue = returnOp.getOperand(origResultNo);
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalShape, *targetShape)) {
+ Value result = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, returnValue, offsets, *targetShape, strides);
+ newOperands.push_back(result);
+ newTypes.push_back(unrolledType);
+ }
+ oneToNTypeMapping.addInputs(origResultNo, newTypes);
+ }
+
+ // Change the function signature.
+ auto newFnType =
+ FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
+ TypeRange(oneToNTypeMapping.getConvertedTypes()));
+ rewriter.modifyOpInPlace(funcOp,
+ [&] { funcOp.setFunctionType(newFnType); });
+
+ // Replace the return op using the new operands. This will automatically
+ // update the entry block as well.
+ rewriter.replaceOp(returnOp,
+ rewriter.create<func::ReturnOp>(loc, newOperands));
+
+ return success();
+ }
+};
+} // namespace
+
+void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) {
+ patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
+}
+
//===----------------------------------------------------------------------===//
// Builtin Variables
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ConvertToSPIRV/arith.mlir b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir
index a2adc0ad9c7a5..1a844a7cd018b 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/arith.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" -split-input-file %s | FileCheck %s
//===----------------------------------------------------------------------===//
// arithmetic ops
diff --git a/mlir/test/Conversion/ConvertToSPIRV/combined.mlir b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir
index 9e908465cb142..02b938be775a3 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/combined.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
// CHECK-LABEL: @combined
// CHECK: %[[C0_F32:.*]] = spirv.Constant 0.000000e+00 : f32
diff --git a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
new file mode 100644
index 0000000000000..347d282f9ee0c
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
@@ -0,0 +1,147 @@
+// RUN: mlir-opt -test-spirv-func-signature-conversion -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @simple_scalar
+// CHECK-SAME: (%[[ARG0:.+]]: i32)
+func.func @simple_scalar(%arg0 : i32) -> i32 {
+ // CHECK: return %[[ARG0]] : i32
+ return %arg0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @simple_vector_4
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>)
+func.func @simple_vector_4(%arg0 : vector<4xi32>) -> vector<4xi32> {
+ // CHECK: return %[[ARG0]] : vector<4xi32>
+ return %arg0 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @simple_vector_5
+// CHECK-SAME: (%[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: vector<1xi32>, %[[ARG2:.+]]: vector<1xi32>, %[[ARG3:.+]]: vector<1xi32>, %[[ARG4:.+]]: vector<1xi32>)
+func.func @simple_vector_5(%arg0 : vector<5xi32>) -> vector<5xi32> {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<5xi32>
+ // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<1xi32> into vector<5xi32>
+ // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [1], strides = [1]} : vector<1xi32> into vector<5xi32>
+ // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [2], strides = [1]} : vector<1xi32> into vector<5xi32>
+ // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [3], strides = [1]} : vector<1xi32> into vector<5xi32>
+ // CHECK: %[[INSERT4:.*]] = vector.insert_strided_slice %[[ARG4]], %[[INSERT3]] {offsets = [4], strides = [1]} : vector<1xi32> into vector<5xi32>
+ // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [0], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
+ // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [1], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
+ // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [2], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
+ // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [3], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
+ // CHECK: %[[EXTRACT4:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [4], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
+ // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]], %[[EXTRACT2]], %[[EXTRACT3]], %[[EXTRACT4]] : vector<1xi32>, vector<1xi32>, vector<1xi32>, vector<1xi32>, vector<1xi32>
+ return %arg0 : vector<5xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @simple_vector_6
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>)
+func.func @simple_vector_6(%arg0 : vector<6xi32>) -> vector<6xi32> {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<6xi32>
+ // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<3xi32> into vector<6xi32>
+ // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [3], strides = [1]} : vector<3xi32> into vector<6xi32>
+ // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32>
+ // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32>
+ // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]] : vector<3xi32>, vector<3xi32>
+ return %arg0 : vector<6xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @simple_vector_8
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>)
+func.func @simple_vector_8(%arg0 : vector<8xi32>) -> vector<8xi32> {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+ // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+ // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+ // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]] : vector<4xi32>, vector<4xi32>
+ return %arg0 : vector<8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_6and8
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>)
+func.func @vector_6and8(%arg0 : vector<6xi32>, %arg1 : vector<8xi32>) -> (vector<6xi32>, vector<8xi32>) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+ // CHECK: %[[CST0:.*]] = arith.constant dense<0> : vector<6xi32>
+ // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]] {offsets = [0], strides = [1]} : vector<3xi32> into vector<6xi32>
+ // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [3], strides = [1]} : vector<3xi32> into vector<6xi32>
+ // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32>
+ // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32>
+ // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+ // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+ // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]], %[[EXTRACT2]], %[[EXTRACT3]] : vector<3xi32>, vector<3xi32>, vector<4xi32>, vector<4xi32>
+ return %arg0, %arg1 : vector<6xi32>, vector<8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_3and8
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>)
+func.func @vector_3and8(%arg0 : vector<3xi32>, %arg1 : vector<8xi32>) -> (vector<3xi32>, vector<8xi32>) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+ // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG1]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+ // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+ // CHECK: return %[[ARG0]], %[[EXTRACT0]], %[[EXTRACT1]] : vector<3xi32>, vector<4xi32>, vector<4xi32>
+ return %arg0, %arg1 : vector<3xi32>, vector<8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @scalar_vector
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<3xi32>, %[[ARG3:.+]]: i32)
+func.func @scalar_vector(%arg0 : vector<8xi32>, %arg1 : vector<3xi32>, %arg2 : i32) -> (vector<8xi32>, vector<3xi32>, i32) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+ // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+ // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+ // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]], %[[ARG2]], %[[ARG3]] : vector<4xi32>, vector<4xi32>, vector<3xi32>, i32
+ return %arg0, %arg1, %arg2 : vector<8xi32>, vector<3xi32>, i32
+}
+
+// -----
+
+// CHECK-LABEL: @reduction
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>, %[[ARG4:.+]]: i32)
+func.func @reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>, %arg2 : i32) -> (i32) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+ // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[ADDI:.*]] = arith.addi %[[INSERT1]], %[[INSERT3]] : vector<8xi32>
+ // CHECK: %[[REDUCTION:.*]] = vector.reduction <add>, %[[ADDI]] : vector<8xi32> into i32
+ // CHECK: %[[RET:.*]] = arith.addi %[[REDUCTION]], %[[ARG4]] : i32
+ // CHECK: return %[[RET]] : i32
+ %0 = arith.addi %arg0, %arg1 : vector<8xi32>
+ %1 = vector.reduction <add>, %0 : vector<8xi32> into i32
+ %2 = arith.addi %1, %arg2 : i32
+ return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func private @unsupported_decl(vector<8xi32>)
+func.func private @unsupported_decl(vector<8xi32>)
+
+// -----
+
+// CHECK-LABEL: @unsupported_scalable
+// CHECK-SAME: (%[[ARG0:.+]]: vector<[8]xi32>)
+func.func @unsupported_scalable(%arg0 : vector<[8]xi32>) -> (vector<[8]xi32>) {
+ // CHECK: return %[[ARG0]] : vector<[8]xi32>
+ return %arg0 : vector<[8]xi32>
+}
+
diff --git a/mlir/test/Conversion/ConvertToSPIRV/index.mlir b/mlir/test/Conversion/ConvertToSPIRV/index.mlir
index db747625bc7b3..e1cb18aac5d01 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/index.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/index.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-to-spirv | FileCheck %s
+// RUN: mlir-opt %s -convert-to-spirv="run-signature-conversion=false" | FileCheck %s
// CHECK-LABEL: @basic
func.func @basic(%a: index, %b: index) {
diff --git a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
index f619ca5771824..58ec6ac61f6ac 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
// CHECK-LABEL: @if_yield
// CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr<f32, Function>
diff --git a/mlir/test/Conversion/ConvertToSPIRV/simple.mlir b/mlir/test/Conversion/ConvertToSPIRV/simple.mlir
index 20b2a42bc3975..c5e0e6603d94a 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/simple.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
// CHECK-LABEL: @return_scalar
// CHECK-SAME: %[[ARG0:.*]]: i32
diff --git a/mlir/test/Conversion/ConvertToSPIRV/ub.mlir b/mlir/test/Conversion/ConvertToSPIRV/ub.mlir
index 66528b68f58cf..a83bfb6f405a0 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/ub.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/ub.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
// CHECK-LABEL: @ub
// CHECK: %[[UNDEF:.*]] = spirv.Undef : i32
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
index 336f0fe10c27e..c63dd030f4747 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -convert-to-spirv %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
// CHECK-LABEL: @extract
// CHECK-SAME: %[[ARG:.+]]: vector<2xf32>
diff --git a/mlir/test/lib/Conversion/CMakeLists.txt b/mlir/test/lib/Conversion/CMakeLists.txt
index 754c9866d18e4..19975f671b081 100644
--- a/mlir/test/lib/Conversion/CMakeLists.txt
+++ b/mlir/test/lib/Conversion/CMakeLists.txt
@@ -1,3 +1,4 @@
+add_subdirectory(ConvertToSPIRV)
add_subdirectory(FuncToLLVM)
add_subdirectory(MathToVCIX)
add_subdirectory(OneToNTypeConversion)
diff --git a/mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt b/mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
new file mode 100644
index 0000000000000..69b5787f7e851
--- /dev/null
+++ b/mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
@@ -0,0 +1,16 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRTestConvertToSPIRV
+ TestSPIRVFuncSignatureConversion.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+
+ LINK_LIBS PUBLIC
+ MLIRArithDialect
+ MLIRFuncDialect
+ MLIRPass
+ MLIRSPIRVConversion
+ MLIRSPIRVDialect
+ MLIRTransformUtils
+ MLIRTransforms
+ MLIRVectorDialect
+ )
diff --git a/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp b/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp
new file mode 100644
index 0000000000000..ec67f85f6f27b
--- /dev/null
+++ b/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp
@@ -0,0 +1,57 @@
+//===- TestSPIRVFuncSignatureConversion.cpp - Test signature conversion -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===-------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace {
+
+struct TestSPIRVFuncSignatureConversion final
+ : PassWrapper<TestSPIRVFuncSignatureConversion, OperationPass<ModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSPIRVFuncSignatureConversion)
+
+ StringRef getArgument() const final {
+ return "test-spirv-func-signature-conversion";
+ }
+
+ StringRef getDescription() const final {
+ return "Test patterns that convert vector inputs and results in function "
+ "signatures";
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<arith::ArithDialect, func::FuncDialect, spirv::SPIRVDialect,
+ vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateFuncOpVectorRewritePatterns(patterns);
+ populateReturnOpVectorRewritePatterns(patterns);
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+ config);
+ }
+};
+
+} // namespace
+
+namespace test {
+void registerTestSPIRVFuncSignatureConversion() {
+ PassRegistration<TestSPIRVFuncSignatureConversion>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index e8091bca3326c..8b79de58fa102 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -36,6 +36,7 @@ if(MLIR_INCLUDE_TESTS)
MLIRSPIRVTestPasses
MLIRTensorTestPasses
MLIRTestAnalysis
+ MLIRTestConvertToSPIRV
MLIRTestDialect
MLIRTestDynDialect
MLIRTestIR
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 8cafb0afac9ae..149f9d59961b8 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -141,6 +141,7 @@ void registerTestSCFWhileOpBuilderPass();
void registerTestSCFWrapInZeroTripCheckPasses();
void registerTestShapeMappingPass();
void registerTestSliceAnalysisPass();
+void registerTestSPIRVFuncSignatureConversion();
void registerTestTensorCopyInsertionPass();
void registerTestTensorTransforms();
void registerTestTopologicalSortAnalysisPass();
@@ -273,6 +274,7 @@ void registerTestPasses() {
mlir::test::registerTestSCFWrapInZeroTripCheckPasses();
mlir::test::registerTestShapeMappingPass();
mlir::test::registerTestSliceAnalysisPass();
+ mlir::test::registerTestSPIRVFuncSignatureConversion();
mlir::test::registerTestTensorCopyInsertionPass();
mlir::test::registerTestTensorTransforms();
mlir::test::registerTestTopologicalSortAnalysisPass();
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 8d2b2be67ad79..0f9f688403530 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7207,10 +7207,15 @@ cc_library(
hdrs = ["include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"],
includes = ["include"],
deps = [
+ ":ArithDialect",
+ ":DialectUtils",
":FuncDialect",
":IR",
":SPIRVDialect",
+ ":Support",
":TransformUtils",
+ ":VectorDialect",
+ ":VectorTransforms",
"//llvm:Support",
],
)
@@ -9588,6 +9593,7 @@ cc_binary(
"//mlir/test:TestArmSME",
"//mlir/test:TestBufferization",
"//mlir/test:TestControlFlow",
+ "//mlir/test:TestConvertToSPIRV",
"//mlir/test:TestDLTI",
"//mlir/test:TestDialect",
"//mlir/test:TestFunc",
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index 1d59370057d1c..a1d2b20a106e6 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -656,6 +656,21 @@ cc_library(
],
)
+cc_library(
+ name = "TestConvertToSPIRV",
+ srcs = glob(["lib/Conversion/ConvertToSPIRV/*.cpp"]),
+ deps = [
+ "//mlir:ArithDialect",
+ "//mlir:FuncDialect",
+ "//mlir:Pass",
+ "//mlir:SPIRVConversion",
+ "//mlir:SPIRVDialect",
+ "//mlir:TransformUtils",
+ "//mlir:Transforms",
+ "//mlir:VectorDialect",
+ ],
+)
+
cc_library(
name = "TestAffine",
srcs = glob([
More information about the Mlir-commits
mailing list