[Mlir-commits] [mlir] 6867e49 - [mlir][spirv] Implement vector type legalization for function signatures (#98337)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 17 10:09:19 PDT 2024


Author: Angel Zhang
Date: 2024-07-17T13:09:15-04:00
New Revision: 6867e49fc80c8468f9a5a8376ce7d3b89fd4fb51

URL: https://github.com/llvm/llvm-project/commit/6867e49fc80c8468f9a5a8376ce7d3b89fd4fb51
DIFF: https://github.com/llvm/llvm-project/commit/6867e49fc80c8468f9a5a8376ce7d3b89fd4fb51.diff

LOG: [mlir][spirv] Implement vector type legalization for function signatures (#98337)

### Description
This PR implements a minimal version of function signature conversion to
unroll vectors into 1D and with a size supported by SPIR-V (2, 3 or 4
depending on the original dimension). This PR also includes new unit
tests that only check for function signature conversion.

### Future Plans
- Check for capabilities that support vectors of size 8 or 16.
- Set up `OneToNTypeConversion` and `DialectConversion` to replace the
current implementation that uses `GreedyPatternRewriteDriver`.
- Introduce other vector unrolling patterns to cancel out the
`vector.insert_strided_slice` and `vector.extract_strided_slice` ops and
fully legalize the vector types in the function body.
- Handle `func::CallOp` and declarations.
- Restructure the code in `SPIRVConversion.cpp`.
- Create test passes for testing sets of patterns in isolation.
- Optimize the way original shape is splitted into target shapes, e.g.
`vector<5xi32>` can be splitted into `vector<4xi32>` and
`vector<1xi32>`.

---------

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>

Added: 
    mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
    mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
    mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
    mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
    mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/ConvertToSPIRV/arith.mlir
    mlir/test/Conversion/ConvertToSPIRV/combined.mlir
    mlir/test/Conversion/ConvertToSPIRV/index.mlir
    mlir/test/Conversion/ConvertToSPIRV/scf.mlir
    mlir/test/Conversion/ConvertToSPIRV/simple.mlir
    mlir/test/Conversion/ConvertToSPIRV/ub.mlir
    mlir/test/Conversion/ConvertToSPIRV/vector.mlir
    mlir/test/lib/Conversion/CMakeLists.txt
    mlir/tools/mlir-opt/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
    utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 54b94bbfb93d1..748646e605827 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -40,7 +40,15 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
   let description = [{
     This is a generic pass to convert to SPIR-V.
   }];
-  let dependentDialects = ["spirv::SPIRVDialect"];
+  let dependentDialects = [
+    "spirv::SPIRVDialect",
+    "vector::VectorDialect",
+  ];
+  let options = [
+    Option<"runSignatureConversion", "run-signature-conversion", "bool",
+    /*default=*/"true",
+    "Run function signature conversion to convert vector types">
+  ];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 09eecafc0c8a5..9ad3d5fc85dd3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -17,7 +17,9 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
 #include "llvm/ADT/SmallSet.h"
 
 namespace mlir {
@@ -134,6 +136,10 @@ class SPIRVConversionTarget : public ConversionTarget {
 void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                         RewritePatternSet &patterns);
 
+void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns);
+
+void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns);
+
 namespace spirv {
 class AccessChainOp;
 

diff  --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index b5be4654bcb25..003a5feea9e9b 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -39,18 +39,31 @@ namespace {
 /// A pass to perform the SPIR-V conversion.
 struct ConvertToSPIRVPass final
     : impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
+  using ConvertToSPIRVPassBase::ConvertToSPIRVPassBase;
 
   void runOnOperation() override {
     MLIRContext *context = &getContext();
     Operation *op = getOperation();
 
+    if (runSignatureConversion) {
+      // Unroll vectors in function signatures to native vector size.
+      RewritePatternSet patterns(context);
+      populateFuncOpVectorRewritePatterns(patterns);
+      populateReturnOpVectorRewritePatterns(patterns);
+      GreedyRewriteConfig config;
+      config.strictMode = GreedyRewriteStrictness::ExistingOps;
+      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+        return signalPassFailure();
+    }
+
     spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
+    std::unique_ptr<ConversionTarget> target =
+        SPIRVConversionTarget::get(targetAttr);
     SPIRVTypeConverter typeConverter(targetAttr);
-
     RewritePatternSet patterns(context);
     ScfToSPIRVContext scfToSPIRVContext;
 
-    // Populate patterns.
+    // Populate patterns for each dialect.
     arith::populateCeilFloorDivExpandOpsPatterns(patterns);
     arith::populateArithToSPIRVPatterns(typeConverter, patterns);
     populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
@@ -60,9 +73,6 @@ struct ConvertToSPIRVPass final
     populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
     ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
 
-    std::unique_ptr<ConversionTarget> target =
-        SPIRVConversionTarget::get(targetAttr);
-
     if (failed(applyPartialConversion(op, *target, std::move(patterns))))
       return signalPassFailure();
   }

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
index 821f82ebc0796..4de9b4729e720 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
@@ -16,9 +16,15 @@ add_mlir_dialect_library(MLIRSPIRVConversion
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
 
   LINK_LIBS PUBLIC
+  MLIRArithDialect
+  MLIRDialectUtils
   MLIRFuncDialect
+  MLIRIR
   MLIRSPIRVDialect
+  MLIRSupport
   MLIRTransformUtils
+  MLIRVectorDialect
+  MLIRVectorTransforms
 )
 
 add_mlir_dialect_library(MLIRSPIRVTransforms

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 4072608dc8f87..e3a09ef1ff684 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -11,14 +11,24 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/MathExtras.h"
@@ -34,6 +44,43 @@ using namespace mlir;
 // Utility functions
 //===----------------------------------------------------------------------===//
 
+static int getComputeVectorSize(int64_t size) {
+  for (int i : {4, 3, 2}) {
+    if (size % i == 0)
+      return i;
+  }
+  return 1;
+}
+
+static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
+  LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
+  if (vecType.isScalable()) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "--scalable vectors are not supported -> BAIL\n");
+    return std::nullopt;
+  }
+  SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
+  std::optional<SmallVector<int64_t>> targetShape =
+      SmallVector<int64_t>(1, getComputeVectorSize(vecType.getShape().back()));
+  if (!targetShape) {
+    LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
+    return std::nullopt;
+  }
+  auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
+  if (!maybeShapeRatio) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "--could not compute integral shape ratio -> BAIL\n");
+    return std::nullopt;
+  }
+  if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
+    LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n");
+    return std::nullopt;
+  }
+  LLVM_DEBUG(llvm::dbgs()
+             << "--found an integral shape ratio to unroll to -> SUCCESS\n");
+  return targetShape;
+}
+
 /// Checks that `candidates` extension requirements are possible to be satisfied
 /// with the given `targetEnv`.
 ///
@@ -813,6 +860,249 @@ void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
   patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
 }
 
+//===----------------------------------------------------------------------===//
+// func::FuncOp Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A pattern for rewriting function signature to convert vector arguments of
+/// functions to be of valid types
+struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(func::FuncOp funcOp,
+                                PatternRewriter &rewriter) const override {
+    FunctionType fnType = funcOp.getFunctionType();
+
+    // TODO: Handle declarations.
+    if (funcOp.isDeclaration()) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << fnType << " illegal: declarations are unsupported\n");
+      return failure();
+    }
+
+    // Create a new func op with the original type and copy the function body.
+    auto newFuncOp = rewriter.create<func::FuncOp>(funcOp.getLoc(),
+                                                   funcOp.getName(), fnType);
+    rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+                                newFuncOp.end());
+
+    Location loc = newFuncOp.getBody().getLoc();
+
+    Block &entryBlock = newFuncOp.getBlocks().front();
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(&entryBlock);
+
+    OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
+
+    // For arguments that are of illegal types and require unrolling.
+    // `unrolledInputNums` stores the indices of arguments that result from
+    // unrolling in the new function signature. `newInputNo` is a counter.
+    SmallVector<size_t> unrolledInputNums;
+    size_t newInputNo = 0;
+
+    // For arguments that are of legal types and do not require unrolling.
+    // `tmpOps` stores a mapping from temporary operations that serve as
+    // placeholders for new arguments that will be added later. These operations
+    // will be erased once the entry block's argument list is updated.
+    llvm::SmallDenseMap<Operation *, size_t> tmpOps;
+
+    // This counts the number of new operations created.
+    size_t newOpCount = 0;
+
+    // Enumerate through the arguments.
+    for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
+      // Check whether the argument is of vector type.
+      auto origVecType = dyn_cast<VectorType>(origType);
+      if (!origVecType) {
+        // We need a placeholder for the old argument that will be erased later.
+        Value result = rewriter.create<arith::ConstantOp>(
+            loc, origType, rewriter.getZeroAttr(origType));
+        rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+        tmpOps.insert({result.getDefiningOp(), newInputNo});
+        oneToNTypeMapping.addInputs(origInputNo, origType);
+        ++newInputNo;
+        ++newOpCount;
+        continue;
+      }
+      // Check whether the vector needs unrolling.
+      auto targetShape = getTargetShape(origVecType);
+      if (!targetShape) {
+        // We need a placeholder for the old argument that will be erased later.
+        Value result = rewriter.create<arith::ConstantOp>(
+            loc, origType, rewriter.getZeroAttr(origType));
+        rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+        tmpOps.insert({result.getDefiningOp(), newInputNo});
+        oneToNTypeMapping.addInputs(origInputNo, origType);
+        ++newInputNo;
+        ++newOpCount;
+        continue;
+      }
+      VectorType unrolledType =
+          VectorType::get(*targetShape, origVecType.getElementType());
+      auto originalShape =
+          llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
+
+      // Prepare the result vector.
+      Value result = rewriter.create<arith::ConstantOp>(
+          loc, origVecType, rewriter.getZeroAttr(origVecType));
+      ++newOpCount;
+      // Prepare the placeholder for the new arguments that will be added later.
+      Value dummy = rewriter.create<arith::ConstantOp>(
+          loc, unrolledType, rewriter.getZeroAttr(unrolledType));
+      ++newOpCount;
+
+      // Create the `vector.insert_strided_slice` ops.
+      SmallVector<int64_t> strides(targetShape->size(), 1);
+      SmallVector<Type> newTypes;
+      for (SmallVector<int64_t> offsets :
+           StaticTileOffsetRange(originalShape, *targetShape)) {
+        result = rewriter.create<vector::InsertStridedSliceOp>(
+            loc, dummy, result, offsets, strides);
+        newTypes.push_back(unrolledType);
+        unrolledInputNums.push_back(newInputNo);
+        ++newInputNo;
+        ++newOpCount;
+      }
+      rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+      oneToNTypeMapping.addInputs(origInputNo, newTypes);
+    }
+
+    // Change the function signature.
+    auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
+    auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
+    rewriter.modifyOpInPlace(newFuncOp,
+                             [&] { newFuncOp.setFunctionType(newFnType); });
+
+    // Update the arguments in the entry block.
+    entryBlock.eraseArguments(0, fnType.getNumInputs());
+    SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
+    entryBlock.addArguments(convertedTypes, locs);
+
+    // Replace the placeholder values with the new arguments. We assume there is
+    // only one block for now.
+    size_t unrolledInputIdx = 0;
+    for (auto [count, op] : enumerate(entryBlock.getOperations())) {
+      // We first look for operands that are placeholders for initially legal
+      // arguments.
+      Operation &curOp = op;
+      for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
+        Operation *operandOp = operandVal.getDefiningOp();
+        if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
+          size_t idx = operandIdx;
+          rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
+            curOp.setOperand(idx, newFuncOp.getArgument(it->second));
+          });
+        }
+      }
+      // Since all newly created operations are in the beginning, reaching the
+      // end of them means that any later `vector.insert_strided_slice` should
+      // not be touched.
+      if (count >= newOpCount)
+        continue;
+      if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
+        size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
+        rewriter.modifyOpInPlace(&curOp, [&] {
+          curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
+        });
+        ++unrolledInputIdx;
+      }
+    }
+
+    // Erase the original funcOp. The `tmpOps` do not need to be erased since
+    // they have no uses and will be handled by dead-code elimination.
+    rewriter.eraseOp(funcOp);
+    return success();
+  }
+};
+} // namespace
+
+void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
+  patterns.add<FuncOpVectorUnroll>(patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// func::ReturnOp Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A pattern for rewriting function signature and the return op to convert
+/// vectors to be of valid types.
+struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(func::ReturnOp returnOp,
+                                PatternRewriter &rewriter) const override {
+    // Check whether the parent funcOp is valid.
+    auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
+    if (!funcOp)
+      return failure();
+
+    FunctionType fnType = funcOp.getFunctionType();
+    OneToNTypeMapping oneToNTypeMapping(fnType.getResults());
+    Location loc = returnOp.getLoc();
+
+    // For the new return op.
+    SmallVector<Value> newOperands;
+
+    // Enumerate through the results.
+    for (auto [origResultNo, origType] : enumerate(fnType.getResults())) {
+      // Check whether the argument is of vector type.
+      auto origVecType = dyn_cast<VectorType>(origType);
+      if (!origVecType) {
+        oneToNTypeMapping.addInputs(origResultNo, origType);
+        newOperands.push_back(returnOp.getOperand(origResultNo));
+        continue;
+      }
+      // Check whether the vector needs unrolling.
+      auto targetShape = getTargetShape(origVecType);
+      if (!targetShape) {
+        // The original argument can be used.
+        oneToNTypeMapping.addInputs(origResultNo, origType);
+        newOperands.push_back(returnOp.getOperand(origResultNo));
+        continue;
+      }
+      VectorType unrolledType =
+          VectorType::get(*targetShape, origVecType.getElementType());
+
+      // Create `vector.extract_strided_slice` ops to form legal vectors from
+      // the original operand of illegal type.
+      auto originalShape =
+          llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
+      SmallVector<int64_t> strides(targetShape->size(), 1);
+      SmallVector<Type> newTypes;
+      Value returnValue = returnOp.getOperand(origResultNo);
+      for (SmallVector<int64_t> offsets :
+           StaticTileOffsetRange(originalShape, *targetShape)) {
+        Value result = rewriter.create<vector::ExtractStridedSliceOp>(
+            loc, returnValue, offsets, *targetShape, strides);
+        newOperands.push_back(result);
+        newTypes.push_back(unrolledType);
+      }
+      oneToNTypeMapping.addInputs(origResultNo, newTypes);
+    }
+
+    // Change the function signature.
+    auto newFnType =
+        FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
+                          TypeRange(oneToNTypeMapping.getConvertedTypes()));
+    rewriter.modifyOpInPlace(funcOp,
+                             [&] { funcOp.setFunctionType(newFnType); });
+
+    // Replace the return op using the new operands. This will automatically
+    // update the entry block as well.
+    rewriter.replaceOp(returnOp,
+                       rewriter.create<func::ReturnOp>(loc, newOperands));
+
+    return success();
+  }
+};
+} // namespace
+
+void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) {
+  patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
+}
+
 //===----------------------------------------------------------------------===//
 // Builtin Variables
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/ConvertToSPIRV/arith.mlir b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir
index a2adc0ad9c7a5..1a844a7cd018b 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/arith.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" -split-input-file %s | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // arithmetic ops

diff  --git a/mlir/test/Conversion/ConvertToSPIRV/combined.mlir b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir
index 9e908465cb142..02b938be775a3 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/combined.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
 
 // CHECK-LABEL: @combined
 // CHECK: %[[C0_F32:.*]] = spirv.Constant 0.000000e+00 : f32

diff  --git a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
new file mode 100644
index 0000000000000..347d282f9ee0c
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
@@ -0,0 +1,147 @@
+// RUN: mlir-opt -test-spirv-func-signature-conversion -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @simple_scalar
+// CHECK-SAME: (%[[ARG0:.+]]: i32)
+func.func @simple_scalar(%arg0 : i32) -> i32 {
+  // CHECK: return %[[ARG0]] : i32
+  return %arg0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @simple_vector_4
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>)
+func.func @simple_vector_4(%arg0 : vector<4xi32>) -> vector<4xi32> {
+  // CHECK: return %[[ARG0]] : vector<4xi32>
+  return %arg0 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @simple_vector_5
+// CHECK-SAME: (%[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: vector<1xi32>, %[[ARG2:.+]]: vector<1xi32>, %[[ARG3:.+]]: vector<1xi32>, %[[ARG4:.+]]: vector<1xi32>)
+func.func @simple_vector_5(%arg0 : vector<5xi32>) -> vector<5xi32> {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<5xi32>
+  // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<1xi32> into vector<5xi32>
+  // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [1], strides = [1]} : vector<1xi32> into vector<5xi32>
+  // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [2], strides = [1]} : vector<1xi32> into vector<5xi32>
+  // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [3], strides = [1]} : vector<1xi32> into vector<5xi32>
+  // CHECK: %[[INSERT4:.*]] = vector.insert_strided_slice %[[ARG4]], %[[INSERT3]] {offsets = [4], strides = [1]} : vector<1xi32> into vector<5xi32>
+  // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [0], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
+  // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [1], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
+  // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [2], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
+  // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [3], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
+  // CHECK: %[[EXTRACT4:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [4], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
+  // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]], %[[EXTRACT2]], %[[EXTRACT3]], %[[EXTRACT4]] : vector<1xi32>, vector<1xi32>, vector<1xi32>, vector<1xi32>, vector<1xi32>
+  return %arg0 : vector<5xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @simple_vector_6
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>)
+func.func @simple_vector_6(%arg0 : vector<6xi32>) -> vector<6xi32> {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<6xi32>
+  // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<3xi32> into vector<6xi32>
+  // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [3], strides = [1]} : vector<3xi32> into vector<6xi32>
+  // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32>
+  // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32>
+  // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]] : vector<3xi32>, vector<3xi32>
+  return %arg0 : vector<6xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @simple_vector_8
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>)
+func.func @simple_vector_8(%arg0 : vector<8xi32>) -> vector<8xi32> {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+  // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+  // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+  // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]] : vector<4xi32>, vector<4xi32>
+  return %arg0 : vector<8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_6and8
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>)
+func.func @vector_6and8(%arg0 : vector<6xi32>, %arg1 : vector<8xi32>) -> (vector<6xi32>, vector<8xi32>) {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+  // CHECK: %[[CST0:.*]] = arith.constant dense<0> : vector<6xi32>
+  // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]] {offsets = [0], strides = [1]} : vector<3xi32> into vector<6xi32>
+  // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [3], strides = [1]} : vector<3xi32> into vector<6xi32>
+  // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32>
+  // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32>
+  // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+  // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+  // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]], %[[EXTRACT2]], %[[EXTRACT3]] : vector<3xi32>, vector<3xi32>, vector<4xi32>, vector<4xi32>
+  return %arg0, %arg1 : vector<6xi32>, vector<8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_3and8
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>)
+func.func @vector_3and8(%arg0 : vector<3xi32>, %arg1 : vector<8xi32>) -> (vector<3xi32>, vector<8xi32>) {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+  // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG1]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+  // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+  // CHECK: return %[[ARG0]], %[[EXTRACT0]], %[[EXTRACT1]] : vector<3xi32>, vector<4xi32>, vector<4xi32>
+  return %arg0, %arg1 : vector<3xi32>, vector<8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @scalar_vector
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<3xi32>, %[[ARG3:.+]]: i32)
+func.func @scalar_vector(%arg0 : vector<8xi32>, %arg1 : vector<3xi32>, %arg2 : i32) -> (vector<8xi32>, vector<3xi32>, i32) {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+  // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+  // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+  // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]], %[[ARG2]], %[[ARG3]] : vector<4xi32>, vector<4xi32>, vector<3xi32>, i32
+  return %arg0, %arg1, %arg2 : vector<8xi32>, vector<3xi32>, i32
+}
+
+// -----
+
+// CHECK-LABEL: @reduction
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>, %[[ARG4:.+]]: i32)
+func.func @reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>, %arg2 : i32) -> (i32) {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+  // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[ADDI:.*]] = arith.addi %[[INSERT1]], %[[INSERT3]] : vector<8xi32>
+  // CHECK: %[[REDUCTION:.*]] = vector.reduction <add>, %[[ADDI]] : vector<8xi32> into i32
+  // CHECK: %[[RET:.*]] = arith.addi %[[REDUCTION]], %[[ARG4]] : i32
+  // CHECK: return %[[RET]] : i32
+  %0 = arith.addi %arg0, %arg1 : vector<8xi32>
+  %1 = vector.reduction <add>, %0 : vector<8xi32> into i32
+  %2 = arith.addi %1, %arg2 : i32
+  return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func private @unsupported_decl(vector<8xi32>)
+func.func private @unsupported_decl(vector<8xi32>)
+
+// -----
+
+// CHECK-LABEL: @unsupported_scalable
+// CHECK-SAME: (%[[ARG0:.+]]: vector<[8]xi32>)
+func.func @unsupported_scalable(%arg0 : vector<[8]xi32>) -> (vector<[8]xi32>) {
+  // CHECK: return %[[ARG0]] : vector<[8]xi32>
+  return %arg0 : vector<[8]xi32>
+}
+

diff  --git a/mlir/test/Conversion/ConvertToSPIRV/index.mlir b/mlir/test/Conversion/ConvertToSPIRV/index.mlir
index db747625bc7b3..e1cb18aac5d01 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/index.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/index.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-to-spirv | FileCheck %s
+// RUN: mlir-opt %s -convert-to-spirv="run-signature-conversion=false" | FileCheck %s
 
 // CHECK-LABEL: @basic
 func.func @basic(%a: index, %b: index) {

diff  --git a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
index f619ca5771824..58ec6ac61f6ac 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
 
 // CHECK-LABEL: @if_yield
 // CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr<f32, Function>

diff  --git a/mlir/test/Conversion/ConvertToSPIRV/simple.mlir b/mlir/test/Conversion/ConvertToSPIRV/simple.mlir
index 20b2a42bc3975..c5e0e6603d94a 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/simple.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
 
 // CHECK-LABEL: @return_scalar
 // CHECK-SAME: %[[ARG0:.*]]: i32

diff  --git a/mlir/test/Conversion/ConvertToSPIRV/ub.mlir b/mlir/test/Conversion/ConvertToSPIRV/ub.mlir
index 66528b68f58cf..a83bfb6f405a0 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/ub.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/ub.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
 
 // CHECK-LABEL: @ub
 // CHECK: %[[UNDEF:.*]] = spirv.Undef : i32

diff  --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
index 336f0fe10c27e..c63dd030f4747 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -convert-to-spirv %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
 
 // CHECK-LABEL: @extract
 //  CHECK-SAME: %[[ARG:.+]]: vector<2xf32>

diff  --git a/mlir/test/lib/Conversion/CMakeLists.txt b/mlir/test/lib/Conversion/CMakeLists.txt
index 754c9866d18e4..19975f671b081 100644
--- a/mlir/test/lib/Conversion/CMakeLists.txt
+++ b/mlir/test/lib/Conversion/CMakeLists.txt
@@ -1,3 +1,4 @@
+add_subdirectory(ConvertToSPIRV)
 add_subdirectory(FuncToLLVM)
 add_subdirectory(MathToVCIX)
 add_subdirectory(OneToNTypeConversion)

diff  --git a/mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt b/mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
new file mode 100644
index 0000000000000..69b5787f7e851
--- /dev/null
+++ b/mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
@@ -0,0 +1,16 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRTestConvertToSPIRV
+  TestSPIRVFuncSignatureConversion.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  LINK_LIBS PUBLIC
+  MLIRArithDialect
+  MLIRFuncDialect
+  MLIRPass
+  MLIRSPIRVConversion
+  MLIRSPIRVDialect
+  MLIRTransformUtils
+  MLIRTransforms
+  MLIRVectorDialect
+  )

diff  --git a/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp b/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp
new file mode 100644
index 0000000000000..ec67f85f6f27b
--- /dev/null
+++ b/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp
@@ -0,0 +1,57 @@
+//===- TestSPIRVFuncSignatureConversion.cpp - Test signature conversion -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===-------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace {
+
+struct TestSPIRVFuncSignatureConversion final
+    : PassWrapper<TestSPIRVFuncSignatureConversion, OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSPIRVFuncSignatureConversion)
+
+  StringRef getArgument() const final {
+    return "test-spirv-func-signature-conversion";
+  }
+
+  StringRef getDescription() const final {
+    return "Test patterns that convert vector inputs and results in function "
+           "signatures";
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arith::ArithDialect, func::FuncDialect, spirv::SPIRVDialect,
+                    vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateFuncOpVectorRewritePatterns(patterns);
+    populateReturnOpVectorRewritePatterns(patterns);
+    GreedyRewriteConfig config;
+    config.strictMode = GreedyRewriteStrictness::ExistingOps;
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+                                       config);
+  }
+};
+
+} // namespace
+
+namespace test {
+void registerTestSPIRVFuncSignatureConversion() {
+  PassRegistration<TestSPIRVFuncSignatureConversion>();
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index e8091bca3326c..8b79de58fa102 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -36,6 +36,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRSPIRVTestPasses
     MLIRTensorTestPasses
     MLIRTestAnalysis
+    MLIRTestConvertToSPIRV
     MLIRTestDialect
     MLIRTestDynDialect
     MLIRTestIR

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 8cafb0afac9ae..149f9d59961b8 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -141,6 +141,7 @@ void registerTestSCFWhileOpBuilderPass();
 void registerTestSCFWrapInZeroTripCheckPasses();
 void registerTestShapeMappingPass();
 void registerTestSliceAnalysisPass();
+void registerTestSPIRVFuncSignatureConversion();
 void registerTestTensorCopyInsertionPass();
 void registerTestTensorTransforms();
 void registerTestTopologicalSortAnalysisPass();
@@ -273,6 +274,7 @@ void registerTestPasses() {
   mlir::test::registerTestSCFWrapInZeroTripCheckPasses();
   mlir::test::registerTestShapeMappingPass();
   mlir::test::registerTestSliceAnalysisPass();
+  mlir::test::registerTestSPIRVFuncSignatureConversion();
   mlir::test::registerTestTensorCopyInsertionPass();
   mlir::test::registerTestTensorTransforms();
   mlir::test::registerTestTopologicalSortAnalysisPass();

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 8d2b2be67ad79..0f9f688403530 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7207,10 +7207,15 @@ cc_library(
     hdrs = ["include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"],
     includes = ["include"],
     deps = [
+        ":ArithDialect",
+        ":DialectUtils",
         ":FuncDialect",
         ":IR",
         ":SPIRVDialect",
+        ":Support",
         ":TransformUtils",
+        ":VectorDialect",
+        ":VectorTransforms",
         "//llvm:Support",
     ],
 )
@@ -9588,6 +9593,7 @@ cc_binary(
         "//mlir/test:TestArmSME",
         "//mlir/test:TestBufferization",
         "//mlir/test:TestControlFlow",
+        "//mlir/test:TestConvertToSPIRV",
         "//mlir/test:TestDLTI",
         "//mlir/test:TestDialect",
         "//mlir/test:TestFunc",

diff  --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index 1d59370057d1c..a1d2b20a106e6 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -656,6 +656,21 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "TestConvertToSPIRV",
+    srcs = glob(["lib/Conversion/ConvertToSPIRV/*.cpp"]),
+    deps = [
+        "//mlir:ArithDialect",
+        "//mlir:FuncDialect",
+        "//mlir:Pass",
+        "//mlir:SPIRVConversion",
+        "//mlir:SPIRVDialect",
+        "//mlir:TransformUtils",
+        "//mlir:Transforms",
+        "//mlir:VectorDialect",
+    ],
+)
+
 cc_library(
     name = "TestAffine",
     srcs = glob([


        


More information about the Mlir-commits mailing list