[Mlir-commits] [mlir] [mlir][spirv] Implement vector type legalization in function signatures (PR #98337)
Jakub Kuderski
llvmlistbot at llvm.org
Wed Jul 10 09:13:01 PDT 2024
================
@@ -813,6 +853,281 @@ void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
}
+//===----------------------------------------------------------------------===//
+// func::FuncOp Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A pattern for rewriting function signature to convert vector arguments of
+/// functions to be of valid types
+class FuncOpVectorUnroll : public OpRewritePattern<func::FuncOp> {
+public:
+ using OpRewritePattern<func::FuncOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(func::FuncOp funcOp,
+ PatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult
+FuncOpVectorUnroll::matchAndRewrite(func::FuncOp funcOp,
+ PatternRewriter &rewriter) const {
+ auto fnType = funcOp.getFunctionType();
+
+ // Create a new func op with the original type and copy the function body.
+ auto newFuncOp =
+ rewriter.create<func::FuncOp>(funcOp.getLoc(), funcOp.getName(), fnType);
+ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+
+ llvm::errs() << "After creating new func op and copying the function body\n";
+ newFuncOp.dump();
+
+ Location loc = newFuncOp.getBody().getLoc();
+ Block &entryBlock = newFuncOp.getBlocks().front();
+ rewriter.setInsertionPointToStart(&entryBlock);
+
+ OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
+
+ // For arguments that are of illegal types and require unrolling.
+ // `unrolledInputNums` stores the indices of arguments that result from
+ // unrolling in the new function signature. `newInputNo` is a counter.
+ SmallVector<size_t> unrolledInputNums;
+ size_t newInputNo = 0;
+
+ // For arguments that are of legal types and do not require unrolling.
+ // `tmpOps` stores a mapping from temporary operations that serve as
+ // placeholders for new arguments that will be added later. These operations
+ // will be erased once the entry block's argument list is updated.
+ DenseMap<Operation *, size_t> tmpOps;
+
+ // This counts the number of new operations created.
+ size_t newOpCount = 0;
+
+ // Enumerate through the arguments.
+ for (const auto &argType : enumerate(fnType.getInputs())) {
+ size_t origInputNo = argType.index();
+ Type origType = argType.value();
+ // Check whether the argument is of vector type.
+ auto origVecType = llvm::dyn_cast<VectorType>(origType);
----------------
kuhar wrote:
nit:
```suggestion
auto origVecType = dyn_cast<VectorType>(origType);
```
https://github.com/llvm/llvm-project/pull/98337
More information about the Mlir-commits
mailing list