[Mlir-commits] [mlir] [mlir][spirv] Implement vector type legalization in function signatures (PR #98337)

Angel Zhang llvmlistbot at llvm.org
Thu Jul 11 08:00:01 PDT 2024


https://github.com/angelz913 updated https://github.com/llvm/llvm-project/pull/98337

>From b33d3726d916dc03927d01616db195f292ccf410 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Thu, 27 Jun 2024 15:44:27 +0000
Subject: [PATCH 01/14] [mlir][spirv] Implement vector type legalization in
 function signatures

---
 mlir/include/mlir/Conversion/Passes.td        |   5 +-
 .../Vector/Transforms/VectorRewritePatterns.h |   4 +
 .../ConvertToSPIRV/ConvertToSPIRVPass.cpp     | 145 ++++++++++++-
 .../Dialect/Vector/Transforms/CMakeLists.txt  |   1 +
 .../Vector/Transforms/VectorUnroll.cpp        | 201 ++++++++++++++++++
 5 files changed, 353 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 560b088dbe5cd..8d83343f5b736 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -40,7 +40,10 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
   let description = [{
     This is a generic pass to convert to SPIR-V.
   }];
-  let dependentDialects = ["spirv::SPIRVDialect"];
+  let dependentDialects = [
+    "spirv::SPIRVDialect",
+    "vector::VectorDialect",
+  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 8e6d36f0b5f09..5c06d6d4d6ad3 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -293,6 +293,10 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
                                   const UnrollVectorOptions &options,
                                   PatternBenefit benefit = 1);
 
+void populateVectorUnrollFuncSignaturePatterns(RewritePatternSet &patterns,
+                                               const UnrollVectorOptions &options,
+                                               PatternBenefit benefit = 1);
+
 /// Collect a set of vector.shape_cast folding patterns.
 void populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
                                       PatternBenefit benefit = 1);
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index b5be4654bcb25..54152c5be26fa 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -34,6 +34,66 @@ namespace mlir {
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+// Vector Lowering
+//===----------------------------------------------------------------------===//
+
+int getComputeVectorSize(int64_t size) {
+  for (int i : {4, 3, 2}) {
+    if (size % i == 0)
+      return i;
+  }
+  return 1;
+}
+
+SmallVector<int64_t> getNativeVectorShapeImpl(vector::MultiDimReductionOp op) {
+  // Unroll all reduction dimensions by size 1 for vector.multi_reduction.
+  VectorType srcVectorType = op.getSourceVectorType();
+  auto nativeSize = llvm::to_vector(srcVectorType.getShape());
+  auto dims = op.getReductionDims().getAsValueRange<IntegerAttr>();
+  for (const auto &dimAttr : dims) {
+    nativeSize[dimAttr.getZExtValue()] = 1;
+  }
+  return nativeSize;
+}
+
+SmallVector<int64_t> getNativeVectorShapeImpl(vector::ReductionOp op) {
+  VectorType srcVectorType = op.getSourceVectorType();
+  assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
+  int64_t vectorSize = getComputeVectorSize(srcVectorType.getDimSize(0));
+  return {vectorSize};
+}
+
+SmallVector<int64_t> getNativeVectorShapeImpl(vector::TransposeOp op) {
+  VectorType vectorType = op.getResultVectorType();
+  SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
+  nativeSize.back() = getComputeVectorSize(vectorType.getShape().back());
+  return nativeSize;
+}
+
+SmallVector<int64_t> getNativeVectorShapeImpl(vector::GatherOp op) {
+  VectorType vectorType = op.getVectorType();
+  SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
+  nativeSize.back() = getComputeVectorSize(vectorType.getShape().back());
+  return nativeSize;
+}
+
+std::optional<SmallVector<int64_t>> getNativeVectorShape(Operation *op) {
+  if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
+    if (auto vecType = llvm::dyn_cast<VectorType>(op->getResultTypes()[0])) {
+      SmallVector<int64_t> nativeSize(vecType.getRank(), 1);
+      nativeSize.back() = getComputeVectorSize(vecType.getShape().back());
+      return nativeSize;
+    }
+  }
+
+  return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)
+      .Case<vector::MultiDimReductionOp, vector::ReductionOp,
+            vector::TransposeOp, vector::GatherOp>(
+          [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
+      .Default([](Operation *) { return std::nullopt; });
+}
+
 namespace {
 
 /// A pass to perform the SPIR-V conversion.
@@ -47,13 +107,94 @@ struct ConvertToSPIRVPass final
     spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
     SPIRVTypeConverter typeConverter(targetAttr);
 
+    // Unroll vectors in function signature to native vector size.
+    {
+      llvm::errs() << "Start unrolling function signature\n";
+      RewritePatternSet patterns(context);
+      // TODO: This is hardcoded to unroll with size 1. Change this later
+      SmallVector<int64_t> nativeShape(1, 1);
+      auto options = vector::UnrollVectorOptions().setNativeShape(nativeShape);
+      populateVectorUnrollFuncSignaturePatterns(patterns, options);
+      GreedyRewriteConfig config;
+      config.strictMode = GreedyRewriteStrictness::ExistingOps;
+      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+        return signalPassFailure();
+      llvm::errs() << "Finish unrolling function signature\n";
+    }
+
+    // Unroll vectors to native vector size.
+    {
+      RewritePatternSet patterns(context);
+      auto options = vector::UnrollVectorOptions().setNativeShapeFn(
+          [=](auto op) { return getNativeVectorShape(op); });
+      populateVectorUnrollPatterns(patterns, options);
+      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+        return signalPassFailure();
+    }
+
+    // Next run canonicalization to cast away leading size-1 dimensions.
+    {
+      RewritePatternSet patterns(context);
+
+      // We need to pull in casting way leading one dims to allow cancelling
+      // some read/write ops.
+      vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+
+      // We may have vector.insert_strided_slice inserting 1-D native vectors
+      // into n-D larger vectors with the above. Break that down too. This is a
+      // companion transformation of unrolling.
+      vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
+          patterns);
+      vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
+
+      // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
+      // them up.
+      vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
+      vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
+
+      vector::TransferReadOp::getCanonicalizationPatterns(patterns, context);
+      vector::TransferWriteOp::getCanonicalizationPatterns(patterns, context);
+
+      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+        return signalPassFailure();
+    }
+
+    // Convert vector.extract_strided_slice into a chain of vector.extract and
+    // then a chain of vector.insert ops. This helps to cancel with previous
+    // vector.insert/extract ops, especially for fP16 cases where we have
+    // mismatched vector size for transfer and compute.
+    {
+      RewritePatternSet patterns(context);
+      vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
+          patterns, [](vector::ExtractStridedSliceOp op) {
+            return op.getSourceVectorType().getNumElements() > 4;
+          });
+      vector::InsertOp::getCanonicalizationPatterns(patterns, context);
+      vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
+      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+        return signalPassFailure();
+    }
+
+    // Run all sorts of canonicalization patterns to clean up again.
+    {
+      RewritePatternSet patterns(context);
+      vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+      vector::InsertOp::getCanonicalizationPatterns(patterns, context);
+      vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
+      vector::TransferReadOp::getCanonicalizationPatterns(patterns, context);
+      vector::TransferWriteOp::getCanonicalizationPatterns(patterns, context);
+      vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
+      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+        return signalPassFailure();
+    }
+
     RewritePatternSet patterns(context);
     ScfToSPIRVContext scfToSPIRVContext;
 
-    // Populate patterns.
+    // Populate patterns for each dialect.
     arith::populateCeilFloorDivExpandOpsPatterns(patterns);
     arith::populateArithToSPIRVPatterns(typeConverter, patterns);
-    populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
+    // populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
     populateFuncToSPIRVPatterns(typeConverter, patterns);
     index::populateIndexToSPIRVPatterns(typeConverter, patterns);
     populateVectorToSPIRVPatterns(typeConverter, patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 723b2f62d65d4..1538c7eed6e76 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -43,6 +43,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   MLIRMemRefDialect
   MLIRMemRefUtils
   MLIRSCFDialect
+  MLIRSPIRVDialect
   MLIRSideEffectInterfaces
   MLIRSubsetOpInterface
   MLIRTensorDialect
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index b3f558c3bac12..b63cb502b76e8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -11,12 +11,26 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/IR/Block.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/TypeRange.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/iterator_range.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/Debug.h"
 #include <numeric>
 #include <optional>
@@ -65,6 +79,32 @@ static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
                         resultTypes, op->getAttrs());
 }
 
+static std::optional<SmallVector<int64_t>>
+getTargetShape(const vector::UnrollVectorOptions &options, func::FuncOp funcOp,
+               VectorType vecType) {
+  assert(options.nativeShape &&
+         "vector unrolling expects the native shape or native"
+         "shape call back function to be set");
+  llvm::errs() << "Get target shape\n";
+  SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
+  std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(funcOp);
+  if (!targetShape) {
+    llvm::errs() << "--no unrolling target shape defined\n";
+    return std::nullopt;
+  }
+  auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
+  if (!maybeShapeRatio) {
+    llvm::errs() << "--could not compute integral shape ratio -> BAIL\n";
+    return std::nullopt;
+  }
+  if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
+    llvm::errs() << "--no unrolling needed -> SKIP\n";
+    return std::nullopt;
+  }
+  llvm::errs() << "--found an integral shape ratio to unroll to -> SUCCESS\n";
+  return targetShape;
+}
+
 /// Return the target shape for unrolling for the given `op`. Return
 /// std::nullopt if the op shouldn't be or cannot be unrolled.
 static std::optional<SmallVector<int64_t>>
@@ -617,6 +657,160 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
   vector::UnrollVectorOptions options;
 };
 
+struct UnrollFuncSignaturePattern : OpRewritePattern<func::FuncOp> {
+  UnrollFuncSignaturePattern(MLIRContext *context,
+                             const vector::UnrollVectorOptions &options,
+                             PatternBenefit benefit = 1)
+      : OpRewritePattern<func::FuncOp>(context, benefit), options(options) {}
+
+  LogicalResult matchAndRewrite(func::FuncOp funcOp,
+                                PatternRewriter &rewriter) const override {
+    llvm::errs() << "Run unroll function signature pattern\n";
+
+    auto fnType = funcOp.getFunctionType();
+
+    // Check function inputs.
+    Location loc = funcOp.getFunctionBody()
+                       .getBlocks()
+                       .begin()
+                       ->getOperations()
+                       .begin()
+                       ->getLoc();
+    size_t newArgIndex = 0;
+    std::vector<Type> newSignature;
+    std::vector<std::vector<size_t>> newArgMap(fnType.getNumInputs());
+
+    for (const auto &argType : enumerate(fnType.getInputs())) {
+      size_t index = argType.index();
+      Type type = argType.value();
+      auto vecType = llvm::dyn_cast<VectorType>(type);
+      if (!vecType) {
+        newSignature.push_back(type);
+        newArgMap[index].push_back(newArgIndex);
+        newArgIndex++;
+        continue;
+      }
+      // Try vector unrolling
+      llvm::errs() << "Try vector unrolling\n";
+      SmallVector<int64_t> originalShape =
+          llvm::to_vector<4>(vecType.getShape());
+      auto targetShape = getTargetShape(options, funcOp, vecType);
+      if (!targetShape) {
+        llvm::errs() << "No target shape\n";
+        newSignature.push_back(type);
+        newArgMap[index].push_back(newArgIndex);
+        newArgIndex++;
+        continue;
+      }
+      llvm::errs() << "Got target shape\n";
+      VectorType unrolledType =
+          VectorType::get(*targetShape, vecType.getElementType());
+      llvm::errs() << "Unrolled type is ";
+      unrolledType.dump();
+
+      for (SmallVector<int64_t> offsets :
+           StaticTileOffsetRange(originalShape, *targetShape)) {
+        newSignature.push_back(unrolledType);
+        newArgMap[index].push_back(newArgIndex);
+        newArgIndex++;
+      }
+    }
+
+    // Assume there is a single result for now.
+    Type originalResultType = fnType.getResult(0);
+
+    // TODO: Handle illegal vector types in results as well.
+    // SmallVector<Type> resultTypes;
+    // auto vecType = llvm::dyn_cast<VectorType>(originalResultType);
+
+    // if (vecType) {
+    //   // Try vector unrolling
+    //   SmallVector<int64_t> originalShape =
+    //   llvm::to_vector<4>(vecType.getShape()); auto targetShape =
+    //   getTargetShape(options, funcOp, vecType); VectorType unrolledType =
+    //     VectorType::get(*targetShape, vecType.getElementType());
+    //   if (targetShape)
+    //     for (SmallVector<int64_t> offsets :
+    //          StaticTileOffsetRange(originalShape, *targetShape))
+    //       resultTypes.push_back(unrolledType);
+    // }
+
+    // Create the converted func op
+    auto newFuncOp = rewriter.create<func::FuncOp>(
+        funcOp.getLoc(), funcOp.getName(),
+        FunctionType::get(rewriter.getContext(), TypeRange(newSignature),
+                          TypeRange(originalResultType)));
+
+    newFuncOp.addEntryBlock();
+
+    llvm::errs() << "Created new func op\n";
+    newFuncOp.dump();
+    llvm::errs() << newFuncOp.getArguments().size() << "\n";
+
+    // TODO: Copy over all attributes other than the function name and type
+
+    // Clone operations (assuming one block for now)
+    // TODO: The uses for operands that are SSA values are not cloned properly.
+    loc = newFuncOp.getBody().getLoc();
+    rewriter.setInsertionPointToStart(&newFuncOp.getBody().getBlocks().front());
+
+    for (auto &op : funcOp.getBlocks().front().getOperations()) {
+      op.dump();
+      SmallVector<Value> newOperands(op.getNumOperands());
+      for (size_t i = 0; i < op.getOperands().size(); ++i) {
+        Value operand = op.getOperand(i);
+        auto blockArg = llvm::dyn_cast<BlockArgument>(operand);
+        if (!blockArg) {
+          newOperands[i] = operand;
+          continue;
+        }
+        // Not unrolled
+        unsigned int argNum = blockArg.getArgNumber();
+        if (newArgMap[argNum].size() == 1) {
+          newOperands[i] = newFuncOp.getArgument(newArgMap[argNum][0]);
+          continue;
+        }
+        // Unrolled
+        // TODO: Store previously created vector.insert_strided_slice ops.
+        auto vecType = dyn_cast<VectorType>(blockArg.getType());
+        SmallVector<int64_t> originalShape =
+            llvm::to_vector<4>(vecType.getShape());
+        auto targetShape = getTargetShape(options, funcOp, vecType);
+        VectorType unrolledType =
+            VectorType::get(*targetShape, vecType.getElementType());
+        llvm::errs() << "Unrolled type is ";
+        unrolledType.dump();
+        // Prepare the result vector.
+        Value result = rewriter.create<arith::ConstantOp>(
+            loc, vecType, rewriter.getZeroAttr(vecType));
+        result.dump();
+        SmallVector<int64_t> strides(targetShape->size(), 1);
+        // Create the vector.insert_strided_slice ops.
+        unsigned int j = 0;
+        for (SmallVector<int64_t> offsets :
+             StaticTileOffsetRange(originalShape, *targetShape)) {
+          result = rewriter.create<vector::InsertStridedSliceOp>(
+              loc, newFuncOp.getArgument(newArgMap[argNum][j]), result, offsets,
+              strides);
+          result.dump();
+          j++;
+        }
+        newOperands[i] = result;
+      }
+      Operation *newOp =
+          rewriter.create(loc, op.getName().getIdentifier(), newOperands,
+                          op.getResultTypes(), op.getAttrs());
+      llvm::errs() << "newOp is ";
+      newOp->dump();
+    }
+    rewriter.eraseOp(funcOp);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorUnrollPatterns(
@@ -628,3 +822,10 @@ void mlir::vector::populateVectorUnrollPatterns(
                UnrollTransposePattern, UnrollGatherPattern>(
       patterns.getContext(), options, benefit);
 }
+
+void mlir::vector::populateVectorUnrollFuncSignaturePatterns(
+    RewritePatternSet &patterns, const UnrollVectorOptions &options,
+    PatternBenefit benefit) {
+  patterns.add<UnrollFuncSignaturePattern>(patterns.getContext(), options,
+                                           benefit);
+}
\ No newline at end of file

>From a5913ba5744fdd56c29284d6b1c2daecd6d7b73d Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Fri, 5 Jul 2024 15:45:34 +0000
Subject: [PATCH 02/14] Function input vector unrolling working and moved
 pattern to SPIRV

---
 .../SPIRV/Transforms/SPIRVConversion.h        |   4 +
 .../ConvertToSPIRV/ConvertToSPIRVPass.cpp     |  34 ++--
 .../Dialect/SPIRV/Transforms/CMakeLists.txt   |   1 +
 .../SPIRV/Transforms/SPIRVConversion.cpp      | 161 ++++++++++++++++++
 4 files changed, 188 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 09eecafc0c8a5..1206603edcb6d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -17,8 +17,10 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/SmallSet.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
 
 namespace mlir {
 
@@ -134,6 +136,8 @@ class SPIRVConversionTarget : public ConversionTarget {
 void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                         RewritePatternSet &patterns);
 
+void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns);
+
 namespace spirv {
 class AccessChainOp;
 
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 54152c5be26fa..adb903d3f448c 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
 #include <memory>
 
 #define DEBUG_TYPE "convert-to-spirv"
@@ -105,23 +106,23 @@ struct ConvertToSPIRVPass final
     Operation *op = getOperation();
 
     spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
-    SPIRVTypeConverter typeConverter(targetAttr);
+    std::unique_ptr<ConversionTarget> target =
+        SPIRVConversionTarget::get(targetAttr);
 
-    // Unroll vectors in function signature to native vector size.
+    // Unroll vectors in function inputs to native vector size.
     {
-      llvm::errs() << "Start unrolling function signature\n";
+      llvm::errs() << "Start unrolling function inputs\n";
       RewritePatternSet patterns(context);
-      // TODO: This is hardcoded to unroll with size 1. Change this later
-      SmallVector<int64_t> nativeShape(1, 1);
-      auto options = vector::UnrollVectorOptions().setNativeShape(nativeShape);
-      populateVectorUnrollFuncSignaturePatterns(patterns, options);
+      populateFuncOpVectorRewritePatterns(patterns);
       GreedyRewriteConfig config;
       config.strictMode = GreedyRewriteStrictness::ExistingOps;
       if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
         return signalPassFailure();
-      llvm::errs() << "Finish unrolling function signature\n";
+      llvm::errs() << "Finish unrolling function inputs\n";
     }
 
+    SPIRVTypeConverter typeConverter(targetAttr);
+
     // Unroll vectors to native vector size.
     {
       RewritePatternSet patterns(context);
@@ -132,6 +133,9 @@ struct ConvertToSPIRVPass final
         return signalPassFailure();
     }
 
+    llvm::errs() << "After unrolling vectors to native vector size\n";
+    op->dump();
+
     // Next run canonicalization to cast away leading size-1 dimensions.
     {
       RewritePatternSet patterns(context);
@@ -159,6 +163,9 @@ struct ConvertToSPIRVPass final
         return signalPassFailure();
     }
 
+    llvm::errs() << "After running canonicalization to cast away leading size-1 dimensions\n";
+    op->dump();
+
     // Convert vector.extract_strided_slice into a chain of vector.extract and
     // then a chain of vector.insert ops. This helps to cancel with previous
     // vector.insert/extract ops, especially for fP16 cases where we have
@@ -175,6 +182,9 @@ struct ConvertToSPIRVPass final
         return signalPassFailure();
     }
 
+    llvm::errs() << "After converting vector.extract_strided_slice into a chain of vector.extract and then a chain of vector.insert ops\n";
+    op->dump();
+
     // Run all sorts of canonicalization patterns to clean up again.
     {
       RewritePatternSet patterns(context);
@@ -188,22 +198,22 @@ struct ConvertToSPIRVPass final
         return signalPassFailure();
     }
 
+    llvm::errs() << "After running canonicalization patterns to clean up again\n";
+    op->dump();
+
     RewritePatternSet patterns(context);
     ScfToSPIRVContext scfToSPIRVContext;
 
     // Populate patterns for each dialect.
     arith::populateCeilFloorDivExpandOpsPatterns(patterns);
     arith::populateArithToSPIRVPatterns(typeConverter, patterns);
-    // populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
+    populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
     populateFuncToSPIRVPatterns(typeConverter, patterns);
     index::populateIndexToSPIRVPatterns(typeConverter, patterns);
     populateVectorToSPIRVPatterns(typeConverter, patterns);
     populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
     ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
 
-    std::unique_ptr<ConversionTarget> target =
-        SPIRVConversionTarget::get(targetAttr);
-
     if (failed(applyPartialConversion(op, *target, std::move(patterns))))
       return signalPassFailure();
   }
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
index 821f82ebc0796..11af020b6c188 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
@@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIRSPIRVConversion
   MLIRFuncDialect
   MLIRSPIRVDialect
   MLIRTransformUtils
+  MLIRVectorTransforms
 )
 
 add_mlir_dialect_library(MLIRSPIRVTransforms
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 4072608dc8f87..616eb6104b705 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -17,8 +17,15 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/MathExtras.h"
@@ -813,6 +820,160 @@ void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
   patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
 }
 
+//===----------------------------------------------------------------------===//
+// func::FuncOp Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A pattern for rewriting function signature to convert vector arguments of
+/// functions to be of valid types
+class FuncOpVectorTypesConversion : public OpRewritePattern<func::FuncOp> {
+public:
+  using OpRewritePattern<func::FuncOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(func::FuncOp funcOp,
+                                PatternRewriter &rewriter) const override;
+};
+} // namespace
+
+static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
+  llvm::errs() << "Get target shape\n";
+  SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
+  // TODO: This is hardcoded to unroll with size 1. Change this later
+  std::optional<SmallVector<int64_t>> targetShape = SmallVector<int64_t>(1, 1);
+  if (!targetShape) {
+    llvm::errs() << "--no unrolling target shape defined\n";
+    return std::nullopt;
+  }
+  auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
+  if (!maybeShapeRatio) {
+    llvm::errs() << "--could not compute integral shape ratio -> BAIL\n";
+    return std::nullopt;
+  }
+  if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
+    llvm::errs() << "--no unrolling needed -> SKIP\n";
+    return std::nullopt;
+  }
+  llvm::errs() << "--found an integral shape ratio to unroll to -> SUCCESS\n";
+  return targetShape;
+}
+
+LogicalResult
+FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
+                                             PatternRewriter &rewriter) const {
+  auto fnType = funcOp.getFunctionType();
+
+  auto newFuncOp =
+      rewriter.create<func::FuncOp>(funcOp.getLoc(), funcOp.getName(), fnType);
+
+  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+                              newFuncOp.end());
+
+  newFuncOp.dump();
+
+  OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
+  Location loc = newFuncOp.getBody().getLoc();
+  rewriter.setInsertionPointToStart(&newFuncOp.getBody().getBlocks().front());
+  SmallVector<size_t> unrolledInputNums;
+  size_t newInputNo = 0;
+
+  // Enumerate through the arguments.
+  for (const auto &argType : enumerate(fnType.getInputs())) {
+    size_t origInputNo = argType.index();
+    Type origType = argType.value();
+    auto origVecType = llvm::dyn_cast<VectorType>(origType);
+    if (!origVecType) {
+      oneToNTypeMapping.addInputs(origInputNo, origType);
+      newInputNo++;
+      continue;
+    }
+    llvm::errs() << "Try vector unrolling\n";
+    SmallVector<int64_t> nativeShape(1, 1);
+    auto options = vector::UnrollVectorOptions().setNativeShape(nativeShape);
+    auto targetShape = getTargetShape(origVecType);
+    if (!targetShape) {
+      llvm::errs() << "No target shape\n";
+      oneToNTypeMapping.addInputs(origInputNo, origType);
+      newInputNo++;
+      continue;
+    }
+    llvm::errs() << "Got target shape\n";
+    VectorType unrolledType =
+        VectorType::get(*targetShape, origVecType.getElementType());
+    llvm::errs() << "Unrolled type is ";
+    unrolledType.dump();
+    SmallVector<int64_t> originalShape =
+        llvm::to_vector<4>(origVecType.getShape());
+    SmallVector<Type> newTypes;
+    // Prepare the result vector
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, origVecType, rewriter.getZeroAttr(origVecType));
+    result.dump();
+    // Prepare the placeholder
+    Value dummy = rewriter.create<arith::ConstantOp>(
+        loc, unrolledType, rewriter.getZeroAttr(unrolledType));
+    dummy.dump();
+    SmallVector<int64_t> strides(targetShape->size(), 1);
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalShape, *targetShape)) {
+      result = rewriter.create<vector::InsertStridedSliceOp>(loc, dummy, result,
+                                                             offsets, strides);
+      result.dump();
+      newTypes.push_back(unrolledType);
+      unrolledInputNums.push_back(newInputNo);
+      newInputNo++;
+    }
+    rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+    oneToNTypeMapping.addInputs(origInputNo, newTypes);
+  }
+
+  llvm::errs() << "After enumerating through the arguments\n";
+  newFuncOp->dump();
+
+  // Assume there is a single result for now.
+  Type originalResultType = fnType.getResult(0);
+
+  // Change function signature
+  auto newFnType = FunctionType::get(
+      rewriter.getContext(), TypeRange(oneToNTypeMapping.getConvertedTypes()),
+      TypeRange(originalResultType));
+  rewriter.modifyOpInPlace(newFuncOp,
+                           [&] { newFuncOp.setFunctionType(newFnType); });
+  llvm::errs() << "After changing function signature\n";
+  newFuncOp->dump();
+
+  Block &entryBlock = newFuncOp.getBlocks().front();
+
+  // Update the arguments in the entry block.
+  entryBlock.eraseArguments(0, fnType.getNumInputs());
+  SmallVector<Location> locs(oneToNTypeMapping.getConvertedTypes().size(),
+                             newFuncOp.getLoc());
+  entryBlock.addArguments(oneToNTypeMapping.getConvertedTypes(), locs);
+
+  llvm::errs() << "After modifying the entry block\n";
+  newFuncOp->dump();
+
+  size_t i = 0;
+  // Relace the dummy values with actual arguments.
+  for (auto &op : entryBlock.getOperations()) {
+    op.dump();
+    auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op);
+    if (vecOp) {
+      size_t unrolledInputNo = unrolledInputNums[i];
+      rewriter.modifyOpInPlace(
+          &op, [&] { op.setOperand(0, newFuncOp.getArgument(unrolledInputNo)); });
+      i++;
+    }
+  }
+
+  rewriter.eraseOp(funcOp);
+  return success();
+}
+
+void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
+  patterns.add<FuncOpVectorTypesConversion>(patterns.getContext());
+}
+
 //===----------------------------------------------------------------------===//
 // Builtin Variables
 //===----------------------------------------------------------------------===//

>From 63744d1555a745d21a3691d9fa6e0319d4b2374b Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Mon, 8 Jul 2024 15:29:52 +0000
Subject: [PATCH 03/14] Implement function result and ReturnOp vector unrolling

---
 .../SPIRV/Transforms/SPIRVConversion.h        |   2 +
 .../ConvertToSPIRV/ConvertToSPIRVPass.cpp     |  12 ++
 .../SPIRV/Transforms/SPIRVConversion.cpp      | 132 +++++++++++++++---
 3 files changed, 128 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 1206603edcb6d..112c404527927 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -138,6 +138,8 @@ void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
 
 void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns);
 
+void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns);
+
 namespace spirv {
 class AccessChainOp;
 
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index adb903d3f448c..cd1344569e503 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -121,6 +121,18 @@ struct ConvertToSPIRVPass final
       llvm::errs() << "Finish unrolling function inputs\n";
     }
 
+    // Unroll vectors in function outputs to native vector size.
+    {
+      llvm::errs() << "Start unrolling function outputs\n";
+      RewritePatternSet patterns(context);
+      populateReturnOpVectorRewritePatterns(patterns);
+      GreedyRewriteConfig config;
+      config.strictMode = GreedyRewriteStrictness::ExistingOps;
+      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+        return signalPassFailure();
+      llvm::errs() << "Finish unrolling function inputs\n";
+    }
+
     SPIRVTypeConverter typeConverter(targetAttr);
 
     // Unroll vectors to native vector size.
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 616eb6104b705..37a8071cbf9b6 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -22,8 +22,10 @@
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/OneToNTypeConversion.h"
+#include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
@@ -863,17 +865,18 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
                                              PatternRewriter &rewriter) const {
   auto fnType = funcOp.getFunctionType();
 
+  // First create a new func op and copy the function body.
   auto newFuncOp =
       rewriter.create<func::FuncOp>(funcOp.getLoc(), funcOp.getName(), fnType);
-
   rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
                               newFuncOp.end());
-
+  llvm::errs() << "After creating new func op and copying the function body\n";
   newFuncOp.dump();
 
-  OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
   Location loc = newFuncOp.getBody().getLoc();
-  rewriter.setInsertionPointToStart(&newFuncOp.getBody().getBlocks().front());
+  Block &entryBlock = newFuncOp.getBlocks().front();
+  rewriter.setInsertionPointToStart(&entryBlock);
+  OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
   SmallVector<size_t> unrolledInputNums;
   size_t newInputNo = 0;
 
@@ -928,21 +931,17 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
   }
 
   llvm::errs() << "After enumerating through the arguments\n";
-  newFuncOp->dump();
-
-  // Assume there is a single result for now.
-  Type originalResultType = fnType.getResult(0);
+  newFuncOp.dump();
 
-  // Change function signature
+  // Change function signature.
   auto newFnType = FunctionType::get(
       rewriter.getContext(), TypeRange(oneToNTypeMapping.getConvertedTypes()),
-      TypeRange(originalResultType));
+      TypeRange(fnType.getResults()));
   rewriter.modifyOpInPlace(newFuncOp,
                            [&] { newFuncOp.setFunctionType(newFnType); });
-  llvm::errs() << "After changing function signature\n";
-  newFuncOp->dump();
 
-  Block &entryBlock = newFuncOp.getBlocks().front();
+  llvm::errs() << "After changing function signature\n";
+  newFuncOp.dump();
 
   // Update the arguments in the entry block.
   entryBlock.eraseArguments(0, fnType.getNumInputs());
@@ -950,18 +949,19 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
                              newFuncOp.getLoc());
   entryBlock.addArguments(oneToNTypeMapping.getConvertedTypes(), locs);
 
-  llvm::errs() << "After modifying the entry block\n";
-  newFuncOp->dump();
+  llvm::errs() << "After updating the arguments in the entry block\n";
+  newFuncOp.dump();
 
-  size_t i = 0;
   // Relace the dummy values with actual arguments.
+  size_t i = 0;
   for (auto &op : entryBlock.getOperations()) {
     op.dump();
     auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op);
     if (vecOp) {
       size_t unrolledInputNo = unrolledInputNums[i];
-      rewriter.modifyOpInPlace(
-          &op, [&] { op.setOperand(0, newFuncOp.getArgument(unrolledInputNo)); });
+      rewriter.modifyOpInPlace(&op, [&] {
+        op.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
+      });
       i++;
     }
   }
@@ -974,6 +974,102 @@ void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
   patterns.add<FuncOpVectorTypesConversion>(patterns.getContext());
 }
 
+//===----------------------------------------------------------------------===//
+// func::ReturnOp Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A pattern for rewriting function signature and the return op to convert
+/// vectors to be of valid types.
+class ReturnOpVectorTypesConversion : public OpRewritePattern<func::ReturnOp> {
+public:
+  using OpRewritePattern<func::ReturnOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(func::ReturnOp returnOp,
+                                PatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult ReturnOpVectorTypesConversion::matchAndRewrite(
+    func::ReturnOp returnOp, PatternRewriter &rewriter) const {
+
+  func::FuncOp funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
+  if (!funcOp)
+    return failure();
+
+  auto fnType = funcOp.getFunctionType();
+  OneToNTypeMapping oneToNTypeMapping(fnType.getResults());
+  Location loc = returnOp.getLoc();
+  SmallVector<Value> newOperands;
+
+  // Enumerate through the results.
+  for (const auto &argType : enumerate(fnType.getResults())) {
+    size_t origResultNo = argType.index();
+    Type origType = argType.value();
+    auto origVecType = llvm::dyn_cast<VectorType>(origType);
+    if (!origVecType) {
+      oneToNTypeMapping.addInputs(origResultNo, origType);
+      newOperands.push_back(returnOp.getOperand(origResultNo));
+      continue;
+    }
+    llvm::errs() << "Try vector unrolling\n";
+    SmallVector<int64_t> nativeShape(1, 1);
+    auto options = vector::UnrollVectorOptions().setNativeShape(nativeShape);
+    auto targetShape = getTargetShape(origVecType);
+    if (!targetShape) {
+      llvm::errs() << "No target shape\n";
+      oneToNTypeMapping.addInputs(origResultNo, origType);
+      newOperands.push_back(returnOp.getOperand(origResultNo));
+      continue;
+    }
+    llvm::errs() << "Got target shape\n";
+    VectorType unrolledType =
+        VectorType::get(*targetShape, origVecType.getElementType());
+    llvm::errs() << "Unrolled type is ";
+    unrolledType.dump();
+    SmallVector<int64_t> originalShape =
+        llvm::to_vector<4>(origVecType.getShape());
+    SmallVector<Type> newTypes;
+    SmallVector<int64_t> strides(targetShape->size(), 1);
+    Value returnValue = returnOp.getOperand(0);
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalShape, *targetShape)) {
+      auto result = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, returnValue, offsets, *targetShape, strides);
+      result.dump();
+      newOperands.push_back(result);
+      newTypes.push_back(unrolledType);
+    }
+    oneToNTypeMapping.addInputs(origResultNo, newTypes);
+  }
+
+  llvm::errs() << "After enumerating through the arguments\n";
+  funcOp.dump();
+
+  for (auto operand : newOperands)
+    operand.dump();
+
+  // Change function signature.
+  auto newFnType =
+      FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
+                        TypeRange(oneToNTypeMapping.getConvertedTypes()));
+  rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setFunctionType(newFnType); });
+  llvm::errs() << "After changing function signature\n";
+  funcOp.dump();
+
+  // Replace the return op using the new operands.
+  rewriter.replaceOp(returnOp,
+                     rewriter.create<func::ReturnOp>(loc, newOperands));
+  llvm::errs() << "After replacing return op\n";
+  funcOp.dump();
+
+  return success();
+}
+
+void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) {
+  patterns.add<ReturnOpVectorTypesConversion>(patterns.getContext());
+}
+
 //===----------------------------------------------------------------------===//
 // Builtin Variables
 //===----------------------------------------------------------------------===//

>From ba04f4f8b60cbbe9cfb8769aa877f598467de5a0 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Mon, 8 Jul 2024 18:44:29 +0000
Subject: [PATCH 04/14] Compute the target shape based on original vector shape

---
 .../SPIRV/Transforms/SPIRVConversion.cpp       | 18 ++++++++++++------
 1 file changed, 12 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 37a8071cbf9b6..2af947e94a475 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -838,11 +838,20 @@ class FuncOpVectorTypesConversion : public OpRewritePattern<func::FuncOp> {
 };
 } // namespace
 
+static int getComputeVectorSize(int64_t size) {
+  for (int i : {4, 3, 2}) {
+    if (size % i == 0)
+      return i;
+  }
+  return 1;
+}
+
 static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
   llvm::errs() << "Get target shape\n";
   SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
   // TODO: This is hardcoded to unroll with size 1. Change this later
-  std::optional<SmallVector<int64_t>> targetShape = SmallVector<int64_t>(1, 1);
+  std::optional<SmallVector<int64_t>> targetShape =
+      SmallVector<int64_t>(1, getComputeVectorSize(vecType.getShape().back()));
   if (!targetShape) {
     llvm::errs() << "--no unrolling target shape defined\n";
     return std::nullopt;
@@ -870,12 +879,14 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
       rewriter.create<func::FuncOp>(funcOp.getLoc(), funcOp.getName(), fnType);
   rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
                               newFuncOp.end());
+  rewriter.eraseOp(funcOp);
   llvm::errs() << "After creating new func op and copying the function body\n";
   newFuncOp.dump();
 
   Location loc = newFuncOp.getBody().getLoc();
   Block &entryBlock = newFuncOp.getBlocks().front();
   rewriter.setInsertionPointToStart(&entryBlock);
+
   OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
   SmallVector<size_t> unrolledInputNums;
   size_t newInputNo = 0;
@@ -891,8 +902,6 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
       continue;
     }
     llvm::errs() << "Try vector unrolling\n";
-    SmallVector<int64_t> nativeShape(1, 1);
-    auto options = vector::UnrollVectorOptions().setNativeShape(nativeShape);
     auto targetShape = getTargetShape(origVecType);
     if (!targetShape) {
       llvm::errs() << "No target shape\n";
@@ -966,7 +975,6 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
     }
   }
 
-  rewriter.eraseOp(funcOp);
   return success();
 }
 
@@ -1013,8 +1021,6 @@ LogicalResult ReturnOpVectorTypesConversion::matchAndRewrite(
       continue;
     }
     llvm::errs() << "Try vector unrolling\n";
-    SmallVector<int64_t> nativeShape(1, 1);
-    auto options = vector::UnrollVectorOptions().setNativeShape(nativeShape);
     auto targetShape = getTargetShape(origVecType);
     if (!targetShape) {
       llvm::errs() << "No target shape\n";

>From 6e99d24e626aa8282340df52fc9e54605e08ca4d Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Mon, 8 Jul 2024 20:58:14 +0000
Subject: [PATCH 05/14] Fix bug in function output unrolling

---
 mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 2af947e94a475..99de9b8195d35 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1037,7 +1037,7 @@ LogicalResult ReturnOpVectorTypesConversion::matchAndRewrite(
         llvm::to_vector<4>(origVecType.getShape());
     SmallVector<Type> newTypes;
     SmallVector<int64_t> strides(targetShape->size(), 1);
-    Value returnValue = returnOp.getOperand(0);
+    Value returnValue = returnOp.getOperand(origResultNo);
     for (SmallVector<int64_t> offsets :
          StaticTileOffsetRange(originalShape, *targetShape)) {
       auto result = rewriter.create<vector::ExtractStridedSliceOp>(

>From 49b5a4b399f8e60eaafc631f4fda52f1b0469182 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Tue, 9 Jul 2024 15:19:07 +0000
Subject: [PATCH 06/14] Working for signatures with legal and illegal types

---
 .../ConvertToSPIRV/ConvertToSPIRVPass.cpp     | 21 ++++++----
 .../SPIRV/Transforms/SPIRVConversion.cpp      | 42 ++++++++++++++++++-
 2 files changed, 53 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index cd1344569e503..1c11076c4b5b9 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -39,7 +39,7 @@ using namespace mlir;
 // Vector Lowering
 //===----------------------------------------------------------------------===//
 
-int getComputeVectorSize(int64_t size) {
+static int getComputeVectorSize(int64_t size) {
   for (int i : {4, 3, 2}) {
     if (size % i == 0)
       return i;
@@ -110,28 +110,29 @@ struct ConvertToSPIRVPass final
         SPIRVConversionTarget::get(targetAttr);
 
     // Unroll vectors in function inputs to native vector size.
+    llvm::errs() << "Start unrolling function inputs\n";
     {
-      llvm::errs() << "Start unrolling function inputs\n";
       RewritePatternSet patterns(context);
       populateFuncOpVectorRewritePatterns(patterns);
       GreedyRewriteConfig config;
       config.strictMode = GreedyRewriteStrictness::ExistingOps;
       if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
         return signalPassFailure();
-      llvm::errs() << "Finish unrolling function inputs\n";
     }
+    llvm::errs() << "Finish unrolling function inputs\n";
+    op->dump();
 
     // Unroll vectors in function outputs to native vector size.
+    llvm::errs() << "Start unrolling function outputs\n";
     {
-      llvm::errs() << "Start unrolling function outputs\n";
       RewritePatternSet patterns(context);
       populateReturnOpVectorRewritePatterns(patterns);
       GreedyRewriteConfig config;
       config.strictMode = GreedyRewriteStrictness::ExistingOps;
       if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
         return signalPassFailure();
-      llvm::errs() << "Finish unrolling function inputs\n";
     }
+    llvm::errs() << "Finish unrolling function outputs\n";
 
     SPIRVTypeConverter typeConverter(targetAttr);
 
@@ -175,7 +176,8 @@ struct ConvertToSPIRVPass final
         return signalPassFailure();
     }
 
-    llvm::errs() << "After running canonicalization to cast away leading size-1 dimensions\n";
+    llvm::errs() << "After running canonicalization to cast away leading "
+                    "size-1 dimensions\n";
     op->dump();
 
     // Convert vector.extract_strided_slice into a chain of vector.extract and
@@ -194,7 +196,9 @@ struct ConvertToSPIRVPass final
         return signalPassFailure();
     }
 
-    llvm::errs() << "After converting vector.extract_strided_slice into a chain of vector.extract and then a chain of vector.insert ops\n";
+    llvm::errs()
+        << "After converting vector.extract_strided_slice into a chain of "
+           "vector.extract and then a chain of vector.insert ops\n";
     op->dump();
 
     // Run all sorts of canonicalization patterns to clean up again.
@@ -210,7 +214,8 @@ struct ConvertToSPIRVPass final
         return signalPassFailure();
     }
 
-    llvm::errs() << "After running canonicalization patterns to clean up again\n";
+    llvm::errs()
+        << "After running canonicalization patterns to clean up again\n";
     op->dump();
 
     RewritePatternSet patterns(context);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 99de9b8195d35..304e1f7756580 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
@@ -21,19 +22,24 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/OneToNTypeConversion.h"
 #include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/MathExtras.h"
 
+#include <cctype>
 #include <functional>
 #include <optional>
+#include <unordered_set>
 
 #define DEBUG_TYPE "mlir-spirv-conversion"
 
@@ -891,22 +897,35 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
   SmallVector<size_t> unrolledInputNums;
   size_t newInputNo = 0;
 
+  std::unordered_map<Operation *, size_t> tmpOps;
+  size_t newOpCount = 0;
+
   // Enumerate through the arguments.
   for (const auto &argType : enumerate(fnType.getInputs())) {
     size_t origInputNo = argType.index();
     Type origType = argType.value();
     auto origVecType = llvm::dyn_cast<VectorType>(origType);
     if (!origVecType) {
+      Value result = rewriter.create<arith::ConstantOp>(
+          loc, origType, rewriter.getZeroAttr(origType));
+      rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+      tmpOps.insert({result.getDefiningOp(), newInputNo});
       oneToNTypeMapping.addInputs(origInputNo, origType);
       newInputNo++;
+      newOpCount++;
       continue;
     }
     llvm::errs() << "Try vector unrolling\n";
     auto targetShape = getTargetShape(origVecType);
     if (!targetShape) {
       llvm::errs() << "No target shape\n";
+      Value result = rewriter.create<arith::ConstantOp>(
+          loc, origType, rewriter.getZeroAttr(origType));
+      rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+      tmpOps.insert({result.getDefiningOp(), newInputNo});
       oneToNTypeMapping.addInputs(origInputNo, origType);
       newInputNo++;
+      newOpCount++;
       continue;
     }
     llvm::errs() << "Got target shape\n";
@@ -921,10 +940,12 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
     Value result = rewriter.create<arith::ConstantOp>(
         loc, origVecType, rewriter.getZeroAttr(origVecType));
     result.dump();
+    newOpCount++;
     // Prepare the placeholder
     Value dummy = rewriter.create<arith::ConstantOp>(
         loc, unrolledType, rewriter.getZeroAttr(unrolledType));
     dummy.dump();
+    newOpCount++;
     SmallVector<int64_t> strides(targetShape->size(), 1);
     for (SmallVector<int64_t> offsets :
          StaticTileOffsetRange(originalShape, *targetShape)) {
@@ -934,6 +955,7 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
       newTypes.push_back(unrolledType);
       unrolledInputNums.push_back(newInputNo);
       newInputNo++;
+      newOpCount++;
     }
     rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
     oneToNTypeMapping.addInputs(origInputNo, newTypes);
@@ -961,10 +983,25 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
   llvm::errs() << "After updating the arguments in the entry block\n";
   newFuncOp.dump();
 
-  // Relace the dummy values with actual arguments.
+  // Replace the dummy values with actual arguments.
   size_t i = 0;
-  for (auto &op : entryBlock.getOperations()) {
+  for (auto pair : llvm::enumerate(entryBlock.getOperations())) {
+    size_t count = pair.index();
+    Operation &op = pair.value();
     op.dump();
+    for (auto pair : llvm::enumerate(op.getOperands())) {
+      Operation *operandOp = pair.value().getDefiningOp();
+      if (tmpOps.find(operandOp) != tmpOps.end()) {
+        rewriter.modifyOpInPlace(&op, [&] {
+          op.setOperand(pair.index(), newFuncOp.getArgument(tmpOps[operandOp]));
+        });
+        rewriter.eraseOp(operandOp);
+        count++;
+        continue;
+      }
+    }
+    if (count == newOpCount)
+      continue;
     auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op);
     if (vecOp) {
       size_t unrolledInputNo = unrolledInputNums[i];
@@ -973,6 +1010,7 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
       });
       i++;
     }
+    count++;
   }
 
   return success();

>From b0fc3ab9439bc5b6092b1470d4642801c626dfaa Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Tue, 9 Jul 2024 20:13:56 +0000
Subject: [PATCH 07/14] Only keep the signature conversion, and refactor code

---
 .../Vector/Transforms/VectorRewritePatterns.h |   4 -
 .../ConvertToSPIRV/ConvertToSPIRVPass.cpp     | 152 +------------
 .../SPIRV/Transforms/SPIRVConversion.cpp      | 196 +++++++++--------
 .../Dialect/Vector/Transforms/CMakeLists.txt  |   1 -
 .../Vector/Transforms/VectorUnroll.cpp        | 201 ------------------
 5 files changed, 108 insertions(+), 446 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 5c06d6d4d6ad3..8e6d36f0b5f09 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -293,10 +293,6 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
                                   const UnrollVectorOptions &options,
                                   PatternBenefit benefit = 1);
 
-void populateVectorUnrollFuncSignaturePatterns(RewritePatternSet &patterns,
-                                               const UnrollVectorOptions &options,
-                                               PatternBenefit benefit = 1);
-
 /// Collect a set of vector.shape_cast folding patterns.
 void populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
                                       PatternBenefit benefit = 1);
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 1c11076c4b5b9..ddfbb0a76ad11 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -23,7 +23,6 @@
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/OneToNTypeConversion.h"
 #include <memory>
 
 #define DEBUG_TYPE "convert-to-spirv"
@@ -35,66 +34,6 @@ namespace mlir {
 
 using namespace mlir;
 
-//===----------------------------------------------------------------------===//
-// Vector Lowering
-//===----------------------------------------------------------------------===//
-
-static int getComputeVectorSize(int64_t size) {
-  for (int i : {4, 3, 2}) {
-    if (size % i == 0)
-      return i;
-  }
-  return 1;
-}
-
-SmallVector<int64_t> getNativeVectorShapeImpl(vector::MultiDimReductionOp op) {
-  // Unroll all reduction dimensions by size 1 for vector.multi_reduction.
-  VectorType srcVectorType = op.getSourceVectorType();
-  auto nativeSize = llvm::to_vector(srcVectorType.getShape());
-  auto dims = op.getReductionDims().getAsValueRange<IntegerAttr>();
-  for (const auto &dimAttr : dims) {
-    nativeSize[dimAttr.getZExtValue()] = 1;
-  }
-  return nativeSize;
-}
-
-SmallVector<int64_t> getNativeVectorShapeImpl(vector::ReductionOp op) {
-  VectorType srcVectorType = op.getSourceVectorType();
-  assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
-  int64_t vectorSize = getComputeVectorSize(srcVectorType.getDimSize(0));
-  return {vectorSize};
-}
-
-SmallVector<int64_t> getNativeVectorShapeImpl(vector::TransposeOp op) {
-  VectorType vectorType = op.getResultVectorType();
-  SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
-  nativeSize.back() = getComputeVectorSize(vectorType.getShape().back());
-  return nativeSize;
-}
-
-SmallVector<int64_t> getNativeVectorShapeImpl(vector::GatherOp op) {
-  VectorType vectorType = op.getVectorType();
-  SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
-  nativeSize.back() = getComputeVectorSize(vectorType.getShape().back());
-  return nativeSize;
-}
-
-std::optional<SmallVector<int64_t>> getNativeVectorShape(Operation *op) {
-  if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
-    if (auto vecType = llvm::dyn_cast<VectorType>(op->getResultTypes()[0])) {
-      SmallVector<int64_t> nativeSize(vecType.getRank(), 1);
-      nativeSize.back() = getComputeVectorSize(vecType.getShape().back());
-      return nativeSize;
-    }
-  }
-
-  return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)
-      .Case<vector::MultiDimReductionOp, vector::ReductionOp,
-            vector::TransposeOp, vector::GatherOp>(
-          [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
-      .Default([](Operation *) { return std::nullopt; });
-}
-
 namespace {
 
 /// A pass to perform the SPIR-V conversion.
@@ -105,10 +44,6 @@ struct ConvertToSPIRVPass final
     MLIRContext *context = &getContext();
     Operation *op = getOperation();
 
-    spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
-    std::unique_ptr<ConversionTarget> target =
-        SPIRVConversionTarget::get(targetAttr);
-
     // Unroll vectors in function inputs to native vector size.
     llvm::errs() << "Start unrolling function inputs\n";
     {
@@ -120,7 +55,6 @@ struct ConvertToSPIRVPass final
         return signalPassFailure();
     }
     llvm::errs() << "Finish unrolling function inputs\n";
-    op->dump();
 
     // Unroll vectors in function outputs to native vector size.
     llvm::errs() << "Start unrolling function outputs\n";
@@ -134,90 +68,10 @@ struct ConvertToSPIRVPass final
     }
     llvm::errs() << "Finish unrolling function outputs\n";
 
+    spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
+    std::unique_ptr<ConversionTarget> target =
+        SPIRVConversionTarget::get(targetAttr);
     SPIRVTypeConverter typeConverter(targetAttr);
-
-    // Unroll vectors to native vector size.
-    {
-      RewritePatternSet patterns(context);
-      auto options = vector::UnrollVectorOptions().setNativeShapeFn(
-          [=](auto op) { return getNativeVectorShape(op); });
-      populateVectorUnrollPatterns(patterns, options);
-      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
-        return signalPassFailure();
-    }
-
-    llvm::errs() << "After unrolling vectors to native vector size\n";
-    op->dump();
-
-    // Next run canonicalization to cast away leading size-1 dimensions.
-    {
-      RewritePatternSet patterns(context);
-
-      // We need to pull in casting way leading one dims to allow cancelling
-      // some read/write ops.
-      vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
-
-      // We may have vector.insert_strided_slice inserting 1-D native vectors
-      // into n-D larger vectors with the above. Break that down too. This is a
-      // companion transformation of unrolling.
-      vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
-          patterns);
-      vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
-
-      // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
-      // them up.
-      vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
-      vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
-
-      vector::TransferReadOp::getCanonicalizationPatterns(patterns, context);
-      vector::TransferWriteOp::getCanonicalizationPatterns(patterns, context);
-
-      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
-        return signalPassFailure();
-    }
-
-    llvm::errs() << "After running canonicalization to cast away leading "
-                    "size-1 dimensions\n";
-    op->dump();
-
-    // Convert vector.extract_strided_slice into a chain of vector.extract and
-    // then a chain of vector.insert ops. This helps to cancel with previous
-    // vector.insert/extract ops, especially for fP16 cases where we have
-    // mismatched vector size for transfer and compute.
-    {
-      RewritePatternSet patterns(context);
-      vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
-          patterns, [](vector::ExtractStridedSliceOp op) {
-            return op.getSourceVectorType().getNumElements() > 4;
-          });
-      vector::InsertOp::getCanonicalizationPatterns(patterns, context);
-      vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
-      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
-        return signalPassFailure();
-    }
-
-    llvm::errs()
-        << "After converting vector.extract_strided_slice into a chain of "
-           "vector.extract and then a chain of vector.insert ops\n";
-    op->dump();
-
-    // Run all sorts of canonicalization patterns to clean up again.
-    {
-      RewritePatternSet patterns(context);
-      vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
-      vector::InsertOp::getCanonicalizationPatterns(patterns, context);
-      vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
-      vector::TransferReadOp::getCanonicalizationPatterns(patterns, context);
-      vector::TransferWriteOp::getCanonicalizationPatterns(patterns, context);
-      vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
-      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
-        return signalPassFailure();
-    }
-
-    llvm::errs()
-        << "After running canonicalization patterns to clean up again\n";
-    op->dump();
-
     RewritePatternSet patterns(context);
     ScfToSPIRVContext scfToSPIRVContext;
 
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 304e1f7756580..6e793573f0262 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -27,19 +27,14 @@
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/OneToNTypeConversion.h"
-#include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/Sequence.h"
-#include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/MathExtras.h"
 
 #include <cctype>
-#include <functional>
 #include <optional>
-#include <unordered_set>
 
 #define DEBUG_TYPE "mlir-spirv-conversion"
 
@@ -49,6 +44,36 @@ using namespace mlir;
 // Utility functions
 //===----------------------------------------------------------------------===//
 
+static int getComputeVectorSize(int64_t size) {
+  for (int i : {4, 3, 2}) {
+    if (size % i == 0)
+      return i;
+  }
+  return 1;
+}
+
+static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
+  llvm::errs() << "Get target shape\n";
+  SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
+  std::optional<SmallVector<int64_t>> targetShape =
+      SmallVector<int64_t>(1, getComputeVectorSize(vecType.getShape().back()));
+  if (!targetShape) {
+    llvm::errs() << "--no unrolling target shape defined\n";
+    return std::nullopt;
+  }
+  auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
+  if (!maybeShapeRatio) {
+    llvm::errs() << "--could not compute integral shape ratio -> BAIL\n";
+    return std::nullopt;
+  }
+  if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
+    llvm::errs() << "--no unrolling needed -> SKIP\n";
+    return std::nullopt;
+  }
+  llvm::errs() << "--found an integral shape ratio to unroll to -> SUCCESS\n";
+  return targetShape;
+}
+
 /// Checks that `candidates` extension requirements are possible to be satisfied
 /// with the given `targetEnv`.
 ///
@@ -835,7 +860,7 @@ void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
 namespace {
 /// A pattern for rewriting function signature to convert vector arguments of
 /// functions to be of valid types
-class FuncOpVectorTypesConversion : public OpRewritePattern<func::FuncOp> {
+class FuncOpVectorUnroll : public OpRewritePattern<func::FuncOp> {
 public:
   using OpRewritePattern<func::FuncOp>::OpRewritePattern;
 
@@ -844,48 +869,17 @@ class FuncOpVectorTypesConversion : public OpRewritePattern<func::FuncOp> {
 };
 } // namespace
 
-static int getComputeVectorSize(int64_t size) {
-  for (int i : {4, 3, 2}) {
-    if (size % i == 0)
-      return i;
-  }
-  return 1;
-}
-
-static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
-  llvm::errs() << "Get target shape\n";
-  SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
-  // TODO: This is hardcoded to unroll with size 1. Change this later
-  std::optional<SmallVector<int64_t>> targetShape =
-      SmallVector<int64_t>(1, getComputeVectorSize(vecType.getShape().back()));
-  if (!targetShape) {
-    llvm::errs() << "--no unrolling target shape defined\n";
-    return std::nullopt;
-  }
-  auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
-  if (!maybeShapeRatio) {
-    llvm::errs() << "--could not compute integral shape ratio -> BAIL\n";
-    return std::nullopt;
-  }
-  if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
-    llvm::errs() << "--no unrolling needed -> SKIP\n";
-    return std::nullopt;
-  }
-  llvm::errs() << "--found an integral shape ratio to unroll to -> SUCCESS\n";
-  return targetShape;
-}
-
 LogicalResult
-FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
-                                             PatternRewriter &rewriter) const {
+FuncOpVectorUnroll::matchAndRewrite(func::FuncOp funcOp,
+                                    PatternRewriter &rewriter) const {
   auto fnType = funcOp.getFunctionType();
 
-  // First create a new func op and copy the function body.
+  // Create a new func op with the original type and copy the function body.
   auto newFuncOp =
       rewriter.create<func::FuncOp>(funcOp.getLoc(), funcOp.getName(), fnType);
   rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
                               newFuncOp.end());
-  rewriter.eraseOp(funcOp);
+
   llvm::errs() << "After creating new func op and copying the function body\n";
   newFuncOp.dump();
 
@@ -894,18 +888,30 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
   rewriter.setInsertionPointToStart(&entryBlock);
 
   OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
+
+  // For arguments that are of illegal types and require unrolling.
+  // `unrolledInputNums` stores the indices of arguments that result from
+  // unrolling in the new function signature. `newInputNo` is a counter.
   SmallVector<size_t> unrolledInputNums;
   size_t newInputNo = 0;
 
-  std::unordered_map<Operation *, size_t> tmpOps;
+  // For arguments that are of legal types and do not require unrolling.
+  // `tmpOps` stores a mapping from temporary operations that serve as
+  // placeholders for new arguments that will be added later. These operations
+  // will be erased once the entry block's argument list is updated.
+  DenseMap<Operation *, size_t> tmpOps;
+
+  // This counts the number of new operations created.
   size_t newOpCount = 0;
 
   // Enumerate through the arguments.
   for (const auto &argType : enumerate(fnType.getInputs())) {
     size_t origInputNo = argType.index();
     Type origType = argType.value();
+    // Check whether the argument is of vector type.
     auto origVecType = llvm::dyn_cast<VectorType>(origType);
     if (!origVecType) {
+      // We need a placeholder for the old argument that will be erased later.
       Value result = rewriter.create<arith::ConstantOp>(
           loc, origType, rewriter.getZeroAttr(origType));
       rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
@@ -915,10 +921,10 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
       newOpCount++;
       continue;
     }
-    llvm::errs() << "Try vector unrolling\n";
+    // Check whether the vector needs unrolling.
     auto targetShape = getTargetShape(origVecType);
     if (!targetShape) {
-      llvm::errs() << "No target shape\n";
+      // We need a placeholder for the old argument that will be erased later.
       Value result = rewriter.create<arith::ConstantOp>(
           loc, origType, rewriter.getZeroAttr(origType));
       rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
@@ -935,23 +941,23 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
     unrolledType.dump();
     SmallVector<int64_t> originalShape =
         llvm::to_vector<4>(origVecType.getShape());
-    SmallVector<Type> newTypes;
-    // Prepare the result vector
+
+    // Prepare the result vector.
     Value result = rewriter.create<arith::ConstantOp>(
         loc, origVecType, rewriter.getZeroAttr(origVecType));
-    result.dump();
     newOpCount++;
-    // Prepare the placeholder
+    // Prepare the placeholder for the new arguments that will be added later.
     Value dummy = rewriter.create<arith::ConstantOp>(
         loc, unrolledType, rewriter.getZeroAttr(unrolledType));
-    dummy.dump();
     newOpCount++;
+
+    // Create the `vector.insert_strided_slice` ops.
     SmallVector<int64_t> strides(targetShape->size(), 1);
+    SmallVector<Type> newTypes;
     for (SmallVector<int64_t> offsets :
          StaticTileOffsetRange(originalShape, *targetShape)) {
       result = rewriter.create<vector::InsertStridedSliceOp>(loc, dummy, result,
                                                              offsets, strides);
-      result.dump();
       newTypes.push_back(unrolledType);
       unrolledInputNums.push_back(newInputNo);
       newInputNo++;
@@ -964,10 +970,11 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
   llvm::errs() << "After enumerating through the arguments\n";
   newFuncOp.dump();
 
-  // Change function signature.
-  auto newFnType = FunctionType::get(
-      rewriter.getContext(), TypeRange(oneToNTypeMapping.getConvertedTypes()),
-      TypeRange(fnType.getResults()));
+  // Change the function signature.
+  auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
+  auto newFnType =
+      FunctionType::get(rewriter.getContext(), TypeRange(convertedTypes),
+                        TypeRange(fnType.getResults()));
   rewriter.modifyOpInPlace(newFuncOp,
                            [&] { newFuncOp.setFunctionType(newFnType); });
 
@@ -976,48 +983,52 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
 
   // Update the arguments in the entry block.
   entryBlock.eraseArguments(0, fnType.getNumInputs());
-  SmallVector<Location> locs(oneToNTypeMapping.getConvertedTypes().size(),
-                             newFuncOp.getLoc());
-  entryBlock.addArguments(oneToNTypeMapping.getConvertedTypes(), locs);
+  SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
+  entryBlock.addArguments(convertedTypes, locs);
 
   llvm::errs() << "After updating the arguments in the entry block\n";
   newFuncOp.dump();
 
-  // Replace the dummy values with actual arguments.
-  size_t i = 0;
-  for (auto pair : llvm::enumerate(entryBlock.getOperations())) {
-    size_t count = pair.index();
-    Operation &op = pair.value();
-    op.dump();
-    for (auto pair : llvm::enumerate(op.getOperands())) {
-      Operation *operandOp = pair.value().getDefiningOp();
-      if (tmpOps.find(operandOp) != tmpOps.end()) {
+  // Replace the placeholder values with the new arguments. We assume there is
+  // only one block for now.
+  size_t idx = 0;
+  for (auto opPair : llvm::enumerate(entryBlock.getOperations())) {
+    size_t count = opPair.index();
+    Operation &op = opPair.value();
+    // We first look for operands that are placeholders for initially legal
+    // arguments.
+    for (auto operandPair : llvm::enumerate(op.getOperands())) {
+      Operation *operandOp = operandPair.value().getDefiningOp();
+      if (tmpOps.find(operandOp) != tmpOps.end())
         rewriter.modifyOpInPlace(&op, [&] {
-          op.setOperand(pair.index(), newFuncOp.getArgument(tmpOps[operandOp]));
+          op.setOperand(operandPair.index(),
+                        newFuncOp.getArgument(tmpOps[operandOp]));
         });
-        rewriter.eraseOp(operandOp);
-        count++;
-        continue;
-      }
     }
-    if (count == newOpCount)
+    // Since all newly created operations are in the beginning, reaching the end
+    // of them means that any later `vector.insert_strided_slice` should not be
+    // touched.
+    if (count >= newOpCount)
       continue;
     auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op);
     if (vecOp) {
-      size_t unrolledInputNo = unrolledInputNums[i];
+      size_t unrolledInputNo = unrolledInputNums[idx];
       rewriter.modifyOpInPlace(&op, [&] {
         op.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
       });
-      i++;
+      idx++;
     }
     count++;
   }
 
+  // Erase the original funcOp. The `tmpOps` do not need to be erased since
+  // they have no uses and will be handled by dead-code elimination.
+  rewriter.eraseOp(funcOp);
   return success();
 }
 
 void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
-  patterns.add<FuncOpVectorTypesConversion>(patterns.getContext());
+  patterns.add<FuncOpVectorUnroll>(patterns.getContext());
 }
 
 //===----------------------------------------------------------------------===//
@@ -1027,7 +1038,7 @@ void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
 namespace {
 /// A pattern for rewriting function signature and the return op to convert
 /// vectors to be of valid types.
-class ReturnOpVectorTypesConversion : public OpRewritePattern<func::ReturnOp> {
+class ReturnOpVectorUnroll : public OpRewritePattern<func::ReturnOp> {
 public:
   using OpRewritePattern<func::ReturnOp>::OpRewritePattern;
 
@@ -1036,9 +1047,11 @@ class ReturnOpVectorTypesConversion : public OpRewritePattern<func::ReturnOp> {
 };
 } // namespace
 
-LogicalResult ReturnOpVectorTypesConversion::matchAndRewrite(
-    func::ReturnOp returnOp, PatternRewriter &rewriter) const {
+LogicalResult
+ReturnOpVectorUnroll::matchAndRewrite(func::ReturnOp returnOp,
+                                      PatternRewriter &rewriter) const {
 
+  // Check whether the parent funcOp is valid.
   func::FuncOp funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
   if (!funcOp)
     return failure();
@@ -1046,6 +1059,8 @@ LogicalResult ReturnOpVectorTypesConversion::matchAndRewrite(
   auto fnType = funcOp.getFunctionType();
   OneToNTypeMapping oneToNTypeMapping(fnType.getResults());
   Location loc = returnOp.getLoc();
+
+  // For the new return op.
   SmallVector<Value> newOperands;
 
   // Enumerate through the results.
@@ -1053,15 +1068,16 @@ LogicalResult ReturnOpVectorTypesConversion::matchAndRewrite(
     size_t origResultNo = argType.index();
     Type origType = argType.value();
     auto origVecType = llvm::dyn_cast<VectorType>(origType);
+    // Check whether the argument is of vector type.
     if (!origVecType) {
       oneToNTypeMapping.addInputs(origResultNo, origType);
       newOperands.push_back(returnOp.getOperand(origResultNo));
       continue;
     }
-    llvm::errs() << "Try vector unrolling\n";
+    // Check whether the vector needs unrolling.
     auto targetShape = getTargetShape(origVecType);
     if (!targetShape) {
-      llvm::errs() << "No target shape\n";
+      // The original argument can be used.
       oneToNTypeMapping.addInputs(origResultNo, origType);
       newOperands.push_back(returnOp.getOperand(origResultNo));
       continue;
@@ -1071,16 +1087,18 @@ LogicalResult ReturnOpVectorTypesConversion::matchAndRewrite(
         VectorType::get(*targetShape, origVecType.getElementType());
     llvm::errs() << "Unrolled type is ";
     unrolledType.dump();
+
+    // Create `vector.extract_strided_slice` ops to form legal vectors from the
+    // original operand of illegal type.
     SmallVector<int64_t> originalShape =
         llvm::to_vector<4>(origVecType.getShape());
-    SmallVector<Type> newTypes;
     SmallVector<int64_t> strides(targetShape->size(), 1);
+    SmallVector<Type> newTypes;
     Value returnValue = returnOp.getOperand(origResultNo);
     for (SmallVector<int64_t> offsets :
          StaticTileOffsetRange(originalShape, *targetShape)) {
-      auto result = rewriter.create<vector::ExtractStridedSliceOp>(
+      Value result = rewriter.create<vector::ExtractStridedSliceOp>(
           loc, returnValue, offsets, *targetShape, strides);
-      result.dump();
       newOperands.push_back(result);
       newTypes.push_back(unrolledType);
     }
@@ -1090,10 +1108,7 @@ LogicalResult ReturnOpVectorTypesConversion::matchAndRewrite(
   llvm::errs() << "After enumerating through the arguments\n";
   funcOp.dump();
 
-  for (auto operand : newOperands)
-    operand.dump();
-
-  // Change function signature.
+  // Change the function signature.
   auto newFnType =
       FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
                         TypeRange(oneToNTypeMapping.getConvertedTypes()));
@@ -1101,17 +1116,16 @@ LogicalResult ReturnOpVectorTypesConversion::matchAndRewrite(
   llvm::errs() << "After changing function signature\n";
   funcOp.dump();
 
-  // Replace the return op using the new operands.
+  // Replace the return op using the new operands. This will automatically
+  // update the entry block as well.
   rewriter.replaceOp(returnOp,
                      rewriter.create<func::ReturnOp>(loc, newOperands));
-  llvm::errs() << "After replacing return op\n";
-  funcOp.dump();
 
   return success();
 }
 
 void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) {
-  patterns.add<ReturnOpVectorTypesConversion>(patterns.getContext());
+  patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 1538c7eed6e76..723b2f62d65d4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -43,7 +43,6 @@ add_mlir_dialect_library(MLIRVectorTransforms
   MLIRMemRefDialect
   MLIRMemRefUtils
   MLIRSCFDialect
-  MLIRSPIRVDialect
   MLIRSideEffectInterfaces
   MLIRSubsetOpInterface
   MLIRTensorDialect
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index b63cb502b76e8..b3f558c3bac12 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -11,26 +11,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
-#include "mlir/IR/Block.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/TypeRange.h"
-#include "mlir/IR/Value.h"
-#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
-#include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/iterator_range.h"
-#include "llvm/IR/DerivedTypes.h"
-#include "llvm/Support/Casting.h"
 #include "llvm/Support/Debug.h"
 #include <numeric>
 #include <optional>
@@ -79,32 +65,6 @@ static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
                         resultTypes, op->getAttrs());
 }
 
-static std::optional<SmallVector<int64_t>>
-getTargetShape(const vector::UnrollVectorOptions &options, func::FuncOp funcOp,
-               VectorType vecType) {
-  assert(options.nativeShape &&
-         "vector unrolling expects the native shape or native"
-         "shape call back function to be set");
-  llvm::errs() << "Get target shape\n";
-  SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
-  std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(funcOp);
-  if (!targetShape) {
-    llvm::errs() << "--no unrolling target shape defined\n";
-    return std::nullopt;
-  }
-  auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
-  if (!maybeShapeRatio) {
-    llvm::errs() << "--could not compute integral shape ratio -> BAIL\n";
-    return std::nullopt;
-  }
-  if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
-    llvm::errs() << "--no unrolling needed -> SKIP\n";
-    return std::nullopt;
-  }
-  llvm::errs() << "--found an integral shape ratio to unroll to -> SUCCESS\n";
-  return targetShape;
-}
-
 /// Return the target shape for unrolling for the given `op`. Return
 /// std::nullopt if the op shouldn't be or cannot be unrolled.
 static std::optional<SmallVector<int64_t>>
@@ -657,160 +617,6 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
   vector::UnrollVectorOptions options;
 };
 
-struct UnrollFuncSignaturePattern : OpRewritePattern<func::FuncOp> {
-  UnrollFuncSignaturePattern(MLIRContext *context,
-                             const vector::UnrollVectorOptions &options,
-                             PatternBenefit benefit = 1)
-      : OpRewritePattern<func::FuncOp>(context, benefit), options(options) {}
-
-  LogicalResult matchAndRewrite(func::FuncOp funcOp,
-                                PatternRewriter &rewriter) const override {
-    llvm::errs() << "Run unroll function signature pattern\n";
-
-    auto fnType = funcOp.getFunctionType();
-
-    // Check function inputs.
-    Location loc = funcOp.getFunctionBody()
-                       .getBlocks()
-                       .begin()
-                       ->getOperations()
-                       .begin()
-                       ->getLoc();
-    size_t newArgIndex = 0;
-    std::vector<Type> newSignature;
-    std::vector<std::vector<size_t>> newArgMap(fnType.getNumInputs());
-
-    for (const auto &argType : enumerate(fnType.getInputs())) {
-      size_t index = argType.index();
-      Type type = argType.value();
-      auto vecType = llvm::dyn_cast<VectorType>(type);
-      if (!vecType) {
-        newSignature.push_back(type);
-        newArgMap[index].push_back(newArgIndex);
-        newArgIndex++;
-        continue;
-      }
-      // Try vector unrolling
-      llvm::errs() << "Try vector unrolling\n";
-      SmallVector<int64_t> originalShape =
-          llvm::to_vector<4>(vecType.getShape());
-      auto targetShape = getTargetShape(options, funcOp, vecType);
-      if (!targetShape) {
-        llvm::errs() << "No target shape\n";
-        newSignature.push_back(type);
-        newArgMap[index].push_back(newArgIndex);
-        newArgIndex++;
-        continue;
-      }
-      llvm::errs() << "Got target shape\n";
-      VectorType unrolledType =
-          VectorType::get(*targetShape, vecType.getElementType());
-      llvm::errs() << "Unrolled type is ";
-      unrolledType.dump();
-
-      for (SmallVector<int64_t> offsets :
-           StaticTileOffsetRange(originalShape, *targetShape)) {
-        newSignature.push_back(unrolledType);
-        newArgMap[index].push_back(newArgIndex);
-        newArgIndex++;
-      }
-    }
-
-    // Assume there is a single result for now.
-    Type originalResultType = fnType.getResult(0);
-
-    // TODO: Handle illegal vector types in results as well.
-    // SmallVector<Type> resultTypes;
-    // auto vecType = llvm::dyn_cast<VectorType>(originalResultType);
-
-    // if (vecType) {
-    //   // Try vector unrolling
-    //   SmallVector<int64_t> originalShape =
-    //   llvm::to_vector<4>(vecType.getShape()); auto targetShape =
-    //   getTargetShape(options, funcOp, vecType); VectorType unrolledType =
-    //     VectorType::get(*targetShape, vecType.getElementType());
-    //   if (targetShape)
-    //     for (SmallVector<int64_t> offsets :
-    //          StaticTileOffsetRange(originalShape, *targetShape))
-    //       resultTypes.push_back(unrolledType);
-    // }
-
-    // Create the converted func op
-    auto newFuncOp = rewriter.create<func::FuncOp>(
-        funcOp.getLoc(), funcOp.getName(),
-        FunctionType::get(rewriter.getContext(), TypeRange(newSignature),
-                          TypeRange(originalResultType)));
-
-    newFuncOp.addEntryBlock();
-
-    llvm::errs() << "Created new func op\n";
-    newFuncOp.dump();
-    llvm::errs() << newFuncOp.getArguments().size() << "\n";
-
-    // TODO: Copy over all attributes other than the function name and type
-
-    // Clone operations (assuming one block for now)
-    // TODO: The uses for operands that are SSA values are not cloned properly.
-    loc = newFuncOp.getBody().getLoc();
-    rewriter.setInsertionPointToStart(&newFuncOp.getBody().getBlocks().front());
-
-    for (auto &op : funcOp.getBlocks().front().getOperations()) {
-      op.dump();
-      SmallVector<Value> newOperands(op.getNumOperands());
-      for (size_t i = 0; i < op.getOperands().size(); ++i) {
-        Value operand = op.getOperand(i);
-        auto blockArg = llvm::dyn_cast<BlockArgument>(operand);
-        if (!blockArg) {
-          newOperands[i] = operand;
-          continue;
-        }
-        // Not unrolled
-        unsigned int argNum = blockArg.getArgNumber();
-        if (newArgMap[argNum].size() == 1) {
-          newOperands[i] = newFuncOp.getArgument(newArgMap[argNum][0]);
-          continue;
-        }
-        // Unrolled
-        // TODO: Store previously created vector.insert_strided_slice ops.
-        auto vecType = dyn_cast<VectorType>(blockArg.getType());
-        SmallVector<int64_t> originalShape =
-            llvm::to_vector<4>(vecType.getShape());
-        auto targetShape = getTargetShape(options, funcOp, vecType);
-        VectorType unrolledType =
-            VectorType::get(*targetShape, vecType.getElementType());
-        llvm::errs() << "Unrolled type is ";
-        unrolledType.dump();
-        // Prepare the result vector.
-        Value result = rewriter.create<arith::ConstantOp>(
-            loc, vecType, rewriter.getZeroAttr(vecType));
-        result.dump();
-        SmallVector<int64_t> strides(targetShape->size(), 1);
-        // Create the vector.insert_strided_slice ops.
-        unsigned int j = 0;
-        for (SmallVector<int64_t> offsets :
-             StaticTileOffsetRange(originalShape, *targetShape)) {
-          result = rewriter.create<vector::InsertStridedSliceOp>(
-              loc, newFuncOp.getArgument(newArgMap[argNum][j]), result, offsets,
-              strides);
-          result.dump();
-          j++;
-        }
-        newOperands[i] = result;
-      }
-      Operation *newOp =
-          rewriter.create(loc, op.getName().getIdentifier(), newOperands,
-                          op.getResultTypes(), op.getAttrs());
-      llvm::errs() << "newOp is ";
-      newOp->dump();
-    }
-    rewriter.eraseOp(funcOp);
-    return success();
-  }
-
-private:
-  vector::UnrollVectorOptions options;
-};
-
 } // namespace
 
 void mlir::vector::populateVectorUnrollPatterns(
@@ -822,10 +628,3 @@ void mlir::vector::populateVectorUnrollPatterns(
                UnrollTransposePattern, UnrollGatherPattern>(
       patterns.getContext(), options, benefit);
 }
-
-void mlir::vector::populateVectorUnrollFuncSignaturePatterns(
-    RewritePatternSet &patterns, const UnrollVectorOptions &options,
-    PatternBenefit benefit) {
-  patterns.add<UnrollFuncSignaturePattern>(patterns.getContext(), options,
-                                           benefit);
-}
\ No newline at end of file

>From a055178e070d458158808d17a1c50d5c59952030 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Wed, 10 Jul 2024 13:07:28 +0000
Subject: [PATCH 08/14] Add an option for testing signature conversion

---
 mlir/include/mlir/Conversion/Passes.td        |  5 ++
 .../ConvertToSPIRV/ConvertToSPIRVPass.cpp     | 49 ++++++++++---------
 2 files changed, 31 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 8d83343f5b736..598bba63a2a82 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -44,6 +44,11 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
     "spirv::SPIRVDialect",
     "vector::VectorDialect",
   ];
+  let options = [
+    Option<"runSignatureConversion", "run-signature-conversion", "bool",
+    /*default=*/"false",
+    "Run function signature conversion to convert vector types">
+  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index ddfbb0a76ad11..21a5a44ece92a 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -37,36 +37,39 @@ using namespace mlir;
 namespace {
 
 /// A pass to perform the SPIR-V conversion.
-struct ConvertToSPIRVPass final
-    : impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
+class ConvertToSPIRVPass
+    : public impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
+  using impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass>::ConvertToSPIRVPassBase;
 
   void runOnOperation() override {
     MLIRContext *context = &getContext();
     Operation *op = getOperation();
 
-    // Unroll vectors in function inputs to native vector size.
-    llvm::errs() << "Start unrolling function inputs\n";
-    {
-      RewritePatternSet patterns(context);
-      populateFuncOpVectorRewritePatterns(patterns);
-      GreedyRewriteConfig config;
-      config.strictMode = GreedyRewriteStrictness::ExistingOps;
-      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
-        return signalPassFailure();
-    }
-    llvm::errs() << "Finish unrolling function inputs\n";
+    if (runSignatureConversion) {
+      // Unroll vectors in function inputs to native vector size.
+      llvm::errs() << "Start unrolling function inputs\n";
+      {
+        RewritePatternSet patterns(context);
+        populateFuncOpVectorRewritePatterns(patterns);
+        GreedyRewriteConfig config;
+        config.strictMode = GreedyRewriteStrictness::ExistingOps;
+        if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+          return signalPassFailure();
+      }
+      llvm::errs() << "Finish unrolling function inputs\n";
 
-    // Unroll vectors in function outputs to native vector size.
-    llvm::errs() << "Start unrolling function outputs\n";
-    {
-      RewritePatternSet patterns(context);
-      populateReturnOpVectorRewritePatterns(patterns);
-      GreedyRewriteConfig config;
-      config.strictMode = GreedyRewriteStrictness::ExistingOps;
-      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
-        return signalPassFailure();
+      // Unroll vectors in function outputs to native vector size.
+      llvm::errs() << "Start unrolling function outputs\n";
+      {
+        RewritePatternSet patterns(context);
+        populateReturnOpVectorRewritePatterns(patterns);
+        GreedyRewriteConfig config;
+        config.strictMode = GreedyRewriteStrictness::ExistingOps;
+        if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+          return signalPassFailure();
+      }
+      llvm::errs() << "Finish unrolling function outputs\n";
     }
-    llvm::errs() << "Finish unrolling function outputs\n";
 
     spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
     std::unique_ptr<ConversionTarget> target =

>From 61cf2559a0f7f71b5bf39956b13deb57cf6aadca Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Wed, 10 Jul 2024 14:55:25 +0000
Subject: [PATCH 09/14] Add unit tests

---
 .../ConvertToSPIRV/ConvertToSPIRVPass.cpp     |  11 +-
 .../func-signature-vector-unroll.mlir         | 132 ++++++++++++++++++
 2 files changed, 140 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir

diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 21a5a44ece92a..88d7590c1daae 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -39,7 +39,8 @@ namespace {
 /// A pass to perform the SPIR-V conversion.
 class ConvertToSPIRVPass
     : public impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
-  using impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass>::ConvertToSPIRVPassBase;
+  using impl::ConvertToSPIRVPassBase<
+      ConvertToSPIRVPass>::ConvertToSPIRVPassBase;
 
   void runOnOperation() override {
     MLIRContext *context = &getContext();
@@ -53,7 +54,8 @@ class ConvertToSPIRVPass
         populateFuncOpVectorRewritePatterns(patterns);
         GreedyRewriteConfig config;
         config.strictMode = GreedyRewriteStrictness::ExistingOps;
-        if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+        if (failed(
+                applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
           return signalPassFailure();
       }
       llvm::errs() << "Finish unrolling function inputs\n";
@@ -65,10 +67,13 @@ class ConvertToSPIRVPass
         populateReturnOpVectorRewritePatterns(patterns);
         GreedyRewriteConfig config;
         config.strictMode = GreedyRewriteStrictness::ExistingOps;
-        if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+        if (failed(
+                applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
           return signalPassFailure();
       }
       llvm::errs() << "Finish unrolling function outputs\n";
+
+      return;
     }
 
     spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
diff --git a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
new file mode 100644
index 0000000000000..d5c777908d7e2
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
@@ -0,0 +1,132 @@
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion" -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @simple_scalar
+// CHECK-SAME: (%[[ARG0:.+]]: i32)
+func.func @simple_scalar(%arg0 : i32) -> i32 {
+  // CHECK: return %[[ARG0]] : i32
+  return %arg0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @simple_vector_4
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>)
+func.func @simple_vector_4(%arg0 : vector<4xi32>) -> vector<4xi32> {
+  // CHECK: return %[[ARG0]] : vector<4xi32>
+  return %arg0 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @simple_vector_5
+// CHECK-SAME: (%[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: vector<1xi32>, %[[ARG2:.+]]: vector<1xi32>, %[[ARG3:.+]]: vector<1xi32>, %[[ARG4:.+]]: vector<1xi32>)
+func.func @simple_vector_5(%arg0 : vector<5xi32>) -> vector<5xi32> {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<5xi32>
+  // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<1xi32> into vector<5xi32>
+  // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [1], strides = [1]} : vector<1xi32> into vector<5xi32>
+  // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [2], strides = [1]} : vector<1xi32> into vector<5xi32>
+  // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [3], strides = [1]} : vector<1xi32> into vector<5xi32>
+  // CHECK: %[[INSERT4:.*]] = vector.insert_strided_slice %[[ARG4]], %[[INSERT3]] {offsets = [4], strides = [1]} : vector<1xi32> into vector<5xi32>
+  // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [0], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
+  // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [1], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
+  // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [2], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
+  // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [3], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
+  // CHECK: %[[EXTRACT4:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [4], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
+  // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]], %[[EXTRACT2]], %[[EXTRACT3]], %[[EXTRACT4]] : vector<1xi32>, vector<1xi32>, vector<1xi32>, vector<1xi32>, vector<1xi32>
+  return %arg0 : vector<5xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @simple_vector_6
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>)
+func.func @simple_vector_6(%arg0 : vector<6xi32>) -> vector<6xi32> {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<6xi32>
+  // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<3xi32> into vector<6xi32>
+  // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [3], strides = [1]} : vector<3xi32> into vector<6xi32>
+  // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32>
+  // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32>
+  // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]] : vector<3xi32>, vector<3xi32>
+  return %arg0 : vector<6xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @simple_vector_8
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>)
+func.func @simple_vector_8(%arg0 : vector<8xi32>) -> vector<8xi32> {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+  // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+  // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+  // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]] : vector<4xi32>, vector<4xi32>
+  return %arg0 : vector<8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_6and8
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>)
+func.func @vector_6and8(%arg0 : vector<6xi32>, %arg1 : vector<8xi32>) -> (vector<6xi32>, vector<8xi32>) {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+  // CHECK: %[[CST0:.*]] = arith.constant dense<0> : vector<6xi32>
+  // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]] {offsets = [0], strides = [1]} : vector<3xi32> into vector<6xi32>
+  // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [3], strides = [1]} : vector<3xi32> into vector<6xi32>
+  // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32>
+  // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32>
+  // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+  // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+  // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]], %[[EXTRACT2]], %[[EXTRACT3]] : vector<3xi32>, vector<3xi32>, vector<4xi32>, vector<4xi32>
+  return %arg0, %arg1 : vector<6xi32>, vector<8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_3and8
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>)
+func.func @vector_3and8(%arg0 : vector<3xi32>, %arg1 : vector<8xi32>) -> (vector<3xi32>, vector<8xi32>) {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+  // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG1]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+  // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+  // CHECK: return %[[ARG0]], %[[EXTRACT0]], %[[EXTRACT1]] : vector<3xi32>, vector<4xi32>, vector<4xi32>
+  return %arg0, %arg1 : vector<3xi32>, vector<8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @scalar_vector
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<3xi32>, %[[ARG3:.+]]: i32)
+func.func @scalar_vector(%arg0 : vector<8xi32>, %arg1 : vector<3xi32>, %arg2 : i32) -> (vector<8xi32>, vector<3xi32>, i32) {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+  // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+  // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+  // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]], %[[ARG2]], %[[ARG3]] : vector<4xi32>, vector<4xi32>, vector<3xi32>, i32
+  return %arg0, %arg1, %arg2 : vector<8xi32>, vector<3xi32>, i32
+}
+
+// -----
+
+// CHECK-LABEL: @reduction
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>, %[[ARG4:.+]]: i32)
+func.func @reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>, %arg2 : i32) -> (i32) {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+  // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+  // CHECK: %[[ADDI:.*]] = arith.addi %[[INSERT1]], %[[INSERT3]] : vector<8xi32>
+  // CHECK: %[[REDUCTION:.*]] = vector.reduction <add>, %[[ADDI]] : vector<8xi32> into i32
+  // CHECK: %[[RET:.*]] = arith.addi %[[REDUCTION]], %[[ARG4]] : i32
+  // CHECK: return %[[RET]] : i32
+  %0 = arith.addi %arg0, %arg1 : vector<8xi32>
+  %1 = vector.reduction <add>, %0 : vector<8xi32> into i32
+  %2 = arith.addi %1, %arg2 : i32
+  return %2 : i32
+}

>From a37422c99a5ef6e072a44723a546d155b0462eb5 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Wed, 10 Jul 2024 15:50:34 +0000
Subject: [PATCH 10/14] Code formatting

---
 mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 112c404527927..9ad3d5fc85dd3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -19,8 +19,8 @@
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "llvm/ADT/SmallSet.h"
 #include "mlir/Transforms/OneToNTypeConversion.h"
+#include "llvm/ADT/SmallSet.h"
 
 namespace mlir {
 

>From fc237908b8f6e0dfd3feef2cfcc0a3353cb3ac28 Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Wed, 10 Jul 2024 13:21:33 -0400
Subject: [PATCH 11/14] Update
 mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 88d7590c1daae..57dc11c434176 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -39,8 +39,7 @@ namespace {
 /// A pass to perform the SPIR-V conversion.
 class ConvertToSPIRVPass
     : public impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
-  using impl::ConvertToSPIRVPassBase<
-      ConvertToSPIRVPass>::ConvertToSPIRVPassBase;
+  using impl::ConvertToSPIRVPassBase::ConvertToSPIRVPassBase;
 
   void runOnOperation() override {
     MLIRContext *context = &getContext();

>From a194ff067bf5ef4d0f44b6b7052490a4dca96b88 Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Wed, 10 Jul 2024 13:21:42 -0400
Subject: [PATCH 12/14] Update
 mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 6e793573f0262..1e2eb336372a1 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -909,7 +909,7 @@ FuncOpVectorUnroll::matchAndRewrite(func::FuncOp funcOp,
     size_t origInputNo = argType.index();
     Type origType = argType.value();
     // Check whether the argument is of vector type.
-    auto origVecType = llvm::dyn_cast<VectorType>(origType);
+    auto origVecType = dyn_cast<VectorType>(origType);
     if (!origVecType) {
       // We need a placeholder for the old argument that will be erased later.
       Value result = rewriter.create<arith::ConstantOp>(

>From 21077fc29a2861626540d9b2a58c5e2694fa3d71 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Wed, 10 Jul 2024 18:45:42 +0000
Subject: [PATCH 13/14] Run both patterns at the same time

---
 .../ConvertToSPIRV/ConvertToSPIRVPass.cpp     | 20 +++----------------
 1 file changed, 3 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 57dc11c434176..3ffaff76e566d 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -39,30 +39,18 @@ namespace {
 /// A pass to perform the SPIR-V conversion.
 class ConvertToSPIRVPass
     : public impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
-  using impl::ConvertToSPIRVPassBase::ConvertToSPIRVPassBase;
+  using impl::ConvertToSPIRVPassBase<
+      ConvertToSPIRVPass>::ConvertToSPIRVPassBase;
 
   void runOnOperation() override {
     MLIRContext *context = &getContext();
     Operation *op = getOperation();
 
     if (runSignatureConversion) {
-      // Unroll vectors in function inputs to native vector size.
-      llvm::errs() << "Start unrolling function inputs\n";
+      // Unroll vectors in function signatures to native vector size.
       {
         RewritePatternSet patterns(context);
         populateFuncOpVectorRewritePatterns(patterns);
-        GreedyRewriteConfig config;
-        config.strictMode = GreedyRewriteStrictness::ExistingOps;
-        if (failed(
-                applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
-          return signalPassFailure();
-      }
-      llvm::errs() << "Finish unrolling function inputs\n";
-
-      // Unroll vectors in function outputs to native vector size.
-      llvm::errs() << "Start unrolling function outputs\n";
-      {
-        RewritePatternSet patterns(context);
         populateReturnOpVectorRewritePatterns(patterns);
         GreedyRewriteConfig config;
         config.strictMode = GreedyRewriteStrictness::ExistingOps;
@@ -70,8 +58,6 @@ class ConvertToSPIRVPass
                 applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
           return signalPassFailure();
       }
-      llvm::errs() << "Finish unrolling function outputs\n";
-
       return;
     }
 

>From 7dac5b8d8c33fd47fc7f523470e76c3248a66b75 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Wed, 10 Jul 2024 19:50:36 +0000
Subject: [PATCH 14/14] Code refactoring and formatting

---
 mlir/include/mlir/Conversion/Passes.td        |   2 +-
 .../ConvertToSPIRV/ConvertToSPIRVPass.cpp     |   5 +-
 .../SPIRV/Transforms/SPIRVConversion.cpp      | 442 +++++++++---------
 .../test/Conversion/ConvertToSPIRV/arith.mlir |   2 +-
 .../Conversion/ConvertToSPIRV/combined.mlir   |   2 +-
 .../func-signature-vector-unroll.mlir         |   2 +-
 .../test/Conversion/ConvertToSPIRV/index.mlir |   2 +-
 mlir/test/Conversion/ConvertToSPIRV/scf.mlir  |   2 +-
 .../Conversion/ConvertToSPIRV/simple.mlir     |   2 +-
 mlir/test/Conversion/ConvertToSPIRV/ub.mlir   |   2 +-
 .../Conversion/ConvertToSPIRV/vector.mlir     |   2 +-
 11 files changed, 220 insertions(+), 245 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 598bba63a2a82..c4b9ff005919b 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -46,7 +46,7 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
   ];
   let options = [
     Option<"runSignatureConversion", "run-signature-conversion", "bool",
-    /*default=*/"false",
+    /*default=*/"true",
     "Run function signature conversion to convert vector types">
   ];
 }
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 3ffaff76e566d..d621318ca1d4f 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -37,10 +37,9 @@ using namespace mlir;
 namespace {
 
 /// A pass to perform the SPIR-V conversion.
-class ConvertToSPIRVPass
+struct ConvertToSPIRVPass
     : public impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
-  using impl::ConvertToSPIRVPassBase<
-      ConvertToSPIRVPass>::ConvertToSPIRVPassBase;
+  using ConvertToSPIRVPassBase::ConvertToSPIRVPassBase;
 
   void runOnOperation() override {
     MLIRContext *context = &getContext();
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 1e2eb336372a1..49bd72d3d55a8 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -53,24 +53,31 @@ static int getComputeVectorSize(int64_t size) {
 }
 
 static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
-  llvm::errs() << "Get target shape\n";
+  LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
+  if (vecType.isScalable()) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "--scalable vectors are not supported -> BAIL\n");
+    return std::nullopt;
+  }
   SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
   std::optional<SmallVector<int64_t>> targetShape =
       SmallVector<int64_t>(1, getComputeVectorSize(vecType.getShape().back()));
   if (!targetShape) {
-    llvm::errs() << "--no unrolling target shape defined\n";
+    LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
     return std::nullopt;
   }
   auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
   if (!maybeShapeRatio) {
-    llvm::errs() << "--could not compute integral shape ratio -> BAIL\n";
+    LLVM_DEBUG(llvm::dbgs()
+               << "--could not compute integral shape ratio -> BAIL\n");
     return std::nullopt;
   }
   if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
-    llvm::errs() << "--no unrolling needed -> SKIP\n";
+    LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n");
     return std::nullopt;
   }
-  llvm::errs() << "--found an integral shape ratio to unroll to -> SUCCESS\n";
+  LLVM_DEBUG(llvm::dbgs()
+             << "--found an integral shape ratio to unroll to -> SUCCESS\n");
   return targetShape;
 }
 
@@ -862,170 +869,153 @@ namespace {
 /// functions to be of valid types
 class FuncOpVectorUnroll : public OpRewritePattern<func::FuncOp> {
 public:
-  using OpRewritePattern<func::FuncOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(func::FuncOp funcOp,
-                                PatternRewriter &rewriter) const override;
-};
-} // namespace
-
-LogicalResult
-FuncOpVectorUnroll::matchAndRewrite(func::FuncOp funcOp,
-                                    PatternRewriter &rewriter) const {
-  auto fnType = funcOp.getFunctionType();
-
-  // Create a new func op with the original type and copy the function body.
-  auto newFuncOp =
-      rewriter.create<func::FuncOp>(funcOp.getLoc(), funcOp.getName(), fnType);
-  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
-                              newFuncOp.end());
-
-  llvm::errs() << "After creating new func op and copying the function body\n";
-  newFuncOp.dump();
-
-  Location loc = newFuncOp.getBody().getLoc();
-  Block &entryBlock = newFuncOp.getBlocks().front();
-  rewriter.setInsertionPointToStart(&entryBlock);
-
-  OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
-
-  // For arguments that are of illegal types and require unrolling.
-  // `unrolledInputNums` stores the indices of arguments that result from
-  // unrolling in the new function signature. `newInputNo` is a counter.
-  SmallVector<size_t> unrolledInputNums;
-  size_t newInputNo = 0;
-
-  // For arguments that are of legal types and do not require unrolling.
-  // `tmpOps` stores a mapping from temporary operations that serve as
-  // placeholders for new arguments that will be added later. These operations
-  // will be erased once the entry block's argument list is updated.
-  DenseMap<Operation *, size_t> tmpOps;
+                                PatternRewriter &rewriter) const override {
+    // TODO: Handle declarations.
+    if (funcOp.isDeclaration()) {
+      LLVM_DEBUG(llvm::dbgs() << " illegal: declarations are unsupported\n");
+      return failure();
+    }
 
-  // This counts the number of new operations created.
-  size_t newOpCount = 0;
+    FunctionType fnType = funcOp.getFunctionType();
+
+    // Create a new func op with the original type and copy the function body.
+    auto newFuncOp = rewriter.create<func::FuncOp>(funcOp.getLoc(),
+                                                   funcOp.getName(), fnType);
+    rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+                                newFuncOp.end());
+
+    Location loc = newFuncOp.getBody().getLoc();
+
+    Block &entryBlock = newFuncOp.getBlocks().front();
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(&entryBlock);
+
+    OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
+
+    // For arguments that are of illegal types and require unrolling.
+    // `unrolledInputNums` stores the indices of arguments that result from
+    // unrolling in the new function signature. `newInputNo` is a counter.
+    SmallVector<size_t> unrolledInputNums;
+    size_t newInputNo = 0;
+
+    // For arguments that are of legal types and do not require unrolling.
+    // `tmpOps` stores a mapping from temporary operations that serve as
+    // placeholders for new arguments that will be added later. These operations
+    // will be erased once the entry block's argument list is updated.
+    llvm::SmallDenseMap<Operation *, size_t> tmpOps;
+
+    // This counts the number of new operations created.
+    size_t newOpCount = 0;
+
+    // Enumerate through the arguments.
+    for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
+      // Check whether the argument is of vector type.
+      auto origVecType = dyn_cast<VectorType>(origType);
+      if (!origVecType) {
+        // We need a placeholder for the old argument that will be erased later.
+        Value result = rewriter.create<arith::ConstantOp>(
+            loc, origType, rewriter.getZeroAttr(origType));
+        rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+        tmpOps.insert({result.getDefiningOp(), newInputNo});
+        oneToNTypeMapping.addInputs(origInputNo, origType);
+        newInputNo++;
+        newOpCount++;
+        continue;
+      }
+      // Check whether the vector needs unrolling.
+      auto targetShape = getTargetShape(origVecType);
+      if (!targetShape) {
+        // We need a placeholder for the old argument that will be erased later.
+        Value result = rewriter.create<arith::ConstantOp>(
+            loc, origType, rewriter.getZeroAttr(origType));
+        rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+        tmpOps.insert({result.getDefiningOp(), newInputNo});
+        oneToNTypeMapping.addInputs(origInputNo, origType);
+        newInputNo++;
+        newOpCount++;
+        continue;
+      }
+      VectorType unrolledType =
+          VectorType::get(*targetShape, origVecType.getElementType());
+      SmallVector<int64_t> originalShape =
+          llvm::to_vector(origVecType.getShape());
 
-  // Enumerate through the arguments.
-  for (const auto &argType : enumerate(fnType.getInputs())) {
-    size_t origInputNo = argType.index();
-    Type origType = argType.value();
-    // Check whether the argument is of vector type.
-    auto origVecType = dyn_cast<VectorType>(origType);
-    if (!origVecType) {
-      // We need a placeholder for the old argument that will be erased later.
+      // Prepare the result vector.
       Value result = rewriter.create<arith::ConstantOp>(
-          loc, origType, rewriter.getZeroAttr(origType));
-      rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
-      tmpOps.insert({result.getDefiningOp(), newInputNo});
-      oneToNTypeMapping.addInputs(origInputNo, origType);
-      newInputNo++;
+          loc, origVecType, rewriter.getZeroAttr(origVecType));
       newOpCount++;
-      continue;
-    }
-    // Check whether the vector needs unrolling.
-    auto targetShape = getTargetShape(origVecType);
-    if (!targetShape) {
-      // We need a placeholder for the old argument that will be erased later.
-      Value result = rewriter.create<arith::ConstantOp>(
-          loc, origType, rewriter.getZeroAttr(origType));
-      rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
-      tmpOps.insert({result.getDefiningOp(), newInputNo});
-      oneToNTypeMapping.addInputs(origInputNo, origType);
-      newInputNo++;
-      newOpCount++;
-      continue;
-    }
-    llvm::errs() << "Got target shape\n";
-    VectorType unrolledType =
-        VectorType::get(*targetShape, origVecType.getElementType());
-    llvm::errs() << "Unrolled type is ";
-    unrolledType.dump();
-    SmallVector<int64_t> originalShape =
-        llvm::to_vector<4>(origVecType.getShape());
-
-    // Prepare the result vector.
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, origVecType, rewriter.getZeroAttr(origVecType));
-    newOpCount++;
-    // Prepare the placeholder for the new arguments that will be added later.
-    Value dummy = rewriter.create<arith::ConstantOp>(
-        loc, unrolledType, rewriter.getZeroAttr(unrolledType));
-    newOpCount++;
-
-    // Create the `vector.insert_strided_slice` ops.
-    SmallVector<int64_t> strides(targetShape->size(), 1);
-    SmallVector<Type> newTypes;
-    for (SmallVector<int64_t> offsets :
-         StaticTileOffsetRange(originalShape, *targetShape)) {
-      result = rewriter.create<vector::InsertStridedSliceOp>(loc, dummy, result,
-                                                             offsets, strides);
-      newTypes.push_back(unrolledType);
-      unrolledInputNums.push_back(newInputNo);
-      newInputNo++;
+      // Prepare the placeholder for the new arguments that will be added later.
+      Value dummy = rewriter.create<arith::ConstantOp>(
+          loc, unrolledType, rewriter.getZeroAttr(unrolledType));
       newOpCount++;
+
+      // Create the `vector.insert_strided_slice` ops.
+      SmallVector<int64_t> strides(targetShape->size(), 1);
+      SmallVector<Type> newTypes;
+      for (SmallVector<int64_t> offsets :
+           StaticTileOffsetRange(originalShape, *targetShape)) {
+        result = rewriter.create<vector::InsertStridedSliceOp>(
+            loc, dummy, result, offsets, strides);
+        newTypes.push_back(unrolledType);
+        unrolledInputNums.push_back(newInputNo);
+        newInputNo++;
+        newOpCount++;
+      }
+      rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+      oneToNTypeMapping.addInputs(origInputNo, newTypes);
     }
-    rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
-    oneToNTypeMapping.addInputs(origInputNo, newTypes);
-  }
-
-  llvm::errs() << "After enumerating through the arguments\n";
-  newFuncOp.dump();
-
-  // Change the function signature.
-  auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
-  auto newFnType =
-      FunctionType::get(rewriter.getContext(), TypeRange(convertedTypes),
-                        TypeRange(fnType.getResults()));
-  rewriter.modifyOpInPlace(newFuncOp,
-                           [&] { newFuncOp.setFunctionType(newFnType); });
-
-  llvm::errs() << "After changing function signature\n";
-  newFuncOp.dump();
-
-  // Update the arguments in the entry block.
-  entryBlock.eraseArguments(0, fnType.getNumInputs());
-  SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
-  entryBlock.addArguments(convertedTypes, locs);
-
-  llvm::errs() << "After updating the arguments in the entry block\n";
-  newFuncOp.dump();
-
-  // Replace the placeholder values with the new arguments. We assume there is
-  // only one block for now.
-  size_t idx = 0;
-  for (auto opPair : llvm::enumerate(entryBlock.getOperations())) {
-    size_t count = opPair.index();
-    Operation &op = opPair.value();
-    // We first look for operands that are placeholders for initially legal
-    // arguments.
-    for (auto operandPair : llvm::enumerate(op.getOperands())) {
-      Operation *operandOp = operandPair.value().getDefiningOp();
-      if (tmpOps.find(operandOp) != tmpOps.end())
+
+    // Change the function signature.
+    auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
+    auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
+    rewriter.modifyOpInPlace(newFuncOp,
+                             [&] { newFuncOp.setFunctionType(newFnType); });
+
+    // Update the arguments in the entry block.
+    entryBlock.eraseArguments(0, fnType.getNumInputs());
+    SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
+    entryBlock.addArguments(convertedTypes, locs);
+
+    // Replace the placeholder values with the new arguments. We assume there is
+    // only one block for now.
+    size_t idx = 0;
+    for (auto [count, op] : enumerate(entryBlock.getOperations())) {
+      // We first look for operands that are placeholders for initially legal
+      // arguments.
+      for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
+        Operation *operandOp = operandVal.getDefiningOp();
+        auto it = tmpOps.find(operandOp);
+        if (it != tmpOps.end())
+          rewriter.modifyOpInPlace(&op, [&] {
+            op.setOperand(operandIdx, newFuncOp.getArgument(it->second));
+          });
+      }
+      // Since all newly created operations are in the beginning, reaching the
+      // end of them means that any later `vector.insert_strided_slice` should
+      // not be touched.
+      if (count >= newOpCount)
+        continue;
+      auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op);
+      if (vecOp) {
+        size_t unrolledInputNo = unrolledInputNums[idx];
         rewriter.modifyOpInPlace(&op, [&] {
-          op.setOperand(operandPair.index(),
-                        newFuncOp.getArgument(tmpOps[operandOp]));
+          op.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
         });
+        idx++;
+      }
+      count++;
     }
-    // Since all newly created operations are in the beginning, reaching the end
-    // of them means that any later `vector.insert_strided_slice` should not be
-    // touched.
-    if (count >= newOpCount)
-      continue;
-    auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op);
-    if (vecOp) {
-      size_t unrolledInputNo = unrolledInputNums[idx];
-      rewriter.modifyOpInPlace(&op, [&] {
-        op.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
-      });
-      idx++;
-    }
-    count++;
-  }
 
-  // Erase the original funcOp. The `tmpOps` do not need to be erased since
-  // they have no uses and will be handled by dead-code elimination.
-  rewriter.eraseOp(funcOp);
-  return success();
-}
+    // Erase the original funcOp. The `tmpOps` do not need to be erased since
+    // they have no uses and will be handled by dead-code elimination.
+    rewriter.eraseOp(funcOp);
+    return success();
+  }
+};
+} // namespace
 
 void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
   patterns.add<FuncOpVectorUnroll>(patterns.getContext());
@@ -1040,89 +1030,75 @@ namespace {
 /// vectors to be of valid types.
 class ReturnOpVectorUnroll : public OpRewritePattern<func::ReturnOp> {
 public:
-  using OpRewritePattern<func::ReturnOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(func::ReturnOp returnOp,
-                                PatternRewriter &rewriter) const override;
-};
-} // namespace
-
-LogicalResult
-ReturnOpVectorUnroll::matchAndRewrite(func::ReturnOp returnOp,
-                                      PatternRewriter &rewriter) const {
-
-  // Check whether the parent funcOp is valid.
-  func::FuncOp funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
-  if (!funcOp)
-    return failure();
+                                PatternRewriter &rewriter) const override {
+    // Check whether the parent funcOp is valid.
+    auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
+    if (!funcOp)
+      return failure();
 
-  auto fnType = funcOp.getFunctionType();
-  OneToNTypeMapping oneToNTypeMapping(fnType.getResults());
-  Location loc = returnOp.getLoc();
-
-  // For the new return op.
-  SmallVector<Value> newOperands;
-
-  // Enumerate through the results.
-  for (const auto &argType : enumerate(fnType.getResults())) {
-    size_t origResultNo = argType.index();
-    Type origType = argType.value();
-    auto origVecType = llvm::dyn_cast<VectorType>(origType);
-    // Check whether the argument is of vector type.
-    if (!origVecType) {
-      oneToNTypeMapping.addInputs(origResultNo, origType);
-      newOperands.push_back(returnOp.getOperand(origResultNo));
-      continue;
-    }
-    // Check whether the vector needs unrolling.
-    auto targetShape = getTargetShape(origVecType);
-    if (!targetShape) {
-      // The original argument can be used.
-      oneToNTypeMapping.addInputs(origResultNo, origType);
-      newOperands.push_back(returnOp.getOperand(origResultNo));
-      continue;
-    }
-    llvm::errs() << "Got target shape\n";
-    VectorType unrolledType =
-        VectorType::get(*targetShape, origVecType.getElementType());
-    llvm::errs() << "Unrolled type is ";
-    unrolledType.dump();
-
-    // Create `vector.extract_strided_slice` ops to form legal vectors from the
-    // original operand of illegal type.
-    SmallVector<int64_t> originalShape =
-        llvm::to_vector<4>(origVecType.getShape());
-    SmallVector<int64_t> strides(targetShape->size(), 1);
-    SmallVector<Type> newTypes;
-    Value returnValue = returnOp.getOperand(origResultNo);
-    for (SmallVector<int64_t> offsets :
-         StaticTileOffsetRange(originalShape, *targetShape)) {
-      Value result = rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, returnValue, offsets, *targetShape, strides);
-      newOperands.push_back(result);
-      newTypes.push_back(unrolledType);
+    FunctionType fnType = funcOp.getFunctionType();
+    OneToNTypeMapping oneToNTypeMapping(fnType.getResults());
+    Location loc = returnOp.getLoc();
+
+    // For the new return op.
+    SmallVector<Value> newOperands;
+
+    // Enumerate through the results.
+    for (auto [origResultNo, origType] : enumerate(fnType.getResults())) {
+      // Check whether the argument is of vector type.
+      auto origVecType = dyn_cast<VectorType>(origType);
+      if (!origVecType) {
+        oneToNTypeMapping.addInputs(origResultNo, origType);
+        newOperands.push_back(returnOp.getOperand(origResultNo));
+        continue;
+      }
+      // Check whether the vector needs unrolling.
+      auto targetShape = getTargetShape(origVecType);
+      if (!targetShape) {
+        // The original argument can be used.
+        oneToNTypeMapping.addInputs(origResultNo, origType);
+        newOperands.push_back(returnOp.getOperand(origResultNo));
+        continue;
+      }
+      VectorType unrolledType =
+          VectorType::get(*targetShape, origVecType.getElementType());
+
+      // Create `vector.extract_strided_slice` ops to form legal vectors from
+      // the original operand of illegal type.
+      SmallVector<int64_t> originalShape =
+          llvm::to_vector<4>(origVecType.getShape());
+      SmallVector<int64_t> strides(targetShape->size(), 1);
+      SmallVector<Type> newTypes;
+      Value returnValue = returnOp.getOperand(origResultNo);
+      for (SmallVector<int64_t> offsets :
+           StaticTileOffsetRange(originalShape, *targetShape)) {
+        Value result = rewriter.create<vector::ExtractStridedSliceOp>(
+            loc, returnValue, offsets, *targetShape, strides);
+        newOperands.push_back(result);
+        newTypes.push_back(unrolledType);
+      }
+      oneToNTypeMapping.addInputs(origResultNo, newTypes);
     }
-    oneToNTypeMapping.addInputs(origResultNo, newTypes);
-  }
 
-  llvm::errs() << "After enumerating through the arguments\n";
-  funcOp.dump();
+    // Change the function signature.
+    auto newFnType =
+        FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
+                          TypeRange(oneToNTypeMapping.getConvertedTypes()));
+    rewriter.modifyOpInPlace(funcOp,
+                             [&] { funcOp.setFunctionType(newFnType); });
 
-  // Change the function signature.
-  auto newFnType =
-      FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
-                        TypeRange(oneToNTypeMapping.getConvertedTypes()));
-  rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setFunctionType(newFnType); });
-  llvm::errs() << "After changing function signature\n";
-  funcOp.dump();
+    // Replace the return op using the new operands. This will automatically
+    // update the entry block as well.
+    rewriter.replaceOp(returnOp,
+                       rewriter.create<func::ReturnOp>(loc, newOperands));
 
-  // Replace the return op using the new operands. This will automatically
-  // update the entry block as well.
-  rewriter.replaceOp(returnOp,
-                     rewriter.create<func::ReturnOp>(loc, newOperands));
-
-  return success();
-}
+    return success();
+  }
+};
+} // namespace
 
 void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) {
   patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
diff --git a/mlir/test/Conversion/ConvertToSPIRV/arith.mlir b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir
index a2adc0ad9c7a5..1a844a7cd018b 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/arith.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" -split-input-file %s | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // arithmetic ops
diff --git a/mlir/test/Conversion/ConvertToSPIRV/combined.mlir b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir
index 9e908465cb142..02b938be775a3 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/combined.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
 
 // CHECK-LABEL: @combined
 // CHECK: %[[C0_F32:.*]] = spirv.Constant 0.000000e+00 : f32
diff --git a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
index d5c777908d7e2..e248b0be18bc1 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv="run-signature-conversion" -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv -split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: @simple_scalar
 // CHECK-SAME: (%[[ARG0:.+]]: i32)
diff --git a/mlir/test/Conversion/ConvertToSPIRV/index.mlir b/mlir/test/Conversion/ConvertToSPIRV/index.mlir
index db747625bc7b3..e1cb18aac5d01 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/index.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/index.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-to-spirv | FileCheck %s
+// RUN: mlir-opt %s -convert-to-spirv="run-signature-conversion=false" | FileCheck %s
 
 // CHECK-LABEL: @basic
 func.func @basic(%a: index, %b: index) {
diff --git a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
index f619ca5771824..58ec6ac61f6ac 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
 
 // CHECK-LABEL: @if_yield
 // CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr<f32, Function>
diff --git a/mlir/test/Conversion/ConvertToSPIRV/simple.mlir b/mlir/test/Conversion/ConvertToSPIRV/simple.mlir
index 20b2a42bc3975..c5e0e6603d94a 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/simple.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
 
 // CHECK-LABEL: @return_scalar
 // CHECK-SAME: %[[ARG0:.*]]: i32
diff --git a/mlir/test/Conversion/ConvertToSPIRV/ub.mlir b/mlir/test/Conversion/ConvertToSPIRV/ub.mlir
index 66528b68f58cf..a83bfb6f405a0 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/ub.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/ub.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
 
 // CHECK-LABEL: @ub
 // CHECK: %[[UNDEF:.*]] = spirv.Undef : i32
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
index 336f0fe10c27e..c63dd030f4747 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -convert-to-spirv %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
 
 // CHECK-LABEL: @extract
 //  CHECK-SAME: %[[ARG:.+]]: vector<2xf32>



More information about the Mlir-commits mailing list