[Mlir-commits] [mlir] [mlir][spirv] Implement vector type legalization in function signatures (PR #98337)

Angel Zhang llvmlistbot at llvm.org
Thu Jul 11 06:44:42 PDT 2024

@@ -813,6 +855,243 @@ void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
   patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
+// func::FuncOp Conversion Patterns
+namespace {
+/// A pattern for rewriting function signature to convert vector arguments of
+/// functions to be of valid types
+class FuncOpVectorUnroll : public OpRewritePattern<func::FuncOp> {
+  using OpRewritePattern<func::FuncOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(func::FuncOp funcOp,
+                                PatternRewriter &rewriter) const override {
+    FunctionType fnType = funcOp.getFunctionType();
+    // Create a new func op with the original type and copy the function body.
+    auto newFuncOp = rewriter.create<func::FuncOp>(funcOp.getLoc(),
+                                                   funcOp.getName(), fnType);
+    rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+                                newFuncOp.end());
+    Location loc = newFuncOp.getBody().getLoc();
+    Block &entryBlock = newFuncOp.getBlocks().front();
+    rewriter.setInsertionPointToStart(&entryBlock);
+    OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
+    // For arguments that are of illegal types and require unrolling.
+    // `unrolledInputNums` stores the indices of arguments that result from
+    // unrolling in the new function signature. `newInputNo` is a counter.
+    SmallVector<size_t> unrolledInputNums;
+    size_t newInputNo = 0;
+    // For arguments that are of legal types and do not require unrolling.
+    // `tmpOps` stores a mapping from temporary operations that serve as
+    // placeholders for new arguments that will be added later. These operations
+    // will be erased once the entry block's argument list is updated.
+    DenseMap<Operation *, size_t> tmpOps;
+    // This counts the number of new operations created.
+    size_t newOpCount = 0;
+    // Enumerate through the arguments.
+    for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
+      // Check whether the argument is of vector type.
+      auto origVecType = dyn_cast<VectorType>(origType);
+      if (!origVecType) {
+        // We need a placeholder for the old argument that will be erased later.
+        Value result = rewriter.create<arith::ConstantOp>(
+            loc, origType, rewriter.getZeroAttr(origType));
+        rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+        tmpOps.insert({result.getDefiningOp(), newInputNo});
+        oneToNTypeMapping.addInputs(origInputNo, origType);
+        newInputNo++;
+        newOpCount++;
+        continue;
+      }
+      // Check whether the vector needs unrolling.
+      auto targetShape = getTargetShape(origVecType);
+      if (!targetShape) {
+        // We need a placeholder for the old argument that will be erased later.
+        Value result = rewriter.create<arith::ConstantOp>(
+            loc, origType, rewriter.getZeroAttr(origType));
+        rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+        tmpOps.insert({result.getDefiningOp(), newInputNo});
+        oneToNTypeMapping.addInputs(origInputNo, origType);
+        newInputNo++;
+        newOpCount++;
+        continue;
+      }
+      VectorType unrolledType =
+          VectorType::get(*targetShape, origVecType.getElementType());
angelz913 wrote:

I think this will cause compile error


More information about the Mlir-commits mailing list