[Mlir-commits] [mlir] [mlir][spirv] Implement vector unrolling for `convert-to-spirv` pass (PR #100138)

Ivan Butygin llvmlistbot at llvm.org
Wed Jul 24 05:55:57 PDT 2024


================
@@ -1285,6 +1287,115 @@ Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter,
                              builder);
 }
 
+//===----------------------------------------------------------------------===//
+// Public functions for vector unrolling
+//===----------------------------------------------------------------------===//
+
+int mlir::spirv::getComputeVectorSize(int64_t size) {
+  for (int i : {4, 3, 2}) {
+    if (size % i == 0)
+      return i;
+  }
+  return 1;
+}
+
+SmallVector<int64_t>
+mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) {
+  VectorType srcVectorType = op.getSourceVectorType();
+  assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
+  int64_t vectorSize =
+      mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0));
+  return {vectorSize};
+}
+
+SmallVector<int64_t>
+mlir::spirv::getNativeVectorShapeImpl(vector::TransposeOp op) {
+  VectorType vectorType = op.getResultVectorType();
+  SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
+  nativeSize.back() =
+      mlir::spirv::getComputeVectorSize(vectorType.getShape().back());
+  return nativeSize;
+}
+
+std::optional<SmallVector<int64_t>>
+mlir::spirv::getNativeVectorShape(Operation *op) {
+  if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
+    if (auto vecType = llvm::dyn_cast<VectorType>(op->getResultTypes()[0])) {
+      SmallVector<int64_t> nativeSize(vecType.getRank(), 1);
+      nativeSize.back() =
+          mlir::spirv::getComputeVectorSize(vecType.getShape().back());
+      return nativeSize;
+    }
+  }
+
+  return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)
+      .Case<vector::ReductionOp, vector::TransposeOp>(
+          [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
+      .Default([](Operation *) { return std::nullopt; });
+}
+
+LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) {
+  MLIRContext *context = op->getContext();
+  RewritePatternSet patterns(context);
+  populateFuncOpVectorRewritePatterns(patterns);
+  populateReturnOpVectorRewritePatterns(patterns);
+  GreedyRewriteConfig config;
+  config.strictMode = GreedyRewriteStrictness::ExistingOps;
+  return applyPatternsAndFoldGreedily(op, std::move(patterns), config);
+}
+
+LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
+  MLIRContext *context = op->getContext();
+
+  // Unroll vectors in function bodies to native vector size.
+  {
+    RewritePatternSet patterns(context);
+    auto options = vector::UnrollVectorOptions().setNativeShapeFn(
+        [](auto op) { return mlir::spirv::getNativeVectorShape(op); });
+    populateVectorUnrollPatterns(patterns, options);
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+      return failure();
+  }
+
+  // Convert transpose ops into extract and insert pairs, in preparation of
+  // further transformations to canonicalize/cancel.
+  {
+    RewritePatternSet patterns(context);
+    auto options = vector::VectorTransformsOptions().setVectorTransposeLowering(
+        vector::VectorTransposeLowering::EltWise);
+    vector::populateVectorTransposeLoweringPatterns(patterns, options);
+    vector::populateVectorShapeCastLoweringPatterns(patterns);
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+      return llvm::failure();
+  }
+
+  // Run canonicalization to cast away leading size-1 dimensions.
+  {
+    RewritePatternSet patterns(context);
+
+    // We need to pull in casting way leading one dims.
+    vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+    vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
+    vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
+
+    // Decompose different rank insert_strided_slice and n-D
+    // extract_slided_slice.
+    vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
+        patterns);
+    vector::InsertOp::getCanonicalizationPatterns(patterns, context);
+    vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
+
+    // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
+    // them up.
+    vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
+    vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
+
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+      return llvm::failure();
----------------
Hardcode84 wrote:

drop `llvm::`

https://github.com/llvm/llvm-project/pull/100138


More information about the Mlir-commits mailing list