[Mlir-commits] [mlir] [mlir][spirv] Implement vector unrolling for `convert-to-spirv` pass (PR #100138)
Angel Zhang
llvmlistbot at llvm.org
Tue Jul 23 12:49:03 PDT 2024
================
@@ -56,6 +58,78 @@ struct ConvertToSPIRVPass final
return signalPassFailure();
}
+ if (runVectorUnrolling) {
+
+ // Fold transpose ops if possible as we cannot unroll it later.
+ {
+ RewritePatternSet patterns(context);
+ vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+
+ // Unroll vectors to native vector size.
+ {
+ RewritePatternSet patterns(context);
+ auto options = vector::UnrollVectorOptions().setNativeShapeFn(
+ [=](auto op) { return mlir::spirv::getNativeVectorShape(op); });
+ populateVectorUnrollPatterns(patterns, options);
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+ return signalPassFailure();
+ }
+
+ // Convert transpose ops into extract and insert pairs, in preparation
+ // of further transformations to canonicalize/cancel.
+ {
+ RewritePatternSet patterns(context);
+ auto options =
+ vector::VectorTransformsOptions().setVectorTransposeLowering(
+ vector::VectorTransposeLowering::EltWise);
+ vector::populateVectorTransposeLoweringPatterns(patterns, options);
+ vector::populateVectorShapeCastLoweringPatterns(patterns);
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+
+ // Run canonicalization to cast away leading size-1 dimensions.
+ {
+ RewritePatternSet patterns(context);
+
+ // Pull in casting way leading one dims to allow cancelling some
+ // read/write ops.
+ vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+ vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
+
+ // Decompose different rank insert_strided_slice and n-D
+ // extract_slided_slice.
+ vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
+ patterns);
+ vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
+
+ // Trimming leading unit dims may generate broadcast/shape_cast ops.
+ // Clean them up.
+ vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
+ vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
+
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+ return signalPassFailure();
+ }
+
+ // Run all sorts of canonicalization patterns to clean up again.
+ {
+ RewritePatternSet patterns(context);
----------------
angelz913 wrote:
These patterns were found to be redundant and have been removed.
https://github.com/llvm/llvm-project/pull/100138
More information about the Mlir-commits
mailing list