[Mlir-commits] [mlir] [mlir][spirv] Implement vector type legalization for function signatures (PR #98337)
Ivan Butygin
llvmlistbot at llvm.org
Thu Jul 11 13:56:32 PDT 2024
================
@@ -813,6 +855,243 @@ void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
}
+//===----------------------------------------------------------------------===//
+// func::FuncOp Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A pattern for rewriting function signature to convert vector arguments of
+/// functions to be of valid types
+class FuncOpVectorUnroll : public OpRewritePattern<func::FuncOp> {
+public:
+ using OpRewritePattern<func::FuncOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(func::FuncOp funcOp,
+ PatternRewriter &rewriter) const override {
+ FunctionType fnType = funcOp.getFunctionType();
+
+ // Create a new func op with the original type and copy the function body.
+ auto newFuncOp = rewriter.create<func::FuncOp>(funcOp.getLoc(),
+ funcOp.getName(), fnType);
+ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+
+ Location loc = newFuncOp.getBody().getLoc();
+ Block &entryBlock = newFuncOp.getBlocks().front();
+ rewriter.setInsertionPointToStart(&entryBlock);
+
+ OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
+
+ // For arguments that are of illegal types and require unrolling.
+ // `unrolledInputNums` stores the indices of arguments that result from
+ // unrolling in the new function signature. `newInputNo` is a counter.
+ SmallVector<size_t> unrolledInputNums;
+ size_t newInputNo = 0;
+
+ // For arguments that are of legal types and do not require unrolling.
+ // `tmpOps` stores a mapping from temporary operations that serve as
+ // placeholders for new arguments that will be added later. These operations
+ // will be erased once the entry block's argument list is updated.
+ DenseMap<Operation *, size_t> tmpOps;
+
+ // This counts the number of new operations created.
+ size_t newOpCount = 0;
+
+ // Enumerate through the arguments.
+ for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
+ // Check whether the argument is of vector type.
+ auto origVecType = dyn_cast<VectorType>(origType);
+ if (!origVecType) {
+ // We need a placeholder for the old argument that will be erased later.
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, origType, rewriter.getZeroAttr(origType));
+ rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+ tmpOps.insert({result.getDefiningOp(), newInputNo});
+ oneToNTypeMapping.addInputs(origInputNo, origType);
+ newInputNo++;
+ newOpCount++;
+ continue;
+ }
+ // Check whether the vector needs unrolling.
+ auto targetShape = getTargetShape(origVecType);
+ if (!targetShape) {
+ // We need a placeholder for the old argument that will be erased later.
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, origType, rewriter.getZeroAttr(origType));
+ rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+ tmpOps.insert({result.getDefiningOp(), newInputNo});
+ oneToNTypeMapping.addInputs(origInputNo, origType);
+ newInputNo++;
+ newOpCount++;
+ continue;
+ }
+ VectorType unrolledType =
+ VectorType::get(*targetShape, origVecType.getElementType());
+ SmallVector<int64_t> originalShape =
+ llvm::to_vector<4>(origVecType.getShape());
+
+ // Prepare the result vector.
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, origVecType, rewriter.getZeroAttr(origVecType));
+ newOpCount++;
+ // Prepare the placeholder for the new arguments that will be added later.
+ Value dummy = rewriter.create<arith::ConstantOp>(
+ loc, unrolledType, rewriter.getZeroAttr(unrolledType));
+ newOpCount++;
+
+ // Create the `vector.insert_strided_slice` ops.
+ SmallVector<int64_t> strides(targetShape->size(), 1);
+ SmallVector<Type> newTypes;
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalShape, *targetShape)) {
+ result = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, dummy, result, offsets, strides);
+ newTypes.push_back(unrolledType);
+ unrolledInputNums.push_back(newInputNo);
+ newInputNo++;
+ newOpCount++;
+ }
+ rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+ oneToNTypeMapping.addInputs(origInputNo, newTypes);
+ }
+
+ // Change the function signature.
+ auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
+ auto newFnType =
+ FunctionType::get(rewriter.getContext(), TypeRange(convertedTypes),
+ TypeRange(fnType.getResults()));
+ rewriter.modifyOpInPlace(newFuncOp,
+ [&] { newFuncOp.setFunctionType(newFnType); });
+
+ // Update the arguments in the entry block.
+ entryBlock.eraseArguments(0, fnType.getNumInputs());
+ SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
+ entryBlock.addArguments(convertedTypes, locs);
+
+ // Replace the placeholder values with the new arguments. We assume there is
+ // only one block for now.
+ size_t idx = 0;
+ for (auto [count, op] : enumerate(entryBlock.getOperations())) {
+ // We first look for operands that are placeholders for initially legal
+ // arguments.
+ for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
+ Operation *operandOp = operandVal.getDefiningOp();
+ if (tmpOps.find(operandOp) != tmpOps.end())
+ rewriter.modifyOpInPlace(&op, [&] {
+ op.setOperand(operandIdx, newFuncOp.getArgument(tmpOps[operandOp]));
+ });
+ }
+ // Since all newly created operations are in the beginning, reaching the
+ // end of them means that any later `vector.insert_strided_slice` should
+ // not be touched.
+ if (count >= newOpCount)
+ continue;
+ auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op);
+ if (vecOp) {
+ size_t unrolledInputNo = unrolledInputNums[idx];
+ rewriter.modifyOpInPlace(&op, [&] {
+ op.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
+ });
+ idx++;
+ }
+ count++;
+ }
+
+ // Erase the original funcOp. The `tmpOps` do not need to be erased since
+ // they have no uses and will be handled by dead-code elimination.
+ rewriter.eraseOp(funcOp);
+ return success();
+ }
+};
+} // namespace
+
+void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
+ patterns.add<FuncOpVectorUnroll>(patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// func::ReturnOp Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A pattern for rewriting function signature and the return op to convert
+/// vectors to be of valid types.
+class ReturnOpVectorUnroll : public OpRewritePattern<func::ReturnOp> {
+public:
+ using OpRewritePattern<func::ReturnOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(func::ReturnOp returnOp,
+ PatternRewriter &rewriter) const override {
+ // Check whether the parent funcOp is valid.
+ auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
----------------
Hardcode84 wrote:
Ok. It requires you call `setFunctionType` twice, but this a minor thing.
https://github.com/llvm/llvm-project/pull/98337
More information about the Mlir-commits
mailing list