[Mlir-commits] [mlir] [mlir][spirv] Implement vector unrolling for `convert-to-spirv` pass (PR #100138)

Jakub Kuderski llvmlistbot at llvm.org
Tue Jul 23 09:27:44 PDT 2024


================
@@ -56,6 +58,78 @@ struct ConvertToSPIRVPass final
         return signalPassFailure();
     }
 
+    if (runVectorUnrolling) {
+
+      // Fold transpose ops if possible as we cannot unroll it later.
+      {
+        RewritePatternSet patterns(context);
+        vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
+        if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) {
+          return signalPassFailure();
+        }
+      }
+
+      // Unroll vectors to native vector size.
+      {
+        RewritePatternSet patterns(context);
+        auto options = vector::UnrollVectorOptions().setNativeShapeFn(
+            [=](auto op) { return mlir::spirv::getNativeVectorShape(op); });
+        populateVectorUnrollPatterns(patterns, options);
+        if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+          return signalPassFailure();
+      }
+
+      // Convert transpose ops into extract and insert pairs, in preparation
+      // of further transformations to canonicalize/cancel.
+      {
+        RewritePatternSet patterns(context);
+        auto options =
+            vector::VectorTransformsOptions().setVectorTransposeLowering(
+                vector::VectorTransposeLowering::EltWise);
+        vector::populateVectorTransposeLoweringPatterns(patterns, options);
+        vector::populateVectorShapeCastLoweringPatterns(patterns);
+        if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) {
+          return signalPassFailure();
+        }
+      }
+
+      // Run canonicalization to cast away leading size-1 dimensions.
+      {
+        RewritePatternSet patterns(context);
+
+        // Pull in casting way leading one dims to allow cancelling some
+        // read/write ops.
+        vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+        vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
+
+        // Decompose different rank insert_strided_slice and n-D
+        // extract_slided_slice.
+        vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
+            patterns);
+        vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
+
+        // Trimming leading unit dims may generate broadcast/shape_cast ops.
+        // Clean them up.
+        vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
+        vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
+
+        if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+          return signalPassFailure();
+      }
+
+      // Run all sorts of canonicalization patterns to clean up again.
+      {
+        RewritePatternSet patterns(context);
----------------
kuhar wrote:

Why do we run these separately from the patterns above?

https://github.com/llvm/llvm-project/pull/100138


More information about the Mlir-commits mailing list