[Mlir-commits] [mlir] [mlir][spirv] Implement vector type legalization in function signatures (PR #98337)

Jakub Kuderski llvmlistbot at llvm.org
Wed Jul 10 09:13:01 PDT 2024


================
@@ -813,6 +853,281 @@ void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
   patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
 }
 
+//===----------------------------------------------------------------------===//
+// func::FuncOp Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A pattern for rewriting function signature to convert vector arguments of
+/// functions to be of valid types
+class FuncOpVectorUnroll : public OpRewritePattern<func::FuncOp> {
+public:
+  using OpRewritePattern<func::FuncOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(func::FuncOp funcOp,
+                                PatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult
+FuncOpVectorUnroll::matchAndRewrite(func::FuncOp funcOp,
+                                    PatternRewriter &rewriter) const {
+  auto fnType = funcOp.getFunctionType();
+
+  // Create a new func op with the original type and copy the function body.
+  auto newFuncOp =
+      rewriter.create<func::FuncOp>(funcOp.getLoc(), funcOp.getName(), fnType);
+  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+                              newFuncOp.end());
+
+  llvm::errs() << "After creating new func op and copying the function body\n";
+  newFuncOp.dump();
+
+  Location loc = newFuncOp.getBody().getLoc();
+  Block &entryBlock = newFuncOp.getBlocks().front();
+  rewriter.setInsertionPointToStart(&entryBlock);
+
+  OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
+
+  // For arguments that are of illegal types and require unrolling.
+  // `unrolledInputNums` stores the indices of arguments that result from
+  // unrolling in the new function signature. `newInputNo` is a counter.
+  SmallVector<size_t> unrolledInputNums;
+  size_t newInputNo = 0;
+
+  // For arguments that are of legal types and do not require unrolling.
+  // `tmpOps` stores a mapping from temporary operations that serve as
+  // placeholders for new arguments that will be added later. These operations
+  // will be erased once the entry block's argument list is updated.
+  DenseMap<Operation *, size_t> tmpOps;
+
+  // This counts the number of new operations created.
+  size_t newOpCount = 0;
+
+  // Enumerate through the arguments.
+  for (const auto &argType : enumerate(fnType.getInputs())) {
+    size_t origInputNo = argType.index();
+    Type origType = argType.value();
+    // Check whether the argument is of vector type.
+    auto origVecType = llvm::dyn_cast<VectorType>(origType);
+    if (!origVecType) {
+      // We need a placeholder for the old argument that will be erased later.
+      Value result = rewriter.create<arith::ConstantOp>(
+          loc, origType, rewriter.getZeroAttr(origType));
+      rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+      tmpOps.insert({result.getDefiningOp(), newInputNo});
+      oneToNTypeMapping.addInputs(origInputNo, origType);
+      newInputNo++;
+      newOpCount++;
+      continue;
+    }
+    // Check whether the vector needs unrolling.
+    auto targetShape = getTargetShape(origVecType);
+    if (!targetShape) {
+      // We need a placeholder for the old argument that will be erased later.
+      Value result = rewriter.create<arith::ConstantOp>(
+          loc, origType, rewriter.getZeroAttr(origType));
+      rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+      tmpOps.insert({result.getDefiningOp(), newInputNo});
+      oneToNTypeMapping.addInputs(origInputNo, origType);
+      newInputNo++;
+      newOpCount++;
+      continue;
+    }
+    llvm::errs() << "Got target shape\n";
+    VectorType unrolledType =
+        VectorType::get(*targetShape, origVecType.getElementType());
+    llvm::errs() << "Unrolled type is ";
+    unrolledType.dump();
+    SmallVector<int64_t> originalShape =
+        llvm::to_vector<4>(origVecType.getShape());
+
+    // Prepare the result vector.
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, origVecType, rewriter.getZeroAttr(origVecType));
+    newOpCount++;
+    // Prepare the placeholder for the new arguments that will be added later.
+    Value dummy = rewriter.create<arith::ConstantOp>(
+        loc, unrolledType, rewriter.getZeroAttr(unrolledType));
+    newOpCount++;
+
+    // Create the `vector.insert_strided_slice` ops.
+    SmallVector<int64_t> strides(targetShape->size(), 1);
+    SmallVector<Type> newTypes;
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalShape, *targetShape)) {
+      result = rewriter.create<vector::InsertStridedSliceOp>(loc, dummy, result,
+                                                             offsets, strides);
+      newTypes.push_back(unrolledType);
+      unrolledInputNums.push_back(newInputNo);
+      newInputNo++;
+      newOpCount++;
+    }
+    rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+    oneToNTypeMapping.addInputs(origInputNo, newTypes);
+  }
+
+  llvm::errs() << "After enumerating through the arguments\n";
+  newFuncOp.dump();
+
+  // Change the function signature.
+  auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
+  auto newFnType =
+      FunctionType::get(rewriter.getContext(), TypeRange(convertedTypes),
+                        TypeRange(fnType.getResults()));
+  rewriter.modifyOpInPlace(newFuncOp,
+                           [&] { newFuncOp.setFunctionType(newFnType); });
+
+  llvm::errs() << "After changing function signature\n";
+  newFuncOp.dump();
+
+  // Update the arguments in the entry block.
+  entryBlock.eraseArguments(0, fnType.getNumInputs());
+  SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
+  entryBlock.addArguments(convertedTypes, locs);
+
+  llvm::errs() << "After updating the arguments in the entry block\n";
+  newFuncOp.dump();
+
+  // Replace the placeholder values with the new arguments. We assume there is
+  // only one block for now.
+  size_t idx = 0;
+  for (auto opPair : llvm::enumerate(entryBlock.getOperations())) {
----------------
kuhar wrote:

Similar here

https://github.com/llvm/llvm-project/pull/98337


More information about the Mlir-commits mailing list