[Mlir-commits] [mlir] [mlir][spirv] Initial version of vector unrolling for `convert-to-spirv` pass (PR #100138)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 23 07:59:24 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-spirv
Author: Angel Zhang (angelz913)
<details>
<summary>Changes</summary>
### Description
This PR depends on #<!-- -->98337. It implements a minimal version of function body vector unrolling to convert vector types into 1D and with a size supported by SPIR-V (2, 3 or 4 depending on the original dimension). The ops that are currently supported include those with elementwise traits (e.g. `arith.addi`), `vector.reduction` and `vector.transpose`. This PR also includes new LIT tests that only check for vector unrolling.
### Future Plans
- Support more ops
---
Patch is 32.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/100138.diff
16 Files Affected:
- (modified) mlir/include/mlir/Conversion/Passes.td (+4-1)
- (modified) mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h (+11)
- (modified) mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp (+74)
- (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+57-12)
- (modified) mlir/test/Conversion/ConvertToSPIRV/arith.mlir (+1-1)
- (modified) mlir/test/Conversion/ConvertToSPIRV/combined.mlir (+1-1)
- (modified) mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir (+44)
- (modified) mlir/test/Conversion/ConvertToSPIRV/index.mlir (+1-1)
- (modified) mlir/test/Conversion/ConvertToSPIRV/scf.mlir (+1-1)
- (modified) mlir/test/Conversion/ConvertToSPIRV/simple.mlir (+1-1)
- (modified) mlir/test/Conversion/ConvertToSPIRV/ub.mlir (+1-1)
- (added) mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir (+102)
- (modified) mlir/test/Conversion/ConvertToSPIRV/vector.mlir (+1-1)
- (modified) mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt (+2)
- (added) mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp (+119)
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2)
``````````diff
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 748646e605827..b5bb2f42f2961 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -47,7 +47,10 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
let options = [
Option<"runSignatureConversion", "run-signature-conversion", "bool",
/*default=*/"true",
- "Run function signature conversion to convert vector types">
+ "Run function signature conversion to convert vector types">,
+ Option<"runVectorUnrolling", "run-vector-unrolling", "bool",
+ /*default=*/"true",
+ "Run vector unrolling to convert vector types in function bodies">
];
}
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 9ad3d5fc85dd3..195fbd0d0cd58 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -189,6 +189,17 @@ Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
MemRefType baseType, Value basePtr,
ValueRange indices, Location loc, OpBuilder &builder);
+int getComputeVectorSize(int64_t size);
+
+// GetNativeVectorShape implementation for reduction ops.
+SmallVector<int64_t> getNativeVectorShapeImpl(vector::ReductionOp op);
+
+// GetNativeVectorShape implementation for transpose ops.
+SmallVector<int64_t> getNativeVectorShapeImpl(vector::TransposeOp op);
+
+// For general ops.
+std::optional<SmallVector<int64_t>> getNativeVectorShape(Operation *op);
+
} // namespace spirv
} // namespace mlir
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 003a5feea9e9b..b82a244cfc973 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -17,6 +17,8 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
@@ -56,6 +58,78 @@ struct ConvertToSPIRVPass final
return signalPassFailure();
}
+ if (runVectorUnrolling) {
+
+ // Fold transpose ops if possible as we cannot unroll it later.
+ {
+ RewritePatternSet patterns(context);
+ vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+
+ // Unroll vectors to native vector size.
+ {
+ RewritePatternSet patterns(context);
+ auto options = vector::UnrollVectorOptions().setNativeShapeFn(
+ [=](auto op) { return mlir::spirv::getNativeVectorShape(op); });
+ populateVectorUnrollPatterns(patterns, options);
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+ return signalPassFailure();
+ }
+
+ // Convert transpose ops into extract and insert pairs, in preparation
+ // of further transformations to canonicalize/cancel.
+ {
+ RewritePatternSet patterns(context);
+ auto options =
+ vector::VectorTransformsOptions().setVectorTransposeLowering(
+ vector::VectorTransposeLowering::EltWise);
+ vector::populateVectorTransposeLoweringPatterns(patterns, options);
+ vector::populateVectorShapeCastLoweringPatterns(patterns);
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+
+ // Run canonicalization to cast away leading size-1 dimensions.
+ {
+ RewritePatternSet patterns(context);
+
+ // Pull in casting way leading one dims to allow cancelling some
+ // read/write ops.
+ vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+ vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
+
+ // Decompose different rank insert_strided_slice and n-D
+ // extract_slided_slice.
+ vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
+ patterns);
+ vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
+
+ // Trimming leading unit dims may generate broadcast/shape_cast ops.
+ // Clean them up.
+ vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
+ vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
+
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+ return signalPassFailure();
+ }
+
+ // Run all sorts of canonicalization patterns to clean up again.
+ {
+ RewritePatternSet patterns(context);
+ vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+ vector::InsertOp::getCanonicalizationPatterns(patterns, context);
+ vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
+ vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
+ vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+ return signalPassFailure();
+ }
+ }
+
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index bf5044437fd09..8470c7642e716 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -46,14 +46,6 @@ namespace {
// Utility functions
//===----------------------------------------------------------------------===//
-static int getComputeVectorSize(int64_t size) {
- for (int i : {4, 3, 2}) {
- if (size % i == 0)
- return i;
- }
- return 1;
-}
-
static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
if (vecType.isScalable()) {
@@ -62,8 +54,8 @@ static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
return std::nullopt;
}
SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
- std::optional<SmallVector<int64_t>> targetShape =
- SmallVector<int64_t>(1, getComputeVectorSize(vecType.getShape().back()));
+ std::optional<SmallVector<int64_t>> targetShape = SmallVector<int64_t>(
+ 1, mlir::spirv::getComputeVectorSize(vecType.getShape().back()));
if (!targetShape) {
LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
return std::nullopt;
@@ -1098,13 +1090,19 @@ struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
// the original operand of illegal type.
auto originalShape =
llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
- SmallVector<int64_t> strides(targetShape->size(), 1);
+ SmallVector<int64_t> strides(originalShape.size(), 1);
+ SmallVector<int64_t> extractShape(originalShape.size(), 1);
+ extractShape.back() = targetShape->back();
SmallVector<Type> newTypes;
Value returnValue = returnOp.getOperand(origResultNo);
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalShape, *targetShape)) {
Value result = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, returnValue, offsets, *targetShape, strides);
+ loc, returnValue, offsets, extractShape, strides);
+ SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0);
+ if (originalShape.size() > 1)
+ result =
+ rewriter.create<vector::ExtractOp>(loc, result, extractIndices);
newOperands.push_back(result);
newTypes.push_back(unrolledType);
}
@@ -1285,6 +1283,53 @@ Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter,
builder);
}
+//===----------------------------------------------------------------------===//
+// Public functions for vector unrolling
+//===----------------------------------------------------------------------===//
+
+int mlir::spirv::getComputeVectorSize(int64_t size) {
+ for (int i : {4, 3, 2}) {
+ if (size % i == 0)
+ return i;
+ }
+ return 1;
+}
+
+SmallVector<int64_t>
+mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) {
+ VectorType srcVectorType = op.getSourceVectorType();
+ assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
+ int64_t vectorSize =
+ mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0));
+ return {vectorSize};
+}
+
+SmallVector<int64_t>
+mlir::spirv::getNativeVectorShapeImpl(vector::TransposeOp op) {
+ VectorType vectorType = op.getResultVectorType();
+ SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
+ nativeSize.back() =
+ mlir::spirv::getComputeVectorSize(vectorType.getShape().back());
+ return nativeSize;
+}
+
+std::optional<SmallVector<int64_t>>
+mlir::spirv::getNativeVectorShape(Operation *op) {
+ if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
+ if (auto vecType = llvm::dyn_cast<VectorType>(op->getResultTypes()[0])) {
+ SmallVector<int64_t> nativeSize(vecType.getRank(), 1);
+ nativeSize.back() =
+ mlir::spirv::getComputeVectorSize(vecType.getShape().back());
+ return nativeSize;
+ }
+ }
+
+ return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)
+ .Case<vector::ReductionOp, vector::TransposeOp>(
+ [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
+ .Default([](Operation *) { return std::nullopt; });
+}
+
//===----------------------------------------------------------------------===//
// SPIR-V TypeConverter
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ConvertToSPIRV/arith.mlir b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir
index 1a844a7cd018b..6418e931f7460 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/arith.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s
//===----------------------------------------------------------------------===//
// arithmetic ops
diff --git a/mlir/test/Conversion/ConvertToSPIRV/combined.mlir b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir
index 02b938be775a3..311174bef15ed 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/combined.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s
// CHECK-LABEL: @combined
// CHECK: %[[C0_F32:.*]] = spirv.Constant 0.000000e+00 : f32
diff --git a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
index 347d282f9ee0c..c018ccb924983 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
@@ -66,6 +66,28 @@ func.func @simple_vector_8(%arg0 : vector<8xi32>) -> vector<8xi32> {
// -----
+// CHECK-LABEL: @simple_vector_2d
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>)
+func.func @simple_vector_2d(%arg0 : vector<4x4xi32>) -> vector<4x4xi32> {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4x4xi32>
+ // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+ // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [1, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+ // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [2, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+ // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [3, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+ // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+ // CHECK: %[[EXTRACT0_1:.*]] = vector.extract %[[EXTRACT0]][0] : vector<4xi32> from vector<1x4xi32>
+ // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+ // CHECK: %[[EXTRACT1_1:.*]] = vector.extract %[[EXTRACT1]][0] : vector<4xi32> from vector<1x4xi32>
+ // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [2, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+ // CHECK: %[[EXTRACT2_1:.*]] = vector.extract %[[EXTRACT2]][0] : vector<4xi32> from vector<1x4xi32>
+ // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [3, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+ // CHECK: %[[EXTRACT3_1:.*]] = vector.extract %[[EXTRACT3]][0] : vector<4xi32> from vector<1x4xi32>
+ // CHECK: return %[[EXTRACT0_1]], %[[EXTRACT1_1]], %[[EXTRACT2_1]], %[[EXTRACT3_1]] : vector<4xi32>, vector<4xi32>, vector<4xi32>, vector<4xi32>
+ return %arg0 : vector<4x4xi32>
+}
+
+// -----
+
// CHECK-LABEL: @vector_6and8
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>)
func.func @vector_6and8(%arg0 : vector<6xi32>, %arg1 : vector<8xi32>) -> (vector<6xi32>, vector<8xi32>) {
@@ -113,6 +135,28 @@ func.func @scalar_vector(%arg0 : vector<8xi32>, %arg1 : vector<3xi32>, %arg2 : i
// -----
+// CHECK-LABEL: @vector_2dand1d
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<3xi32>, %[[ARG3:.+]]: vector<3xi32>, %[[ARG4:.+]]: vector<4xi32>)
+func.func @vector_2dand1d(%arg0 : vector<2x6xi32>, %arg1 : vector<4xi32>) -> (vector<2x6xi32>, vector<4xi32>) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x6xi32>
+ // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+ // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [0, 3], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+ // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [1, 0], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+ // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [1, 3], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+ // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+ // CHECK: %[[EXTRACT0_1:.*]] = vector.extract %[[EXTRACT0]][0] : vector<3xi32> from vector<1x3xi32>
+ // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 3], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+ // CHECK: %[[EXTRACT1_1:.*]] = vector.extract %[[EXTRACT1]][0] : vector<3xi32> from vector<1x3xi32>
+ // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+ // CHECK: %[[EXTRACT2_1:.*]] = vector.extract %[[EXTRACT2]][0] : vector<3xi32> from vector<1x3xi32>
+ // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 3], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+ // CHECK: %[[EXTRACT3_1:.*]] = vector.extract %[[EXTRACT3]][0] : vector<3xi32> from vector<1x3xi32>
+ // CHECK: return %[[EXTRACT0_1]], %[[EXTRACT1_1]], %[[EXTRACT2_1]], %[[EXTRACT3_1]], %[[ARG4]] : vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<4xi32>
+ return %arg0, %arg1 : vector<2x6xi32>, vector<4xi32>
+}
+
+// -----
+
// CHECK-LABEL: @reduction
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>, %[[ARG4:.+]]: i32)
func.func @reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>, %arg2 : i32) -> (i32) {
diff --git a/mlir/test/Conversion/ConvertToSPIRV/index.mlir b/mlir/test/Conversion/ConvertToSPIRV/index.mlir
index e1cb18aac5d01..f4b116849fa93 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/index.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/index.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-to-spirv="run-signature-conversion=false" | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s
// CHECK-LABEL: @basic
func.func @basic(%a: index, %b: index) {
diff --git a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
index 58ec6ac61f6ac..246464928b81c 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s
// CHECK-LABEL: @if_yield
// CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr<f32, Function>
diff --git a/mlir/test/Conversion/ConvertToSPIRV/simple.mlir b/mlir/test/Conversion/ConvertToSPIRV/simple.mlir
index c5e0e6603d94a..00556140c3018 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/simple.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s
// CHECK-LABEL: @return_scalar
// CHECK-SAME: %[[ARG0:.*]]: i32
diff --git a/mlir/test/Conversion/ConvertToSPIRV/ub.mlir b/mlir/test/Conversion/ConvertToSPIRV/ub.mlir
index a83bfb6f405a0..f34ca01c94f00 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/ub.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/ub.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s
// CHECK-LABEL: @ub
// CHECK: %[[UNDEF:.*]] = spirv.Undef : i32
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
new file mode 100644
index 0000000000000..54d9875002cb5
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
@@ -0,0 +1,102 @@
+// RUN: mlir-opt -test-spirv-vector-unrolling -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @vaddi
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<3xi32>, %[[ARG3:.+]]: vector<3xi32>)
+func.func @vaddi(%arg0 : vector<6xi32>, %arg1 : vector<6xi32>) -> (vector<6xi32>) {
+ // CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG2]] : vector<3xi32>
+ // CHECK: %[[ADD1:.*]] = arith.addi %[[ARG1]], %[[ARG3]] : vector<3xi32>
+ // CHECK: return %[[ADD0]], %[[ADD1]] : vector<3xi32>, vector<3xi32>
+ %0 = arith.addi %arg0, %arg1 : vector<6xi32>
+ return %0 : vector<6xi32>
+}
+
+// CHECK-LABEL: @vaddi_2d
+// CHECK-SAME: (%[[ARG0:...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/100138
More information about the Mlir-commits
mailing list