[Mlir-commits] [mlir] [mlir][spirv] Implement vector unrolling for `convert-to-spirv` pass (PR #100138)

Angel Zhang llvmlistbot at llvm.org
Tue Jul 23 13:09:39 PDT 2024


https://github.com/angelz913 updated https://github.com/llvm/llvm-project/pull/100138

>From b14e3fe50ef08979ad248382a34b10025649bffd Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Sun, 21 Jul 2024 15:26:32 +0000
Subject: [PATCH 1/4] [mlir][spirv] Fix function signature legalization for n-D
 vectors

---
 .../SPIRV/Transforms/SPIRVConversion.cpp      | 10 ++++-
 .../func-signature-vector-unroll.mlir         | 44 +++++++++++++++++++
 2 files changed, 52 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index bf5044437fd09..c146589612b5e 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1098,13 +1098,19 @@ struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
       // the original operand of illegal type.
       auto originalShape =
           llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
-      SmallVector<int64_t> strides(targetShape->size(), 1);
+      SmallVector<int64_t> strides(originalShape.size(), 1);
+      SmallVector<int64_t> extractShape(originalShape.size(), 1);
+      extractShape.back() = targetShape->back();
       SmallVector<Type> newTypes;
       Value returnValue = returnOp.getOperand(origResultNo);
       for (SmallVector<int64_t> offsets :
            StaticTileOffsetRange(originalShape, *targetShape)) {
         Value result = rewriter.create<vector::ExtractStridedSliceOp>(
-            loc, returnValue, offsets, *targetShape, strides);
+            loc, returnValue, offsets, extractShape, strides);
+        SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0);
+        if (originalShape.size() > 1)
+          result =
+              rewriter.create<vector::ExtractOp>(loc, result, extractIndices);
         newOperands.push_back(result);
         newTypes.push_back(unrolledType);
       }
diff --git a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
index 347d282f9ee0c..c018ccb924983 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
@@ -66,6 +66,28 @@ func.func @simple_vector_8(%arg0 : vector<8xi32>) -> vector<8xi32> {
 
 // -----
 
+// CHECK-LABEL: @simple_vector_2d
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>)
+func.func @simple_vector_2d(%arg0 : vector<4x4xi32>) -> vector<4x4xi32> {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4x4xi32>
+  // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+  // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [1, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+  // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [2, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+  // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [3, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+  // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+  // CHECK: %[[EXTRACT0_1:.*]] = vector.extract %[[EXTRACT0]][0] : vector<4xi32> from vector<1x4xi32>
+  // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+  // CHECK: %[[EXTRACT1_1:.*]] = vector.extract %[[EXTRACT1]][0] : vector<4xi32> from vector<1x4xi32>
+  // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [2, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+  // CHECK: %[[EXTRACT2_1:.*]] = vector.extract %[[EXTRACT2]][0] : vector<4xi32> from vector<1x4xi32>
+  // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [3, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+  // CHECK: %[[EXTRACT3_1:.*]] = vector.extract %[[EXTRACT3]][0] : vector<4xi32> from vector<1x4xi32>
+  // CHECK: return %[[EXTRACT0_1]], %[[EXTRACT1_1]], %[[EXTRACT2_1]], %[[EXTRACT3_1]] : vector<4xi32>, vector<4xi32>, vector<4xi32>, vector<4xi32>
+  return %arg0 : vector<4x4xi32>
+}
+
+// -----
+
 // CHECK-LABEL: @vector_6and8
 // CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>)
 func.func @vector_6and8(%arg0 : vector<6xi32>, %arg1 : vector<8xi32>) -> (vector<6xi32>, vector<8xi32>) {
@@ -113,6 +135,28 @@ func.func @scalar_vector(%arg0 : vector<8xi32>, %arg1 : vector<3xi32>, %arg2 : i
 
 // -----
 
+// CHECK-LABEL: @vector_2dand1d
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<3xi32>, %[[ARG3:.+]]: vector<3xi32>, %[[ARG4:.+]]: vector<4xi32>)
+func.func @vector_2dand1d(%arg0 : vector<2x6xi32>, %arg1 : vector<4xi32>) -> (vector<2x6xi32>, vector<4xi32>) {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x6xi32>
+  // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+  // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [0, 3], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+  // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [1, 0], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+  // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [1, 3], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+  // CHECK: %[[EXTRACT0:.*]]  = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+  // CHECK: %[[EXTRACT0_1:.*]]  = vector.extract %[[EXTRACT0]][0] : vector<3xi32> from vector<1x3xi32>
+  // CHECK: %[[EXTRACT1:.*]]  = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 3], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+  // CHECK: %[[EXTRACT1_1:.*]]  = vector.extract %[[EXTRACT1]][0] : vector<3xi32> from vector<1x3xi32>
+  // CHECK: %[[EXTRACT2:.*]]  = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+  // CHECK: %[[EXTRACT2_1:.*]]  = vector.extract %[[EXTRACT2]][0] : vector<3xi32> from vector<1x3xi32>
+  // CHECK: %[[EXTRACT3:.*]]  = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 3], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+  // CHECK: %[[EXTRACT3_1:.*]]  = vector.extract %[[EXTRACT3]][0] : vector<3xi32> from vector<1x3xi32>
+  // CHECK: return %[[EXTRACT0_1]], %[[EXTRACT1_1]], %[[EXTRACT2_1]], %[[EXTRACT3_1]], %[[ARG4]] : vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<4xi32>
+  return %arg0, %arg1 : vector<2x6xi32>, vector<4xi32>
+}
+
+// -----
+
 // CHECK-LABEL: @reduction
 // CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>, %[[ARG4:.+]]: i32)
 func.func @reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>, %arg2 : i32) -> (i32) {

>From c80e380068ccca536ec7d06f9e5062b97379f5df Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Mon, 22 Jul 2024 19:23:05 +0000
Subject: [PATCH 2/4] [mlir][spirv] Initial version of vector unrolling for
 convert-to-spirv pass

---
 mlir/include/mlir/Conversion/Passes.td        |   5 +-
 .../SPIRV/Transforms/SPIRVConversion.h        |  11 ++
 .../ConvertToSPIRV/ConvertToSPIRVPass.cpp     |  74 +++++++++++
 .../SPIRV/Transforms/SPIRVConversion.cpp      |  59 +++++++--
 .../test/Conversion/ConvertToSPIRV/arith.mlir |   2 +-
 .../Conversion/ConvertToSPIRV/combined.mlir   |   2 +-
 .../test/Conversion/ConvertToSPIRV/index.mlir |   2 +-
 mlir/test/Conversion/ConvertToSPIRV/scf.mlir  |   2 +-
 .../Conversion/ConvertToSPIRV/simple.mlir     |   2 +-
 mlir/test/Conversion/ConvertToSPIRV/ub.mlir   |   2 +-
 .../ConvertToSPIRV/vector-unroll.mlir         | 102 +++++++++++++++
 .../Conversion/ConvertToSPIRV/vector.mlir     |   2 +-
 .../Conversion/ConvertToSPIRV/CMakeLists.txt  |   2 +
 .../TestSPIRVVectorUnrolling.cpp              | 119 ++++++++++++++++++
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 +
 15 files changed, 370 insertions(+), 18 deletions(-)
 create mode 100644 mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
 create mode 100644 mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 748646e605827..b5bb2f42f2961 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -47,7 +47,10 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
   let options = [
     Option<"runSignatureConversion", "run-signature-conversion", "bool",
     /*default=*/"true",
-    "Run function signature conversion to convert vector types">
+    "Run function signature conversion to convert vector types">,
+    Option<"runVectorUnrolling", "run-vector-unrolling", "bool",
+    /*default=*/"true",
+    "Run vector unrolling to convert vector types in function bodies">
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 9ad3d5fc85dd3..195fbd0d0cd58 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -189,6 +189,17 @@ Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
                           MemRefType baseType, Value basePtr,
                           ValueRange indices, Location loc, OpBuilder &builder);
 
+int getComputeVectorSize(int64_t size);
+
+// GetNativeVectorShape implementation for reduction ops.
+SmallVector<int64_t> getNativeVectorShapeImpl(vector::ReductionOp op);
+
+// GetNativeVectorShape implementation for transpose ops.
+SmallVector<int64_t> getNativeVectorShapeImpl(vector::TransposeOp op);
+
+// For general ops.
+std::optional<SmallVector<int64_t>> getNativeVectorShape(Operation *op);
+
 } // namespace spirv
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 003a5feea9e9b..b82a244cfc973 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -17,6 +17,8 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
@@ -56,6 +58,78 @@ struct ConvertToSPIRVPass final
         return signalPassFailure();
     }
 
+    if (runVectorUnrolling) {
+
+      // Fold transpose ops if possible as we cannot unroll it later.
+      {
+        RewritePatternSet patterns(context);
+        vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
+        if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) {
+          return signalPassFailure();
+        }
+      }
+
+      // Unroll vectors to native vector size.
+      {
+        RewritePatternSet patterns(context);
+        auto options = vector::UnrollVectorOptions().setNativeShapeFn(
+            [=](auto op) { return mlir::spirv::getNativeVectorShape(op); });
+        populateVectorUnrollPatterns(patterns, options);
+        if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+          return signalPassFailure();
+      }
+
+      // Convert transpose ops into extract and insert pairs, in preparation
+      // of further transformations to canonicalize/cancel.
+      {
+        RewritePatternSet patterns(context);
+        auto options =
+            vector::VectorTransformsOptions().setVectorTransposeLowering(
+                vector::VectorTransposeLowering::EltWise);
+        vector::populateVectorTransposeLoweringPatterns(patterns, options);
+        vector::populateVectorShapeCastLoweringPatterns(patterns);
+        if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) {
+          return signalPassFailure();
+        }
+      }
+
+      // Run canonicalization to cast away leading size-1 dimensions.
+      {
+        RewritePatternSet patterns(context);
+
+        // Pull in casting way leading one dims to allow cancelling some
+        // read/write ops.
+        vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+        vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
+
+        // Decompose different rank insert_strided_slice and n-D
+        // extract_slided_slice.
+        vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
+            patterns);
+        vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
+
+        // Trimming leading unit dims may generate broadcast/shape_cast ops.
+        // Clean them up.
+        vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
+        vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
+
+        if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+          return signalPassFailure();
+      }
+
+      // Run all sorts of canonicalization patterns to clean up again.
+      {
+        RewritePatternSet patterns(context);
+        vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+        vector::InsertOp::getCanonicalizationPatterns(patterns, context);
+        vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
+        vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
+        vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
+        if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+          return signalPassFailure();
+      }
+    }
+
     spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
     std::unique_ptr<ConversionTarget> target =
         SPIRVConversionTarget::get(targetAttr);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index c146589612b5e..8470c7642e716 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -46,14 +46,6 @@ namespace {
 // Utility functions
 //===----------------------------------------------------------------------===//
 
-static int getComputeVectorSize(int64_t size) {
-  for (int i : {4, 3, 2}) {
-    if (size % i == 0)
-      return i;
-  }
-  return 1;
-}
-
 static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
   LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
   if (vecType.isScalable()) {
@@ -62,8 +54,8 @@ static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
     return std::nullopt;
   }
   SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
-  std::optional<SmallVector<int64_t>> targetShape =
-      SmallVector<int64_t>(1, getComputeVectorSize(vecType.getShape().back()));
+  std::optional<SmallVector<int64_t>> targetShape = SmallVector<int64_t>(
+      1, mlir::spirv::getComputeVectorSize(vecType.getShape().back()));
   if (!targetShape) {
     LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
     return std::nullopt;
@@ -1291,6 +1283,53 @@ Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter,
                              builder);
 }
 
+//===----------------------------------------------------------------------===//
+// Public functions for vector unrolling
+//===----------------------------------------------------------------------===//
+
+int mlir::spirv::getComputeVectorSize(int64_t size) {
+  for (int i : {4, 3, 2}) {
+    if (size % i == 0)
+      return i;
+  }
+  return 1;
+}
+
+SmallVector<int64_t>
+mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) {
+  VectorType srcVectorType = op.getSourceVectorType();
+  assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
+  int64_t vectorSize =
+      mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0));
+  return {vectorSize};
+}
+
+SmallVector<int64_t>
+mlir::spirv::getNativeVectorShapeImpl(vector::TransposeOp op) {
+  VectorType vectorType = op.getResultVectorType();
+  SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
+  nativeSize.back() =
+      mlir::spirv::getComputeVectorSize(vectorType.getShape().back());
+  return nativeSize;
+}
+
+std::optional<SmallVector<int64_t>>
+mlir::spirv::getNativeVectorShape(Operation *op) {
+  if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
+    if (auto vecType = llvm::dyn_cast<VectorType>(op->getResultTypes()[0])) {
+      SmallVector<int64_t> nativeSize(vecType.getRank(), 1);
+      nativeSize.back() =
+          mlir::spirv::getComputeVectorSize(vecType.getShape().back());
+      return nativeSize;
+    }
+  }
+
+  return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)
+      .Case<vector::ReductionOp, vector::TransposeOp>(
+          [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
+      .Default([](Operation *) { return std::nullopt; });
+}
+
 //===----------------------------------------------------------------------===//
 // SPIR-V TypeConverter
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ConvertToSPIRV/arith.mlir b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir
index 1a844a7cd018b..6418e931f7460 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/arith.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // arithmetic ops
diff --git a/mlir/test/Conversion/ConvertToSPIRV/combined.mlir b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir
index 02b938be775a3..311174bef15ed 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/combined.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: @combined
 // CHECK: %[[C0_F32:.*]] = spirv.Constant 0.000000e+00 : f32
diff --git a/mlir/test/Conversion/ConvertToSPIRV/index.mlir b/mlir/test/Conversion/ConvertToSPIRV/index.mlir
index e1cb18aac5d01..f4b116849fa93 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/index.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/index.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-to-spirv="run-signature-conversion=false" | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: @basic
 func.func @basic(%a: index, %b: index) {
diff --git a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
index 58ec6ac61f6ac..246464928b81c 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: @if_yield
 // CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr<f32, Function>
diff --git a/mlir/test/Conversion/ConvertToSPIRV/simple.mlir b/mlir/test/Conversion/ConvertToSPIRV/simple.mlir
index c5e0e6603d94a..00556140c3018 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/simple.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: @return_scalar
 // CHECK-SAME: %[[ARG0:.*]]: i32
diff --git a/mlir/test/Conversion/ConvertToSPIRV/ub.mlir b/mlir/test/Conversion/ConvertToSPIRV/ub.mlir
index a83bfb6f405a0..f34ca01c94f00 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/ub.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/ub.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: @ub
 // CHECK: %[[UNDEF:.*]] = spirv.Undef : i32
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
new file mode 100644
index 0000000000000..54d9875002cb5
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
@@ -0,0 +1,102 @@
+// RUN: mlir-opt -test-spirv-vector-unrolling -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @vaddi
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<3xi32>, %[[ARG3:.+]]: vector<3xi32>)
+func.func @vaddi(%arg0 : vector<6xi32>, %arg1 : vector<6xi32>) -> (vector<6xi32>) {
+  // CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG2]] : vector<3xi32>
+  // CHECK: %[[ADD1:.*]] = arith.addi %[[ARG1]], %[[ARG3]] : vector<3xi32>
+  // CHECK: return %[[ADD0]], %[[ADD1]] : vector<3xi32>, vector<3xi32>
+  %0 = arith.addi %arg0, %arg1 : vector<6xi32>
+  return %0 : vector<6xi32>
+}
+
+// CHECK-LABEL: @vaddi_2d
+// CHECK-SAME: (%[[ARG0:.+]]: vector<2xi32>, %[[ARG1:.+]]: vector<2xi32>, %[[ARG2:.+]]: vector<2xi32>, %[[ARG3:.+]]: vector<2xi32>)
+func.func @vaddi_2d(%arg0 : vector<2x2xi32>, %arg1 : vector<2x2xi32>) -> (vector<2x2xi32>) {
+  // CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG2]] : vector<2xi32>
+  // CHECK: %[[ADD1:.*]] = arith.addi %[[ARG1]], %[[ARG3]] : vector<2xi32>
+  // CHECK: return %[[ADD0]], %[[ADD1]] : vector<2xi32>, vector<2xi32>
+  %0 = arith.addi %arg0, %arg1 : vector<2x2xi32>
+  return %0 : vector<2x2xi32>
+}
+
+// CHECK-LABEL: @vaddi_2d_8
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>, %[[ARG4:.+]]: vector<4xi32>, %[[ARG5:.+]]: vector<4xi32>, %[[ARG6:.+]]: vector<4xi32>, %[[ARG7:.+]]: vector<4xi32>)
+func.func @vaddi_2d_8(%arg0 : vector<2x8xi32>, %arg1 : vector<2x8xi32>) -> (vector<2x8xi32>) {
+  // CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG4]] : vector<4xi32>
+  // CHECK: %[[ADD1:.*]] = arith.addi %[[ARG1]], %[[ARG5]] : vector<4xi32>
+  // CHECK: %[[ADD2:.*]] = arith.addi %[[ARG2]], %[[ARG6]] : vector<4xi32>
+  // CHECK: %[[ADD3:.*]] = arith.addi %[[ARG3]], %[[ARG7]] : vector<4xi32>
+  // CHECK: return %[[ADD0]], %[[ADD1]], %[[ADD2]], %[[ADD3]] : vector<4xi32>, vector<4xi32>, vector<4xi32>, vector<4xi32>
+  %0 = arith.addi %arg0, %arg1 : vector<2x8xi32>
+  return %0 : vector<2x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @reduction_5
+// CHECK-SAME: (%[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: vector<1xi32>, %[[ARG2:.+]]: vector<1xi32>, %[[ARG3:.+]]: vector<1xi32>, %[[ARG4:.+]]: vector<1xi32>)
+func.func @reduction_5(%arg0 : vector<5xi32>) -> (i32) {
+  // CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : i32 from vector<1xi32>
+  // CHECK: %[[EXTRACT1:.*]] = vector.extract %[[ARG1]][0] : i32 from vector<1xi32>
+  // CHECK: %[[ADD0:.*]] = arith.addi %[[EXTRACT0]], %[[EXTRACT1]] : i32
+  // CHECK: %[[EXTRACT2:.*]] = vector.extract %[[ARG2]][0] : i32 from vector<1xi32>
+  // CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[EXTRACT2]] : i32
+  // CHECK: %[[EXTRACT3:.*]] = vector.extract %[[ARG3]][0] : i32 from vector<1xi32>
+  // CHECK: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[EXTRACT3]] : i32
+  // CHECK: %[[EXTRACT4:.*]] = vector.extract %[[ARG4]][0] : i32 from vector<1xi32>
+  // CHECK: %[[ADD3:.*]] = arith.addi %[[ADD2]], %[[EXTRACT4]] : i32
+  // CHECK: return %[[ADD3]] : i32
+  %0 = vector.reduction <add>, %arg0 : vector<5xi32> into i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: @reduction_8
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>)
+func.func @reduction_8(%arg0 : vector<8xi32>) -> (i32) {
+  // CHECK: %[[REDUCTION0:.*]] = vector.reduction <add>, %[[ARG0]] : vector<4xi32> into i32
+  // CHECK: %[[REDUCTION1:.*]] = vector.reduction <add>, %[[ARG1]] : vector<4xi32> into i32
+  // CHECK: %[[ADD:.*]] = arith.addi %[[REDUCTION0]], %[[REDUCTION1]] : i32
+  // CHECK: return %[[ADD]] : i32
+  %0 = vector.reduction <add>, %arg0 : vector<8xi32> into i32
+  return %0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @vaddi_reduction
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>)
+func.func @vaddi_reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>) -> (i32) {
+  // CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG2]] : vector<4xi32>
+  // CHECK: %[[ADD1:.*]] = arith.addi %[[ARG1]], %[[ARG3]] : vector<4xi32>
+  // CHECK: %[[REDUCTION0:.*]] = vector.reduction <add>, %[[ADD0]] : vector<4xi32> into i32
+  // CHECK: %[[REDUCTION1:.*]] = vector.reduction <add>, %[[ADD1]] : vector<4xi32> into i32
+  // CHECK: %[[ADD2:.*]] = arith.addi %[[REDUCTION0]], %[[REDUCTION1]] : i32
+  // CHECK: return %[[ADD2]] : i32
+  %0 = arith.addi %arg0, %arg1 : vector<8xi32>
+  %1 = vector.reduction <add>, %0 : vector<8xi32> into i32
+  return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @transpose
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>)
+func.func @transpose(%arg0 : vector<2x3xi32>) -> (vector<3x2xi32>) {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2xi32>
+  // CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : i32 from vector<3xi32>
+  // CHECK: %[[INSERT0:.*]]= vector.insert %[[EXTRACT0]], %[[CST]] [0] : i32 into vector<2xi32>
+  // CHECK: %[[EXTRACT1:.*]] = vector.extract %[[ARG1]][0] : i32 from vector<3xi32>
+  // CHECK: %[[INSERT1:.*]] = vector.insert %[[EXTRACT1]], %[[INSERT0]][1] : i32 into vector<2xi32>
+  // CHECK: %[[EXTRACT2:.*]] = vector.extract %[[ARG0]][1] : i32 from vector<3xi32>
+  // CHECK: %[[INSERT2:.*]] = vector.insert %[[EXTRACT2]], %[[CST]] [0] : i32 into vector<2xi32>
+  // CHECK: %[[EXTRACT3:.*]] = vector.extract %[[ARG1]][1] : i32 from vector<3xi32>
+  // CHECK: %[[INSERT3:.*]] = vector.insert %[[EXTRACT3]], %[[INSERT2]] [1] : i32 into vector<2xi32>
+  // CHECK: %[[EXTRACT4:.*]] = vector.extract %[[ARG0]][2] : i32 from vector<3xi32>
+  // CHECK: %[[INSERT4:.*]] = vector.insert %[[EXTRACT4]], %[[CST]] [0] : i32 into vector<2xi32>
+  // CHECK: %[[EXTRACT5:.*]] = vector.extract %[[ARG1]][2] : i32 from vector<3xi32>
+  // CHECK: %[[INSERT5:.*]] = vector.insert %[[EXTRACT5]], %[[INSERT4]] [1] : i32 into vector<2xi32>
+  // CHECK: return %[[INSERT1]], %[[INSERT3]], %[[INSERT5]] : vector<2xi32>, vector<2xi32>, vector<2xi32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
+  return %0 : vector<3x2xi32>
+}
\ No newline at end of file
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
index c63dd030f4747..e369eadca5730 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: @extract
 //  CHECK-SAME: %[[ARG:.+]]: vector<2xf32>
diff --git a/mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt b/mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
index 69b5787f7e851..aeade52c7ade5 100644
--- a/mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
+++ b/mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
@@ -1,6 +1,7 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRTestConvertToSPIRV
   TestSPIRVFuncSignatureConversion.cpp
+  TestSPIRVVectorUnrolling.cpp
 
   EXCLUDE_FROM_LIBMLIR
 
@@ -13,4 +14,5 @@ add_mlir_library(MLIRTestConvertToSPIRV
   MLIRTransformUtils
   MLIRTransforms
   MLIRVectorDialect
+  MLIRVectorTransforms
   )
diff --git a/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp b/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp
new file mode 100644
index 0000000000000..fbf7933b244fa
--- /dev/null
+++ b/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp
@@ -0,0 +1,119 @@
+//===- TestSPIRVVectorUnrolling.cpp - Test signature conversion -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===-------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace {
+
+struct TestSPIRVVectorUnrolling final
+    : PassWrapper<TestSPIRVVectorUnrolling, OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSPIRVVectorUnrolling)
+
+  StringRef getArgument() const final { return "test-spirv-vector-unrolling"; }
+
+  StringRef getDescription() const final {
+    return "Test patterns that unroll vectors to types supported by SPIR-V";
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<spirv::SPIRVDialect, vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    Operation *op = getOperation();
+
+    // Unroll vectors in function signatures to native vector size.
+    {
+      RewritePatternSet patterns(context);
+      populateFuncOpVectorRewritePatterns(patterns);
+      populateReturnOpVectorRewritePatterns(patterns);
+      GreedyRewriteConfig config;
+      config.strictMode = GreedyRewriteStrictness::ExistingOps;
+      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+        return signalPassFailure();
+    }
+
+    // Unroll vectors to native vector size.
+    {
+      RewritePatternSet patterns(context);
+      auto options = vector::UnrollVectorOptions().setNativeShapeFn(
+          [=](auto op) { return mlir::spirv::getNativeVectorShape(op); });
+      populateVectorUnrollPatterns(patterns, options);
+      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+        return signalPassFailure();
+    }
+
+    // Convert transpose ops into extract and insert pairs, in preparation of
+    // further transformations to canonicalize/cancel.
+    {
+      RewritePatternSet patterns(context);
+      auto options =
+          vector::VectorTransformsOptions().setVectorTransposeLowering(
+              vector::VectorTransposeLowering::EltWise);
+      vector::populateVectorTransposeLoweringPatterns(patterns, options);
+      vector::populateVectorShapeCastLoweringPatterns(patterns);
+      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) {
+        return signalPassFailure();
+      }
+    }
+
+    // Run canonicalization to cast away leading size-1 dimensions.
+    {
+      RewritePatternSet patterns(context);
+
+      // We need to pull in casting way leading one dims.
+      vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+      vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
+
+      // Decompose different rank insert_strided_slice and n-D
+      // extract_slided_slice.
+      vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
+          patterns);
+      vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
+
+      // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
+      // them up.
+      vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
+      vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
+
+      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+        return signalPassFailure();
+    }
+
+    // Run all sorts of canonicalization patterns to clean up again.
+    {
+      RewritePatternSet patterns(context);
+      vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+      vector::InsertOp::getCanonicalizationPatterns(patterns, context);
+      vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
+      vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
+      vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
+      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+        return signalPassFailure();
+    }
+  }
+};
+
+} // namespace
+
+namespace test {
+void registerTestSPIRVVectorUnrolling() {
+  PassRegistration<TestSPIRVVectorUnrolling>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 149f9d59961b8..0f29963da39bb 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -142,6 +142,7 @@ void registerTestSCFWrapInZeroTripCheckPasses();
 void registerTestShapeMappingPass();
 void registerTestSliceAnalysisPass();
 void registerTestSPIRVFuncSignatureConversion();
+void registerTestSPIRVVectorUnrolling();
 void registerTestTensorCopyInsertionPass();
 void registerTestTensorTransforms();
 void registerTestTopologicalSortAnalysisPass();
@@ -275,6 +276,7 @@ void registerTestPasses() {
   mlir::test::registerTestShapeMappingPass();
   mlir::test::registerTestSliceAnalysisPass();
   mlir::test::registerTestSPIRVFuncSignatureConversion();
+  mlir::test::registerTestSPIRVVectorUnrolling();
   mlir::test::registerTestTensorCopyInsertionPass();
   mlir::test::registerTestTensorTransforms();
   mlir::test::registerTestTopologicalSortAnalysisPass();

>From 8915f1e2954b58d5ce5bc4b6dba9e19fa828f574 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Tue, 23 Jul 2024 18:10:30 +0000
Subject: [PATCH 3/4] Formatting and documentation

---
 mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h | 2 ++
 mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp    | 1 -
 mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir       | 2 +-
 3 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 195fbd0d0cd58..1236afe552a9c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -189,6 +189,8 @@ Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
                           MemRefType baseType, Value basePtr,
                           ValueRange indices, Location loc, OpBuilder &builder);
 
+// Find the largest factor of size among {2,3,4} for the lowest dimension of
+// the target shape.
 int getComputeVectorSize(int64_t size);
 
 // GetNativeVectorShape implementation for reduction ops.
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index b82a244cfc973..c4a11793b1b82 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -59,7 +59,6 @@ struct ConvertToSPIRVPass final
     }
 
     if (runVectorUnrolling) {
-
       // Fold transpose ops if possible as we cannot unroll it later.
       {
         RewritePatternSet patterns(context);
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
index 54d9875002cb5..043f9422d8790 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
@@ -99,4 +99,4 @@ func.func @transpose(%arg0 : vector<2x3xi32>) -> (vector<3x2xi32>) {
   // CHECK: return %[[INSERT1]], %[[INSERT3]], %[[INSERT5]] : vector<2xi32>, vector<2xi32>, vector<2xi32>
   %0 = vector.transpose %arg0, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
   return %0 : vector<3x2xi32>
-}
\ No newline at end of file
+}

>From 9c312f1dea3521606b33c2837deaf0ae1120ce1f Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Tue, 23 Jul 2024 19:46:28 +0000
Subject: [PATCH 4/4] Remove redundant patterns and refactor code

---
 .../SPIRV/Transforms/SPIRVConversion.h        |  8 ++
 .../ConvertToSPIRV/ConvertToSPIRVPass.cpp     | 88 ++-----------------
 .../SPIRV/Transforms/SPIRVConversion.cpp      | 66 ++++++++++++++
 .../TestSPIRVFuncSignatureConversion.cpp      |  9 +-
 .../TestSPIRVVectorUnrolling.cpp              | 73 +--------------
 5 files changed, 85 insertions(+), 159 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 1236afe552a9c..f54c93a09e727 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -19,8 +19,10 @@
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/OneToNTypeConversion.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/Support/LogicalResult.h"
 
 namespace mlir {
 
@@ -202,6 +204,12 @@ SmallVector<int64_t> getNativeVectorShapeImpl(vector::TransposeOp op);
 // For general ops.
 std::optional<SmallVector<int64_t>> getNativeVectorShape(Operation *op);
 
+// Unroll vectors in function signatures to native size.
+LogicalResult unrollVectorsInSignatures(Operation *op);
+
+// Unroll vectors in function bodies to native size.
+LogicalResult unrollVectorsInFuncBodies(Operation *op);
+
 } // namespace spirv
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index c4a11793b1b82..4694a147e1e94 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -44,90 +44,16 @@ struct ConvertToSPIRVPass final
   using ConvertToSPIRVPassBase::ConvertToSPIRVPassBase;
 
   void runOnOperation() override {
-    MLIRContext *context = &getContext();
     Operation *op = getOperation();
+    MLIRContext *context = &getContext();
 
-    if (runSignatureConversion) {
-      // Unroll vectors in function signatures to native vector size.
-      RewritePatternSet patterns(context);
-      populateFuncOpVectorRewritePatterns(patterns);
-      populateReturnOpVectorRewritePatterns(patterns);
-      GreedyRewriteConfig config;
-      config.strictMode = GreedyRewriteStrictness::ExistingOps;
-      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
-        return signalPassFailure();
-    }
-
-    if (runVectorUnrolling) {
-      // Fold transpose ops if possible as we cannot unroll it later.
-      {
-        RewritePatternSet patterns(context);
-        vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
-        if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) {
-          return signalPassFailure();
-        }
-      }
-
-      // Unroll vectors to native vector size.
-      {
-        RewritePatternSet patterns(context);
-        auto options = vector::UnrollVectorOptions().setNativeShapeFn(
-            [=](auto op) { return mlir::spirv::getNativeVectorShape(op); });
-        populateVectorUnrollPatterns(patterns, options);
-        if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
-          return signalPassFailure();
-      }
-
-      // Convert transpose ops into extract and insert pairs, in preparation
-      // of further transformations to canonicalize/cancel.
-      {
-        RewritePatternSet patterns(context);
-        auto options =
-            vector::VectorTransformsOptions().setVectorTransposeLowering(
-                vector::VectorTransposeLowering::EltWise);
-        vector::populateVectorTransposeLoweringPatterns(patterns, options);
-        vector::populateVectorShapeCastLoweringPatterns(patterns);
-        if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) {
-          return signalPassFailure();
-        }
-      }
-
-      // Run canonicalization to cast away leading size-1 dimensions.
-      {
-        RewritePatternSet patterns(context);
-
-        // Pull in casting way leading one dims to allow cancelling some
-        // read/write ops.
-        vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
-        vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
-
-        // Decompose different rank insert_strided_slice and n-D
-        // extract_slided_slice.
-        vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
-            patterns);
-        vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
-
-        // Trimming leading unit dims may generate broadcast/shape_cast ops.
-        // Clean them up.
-        vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
-        vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
-
-        if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
-          return signalPassFailure();
-      }
+    // Unroll vectors in function signatures to native size.
+    if (runSignatureConversion && failed(spirv::unrollVectorsInSignatures(op)))
+      return signalPassFailure();
 
-      // Run all sorts of canonicalization patterns to clean up again.
-      {
-        RewritePatternSet patterns(context);
-        vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
-        vector::InsertOp::getCanonicalizationPatterns(patterns, context);
-        vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
-        vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
-        vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
-        if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
-          return signalPassFailure();
-      }
-    }
+    // Unroll vectors in function bodies to native size.
+    if (runVectorUnrolling && failed(spirv::unrollVectorsInFuncBodies(op)))
+      return signalPassFailure();
 
     spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
     std::unique_ptr<ConversionTarget> target =
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 8470c7642e716..9f2cab30d56f3 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -20,17 +20,21 @@
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/OneToNTypeConversion.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/MathExtras.h"
 
 #include <functional>
@@ -1330,6 +1334,68 @@ mlir::spirv::getNativeVectorShape(Operation *op) {
       .Default([](Operation *) { return std::nullopt; });
 }
 
+LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) {
+  MLIRContext *context = op->getContext();
+  RewritePatternSet patterns(context);
+  populateFuncOpVectorRewritePatterns(patterns);
+  populateReturnOpVectorRewritePatterns(patterns);
+  GreedyRewriteConfig config;
+  config.strictMode = GreedyRewriteStrictness::ExistingOps;
+  return applyPatternsAndFoldGreedily(op, std::move(patterns), config);
+}
+
+LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
+  MLIRContext *context = op->getContext();
+
+  // Unroll vectors in function bodies to native vector size.
+  {
+    RewritePatternSet patterns(context);
+    auto options = vector::UnrollVectorOptions().setNativeShapeFn(
+        [](auto op) { return mlir::spirv::getNativeVectorShape(op); });
+    populateVectorUnrollPatterns(patterns, options);
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+      return llvm::failure();
+  }
+
+  // Convert transpose ops into extract and insert pairs, in preparation of
+  // further transformations to canonicalize/cancel.
+  {
+    RewritePatternSet patterns(context);
+    auto options = vector::VectorTransformsOptions().setVectorTransposeLowering(
+        vector::VectorTransposeLowering::EltWise);
+    vector::populateVectorTransposeLoweringPatterns(patterns, options);
+    vector::populateVectorShapeCastLoweringPatterns(patterns);
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+      return llvm::failure();
+  }
+
+  // Run canonicalization to cast away leading size-1 dimensions.
+  {
+    RewritePatternSet patterns(context);
+
+    // We need to pull in casting way leading one dims.
+    vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+    vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
+    vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
+
+    // Decompose different rank insert_strided_slice and n-D
+    // extract_slided_slice.
+    vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
+        patterns);
+    vector::InsertOp::getCanonicalizationPatterns(patterns, context);
+    vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
+
+    // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
+    // them up.
+    vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
+    vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
+
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+      return llvm::failure();
+  }
+  return llvm::success();
+}
+
 //===----------------------------------------------------------------------===//
 // SPIR-V TypeConverter
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp b/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp
index ec67f85f6f27b..4a792336caba4 100644
--- a/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp
+++ b/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp
@@ -37,13 +37,8 @@ struct TestSPIRVFuncSignatureConversion final
   }
 
   void runOnOperation() override {
-    RewritePatternSet patterns(&getContext());
-    populateFuncOpVectorRewritePatterns(patterns);
-    populateReturnOpVectorRewritePatterns(patterns);
-    GreedyRewriteConfig config;
-    config.strictMode = GreedyRewriteStrictness::ExistingOps;
-    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
-                                       config);
+    Operation *op = getOperation();
+    (void)spirv::unrollVectorsInSignatures(op);
   }
 };
 
diff --git a/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp b/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp
index fbf7933b244fa..0bad43d5214b1 100644
--- a/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp
+++ b/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp
@@ -34,78 +34,9 @@ struct TestSPIRVVectorUnrolling final
   }
 
   void runOnOperation() override {
-    MLIRContext *context = &getContext();
     Operation *op = getOperation();
-
-    // Unroll vectors in function signatures to native vector size.
-    {
-      RewritePatternSet patterns(context);
-      populateFuncOpVectorRewritePatterns(patterns);
-      populateReturnOpVectorRewritePatterns(patterns);
-      GreedyRewriteConfig config;
-      config.strictMode = GreedyRewriteStrictness::ExistingOps;
-      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
-        return signalPassFailure();
-    }
-
-    // Unroll vectors to native vector size.
-    {
-      RewritePatternSet patterns(context);
-      auto options = vector::UnrollVectorOptions().setNativeShapeFn(
-          [=](auto op) { return mlir::spirv::getNativeVectorShape(op); });
-      populateVectorUnrollPatterns(patterns, options);
-      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
-        return signalPassFailure();
-    }
-
-    // Convert transpose ops into extract and insert pairs, in preparation of
-    // further transformations to canonicalize/cancel.
-    {
-      RewritePatternSet patterns(context);
-      auto options =
-          vector::VectorTransformsOptions().setVectorTransposeLowering(
-              vector::VectorTransposeLowering::EltWise);
-      vector::populateVectorTransposeLoweringPatterns(patterns, options);
-      vector::populateVectorShapeCastLoweringPatterns(patterns);
-      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) {
-        return signalPassFailure();
-      }
-    }
-
-    // Run canonicalization to cast away leading size-1 dimensions.
-    {
-      RewritePatternSet patterns(context);
-
-      // We need to pull in casting way leading one dims.
-      vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
-      vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
-
-      // Decompose different rank insert_strided_slice and n-D
-      // extract_slided_slice.
-      vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
-          patterns);
-      vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
-
-      // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
-      // them up.
-      vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
-      vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
-
-      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
-        return signalPassFailure();
-    }
-
-    // Run all sorts of canonicalization patterns to clean up again.
-    {
-      RewritePatternSet patterns(context);
-      vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
-      vector::InsertOp::getCanonicalizationPatterns(patterns, context);
-      vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
-      vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
-      vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
-      if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
-        return signalPassFailure();
-    }
+    (void)spirv::unrollVectorsInSignatures(op);
+    (void)spirv::unrollVectorsInFuncBodies(op);
   }
 };
 



More information about the Mlir-commits mailing list