[Mlir-commits] [mlir] [mlir][spirv] Implement vector unrolling for `convert-to-spirv` pass (PR #100138)
Ivan Butygin
llvmlistbot at llvm.org
Wed Jul 24 05:55:57 PDT 2024
================
@@ -1285,6 +1287,115 @@ Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter,
builder);
}
+//===----------------------------------------------------------------------===//
+// Public functions for vector unrolling
+//===----------------------------------------------------------------------===//
+
+int mlir::spirv::getComputeVectorSize(int64_t size) {
+ for (int i : {4, 3, 2}) {
+ if (size % i == 0)
+ return i;
+ }
+ return 1;
+}
+
+SmallVector<int64_t>
+mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) {
+ VectorType srcVectorType = op.getSourceVectorType();
+ assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
+ int64_t vectorSize =
+ mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0));
+ return {vectorSize};
+}
+
+SmallVector<int64_t>
+mlir::spirv::getNativeVectorShapeImpl(vector::TransposeOp op) {
+ VectorType vectorType = op.getResultVectorType();
+ SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
+ nativeSize.back() =
+ mlir::spirv::getComputeVectorSize(vectorType.getShape().back());
+ return nativeSize;
+}
+
+std::optional<SmallVector<int64_t>>
+mlir::spirv::getNativeVectorShape(Operation *op) {
+ if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
+ if (auto vecType = llvm::dyn_cast<VectorType>(op->getResultTypes()[0])) {
+ SmallVector<int64_t> nativeSize(vecType.getRank(), 1);
+ nativeSize.back() =
+ mlir::spirv::getComputeVectorSize(vecType.getShape().back());
+ return nativeSize;
+ }
+ }
+
+ return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)
+ .Case<vector::ReductionOp, vector::TransposeOp>(
+ [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
+ .Default([](Operation *) { return std::nullopt; });
+}
+
+LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) {
+ MLIRContext *context = op->getContext();
+ RewritePatternSet patterns(context);
+ populateFuncOpVectorRewritePatterns(patterns);
+ populateReturnOpVectorRewritePatterns(patterns);
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
----------------
Hardcode84 wrote:
Please add comment why we need `config.strictMode`
https://github.com/llvm/llvm-project/pull/100138
More information about the Mlir-commits
mailing list