[Mlir-commits] [llvm] [mlir] [mlir][spirv] Implement vector type legalization for function signatures (PR #98337)

Jakub Kuderski llvmlistbot at llvm.org
Wed Jul 17 08:36:15 PDT 2024


================
@@ -813,6 +860,250 @@ void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
   patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
 }
 
+//===----------------------------------------------------------------------===//
+// func::FuncOp Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A pattern for rewriting function signature to convert vector arguments of
+/// functions to be of valid types
+struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(func::FuncOp funcOp,
+                                PatternRewriter &rewriter) const override {
+    FunctionType fnType = funcOp.getFunctionType();
+
+    // TODO: Handle declarations.
+    if (funcOp.isDeclaration()) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << fnType << " illegal: declarations are unsupported\n");
+      return failure();
+    }
+
+    // Create a new func op with the original type and copy the function body.
+    auto newFuncOp = rewriter.create<func::FuncOp>(funcOp.getLoc(),
+                                                   funcOp.getName(), fnType);
+    rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+                                newFuncOp.end());
+
+    Location loc = newFuncOp.getBody().getLoc();
+
+    Block &entryBlock = newFuncOp.getBlocks().front();
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(&entryBlock);
+
+    OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
+
+    // For arguments that are of illegal types and require unrolling.
+    // `unrolledInputNums` stores the indices of arguments that result from
+    // unrolling in the new function signature. `newInputNo` is a counter.
+    SmallVector<size_t> unrolledInputNums;
+    size_t newInputNo = 0;
+
+    // For arguments that are of legal types and do not require unrolling.
+    // `tmpOps` stores a mapping from temporary operations that serve as
+    // placeholders for new arguments that will be added later. These operations
+    // will be erased once the entry block's argument list is updated.
+    llvm::SmallDenseMap<Operation *, size_t> tmpOps;
+
+    // This counts the number of new operations created.
+    size_t newOpCount = 0;
+
+    // Enumerate through the arguments.
+    for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
+      // Check whether the argument is of vector type.
+      auto origVecType = dyn_cast<VectorType>(origType);
+      if (!origVecType) {
+        // We need a placeholder for the old argument that will be erased later.
+        Value result = rewriter.create<arith::ConstantOp>(
+            loc, origType, rewriter.getZeroAttr(origType));
+        rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+        tmpOps.insert({result.getDefiningOp(), newInputNo});
+        oneToNTypeMapping.addInputs(origInputNo, origType);
+        ++newInputNo;
+        ++newOpCount;
+        continue;
+      }
+      // Check whether the vector needs unrolling.
+      auto targetShape = getTargetShape(origVecType);
+      if (!targetShape) {
+        // We need a placeholder for the old argument that will be erased later.
+        Value result = rewriter.create<arith::ConstantOp>(
+            loc, origType, rewriter.getZeroAttr(origType));
+        rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+        tmpOps.insert({result.getDefiningOp(), newInputNo});
+        oneToNTypeMapping.addInputs(origInputNo, origType);
+        ++newInputNo;
+        ++newOpCount;
+        continue;
+      }
+      VectorType unrolledType =
+          VectorType::get(*targetShape, origVecType.getElementType());
+      SmallVector<int64_t> originalShape =
+          llvm::to_vector(origVecType.getShape());
+
+      // Prepare the result vector.
+      Value result = rewriter.create<arith::ConstantOp>(
+          loc, origVecType, rewriter.getZeroAttr(origVecType));
+      ++newOpCount;
+      // Prepare the placeholder for the new arguments that will be added later.
+      Value dummy = rewriter.create<arith::ConstantOp>(
+          loc, unrolledType, rewriter.getZeroAttr(unrolledType));
+      ++newOpCount;
+
+      // Create the `vector.insert_strided_slice` ops.
+      SmallVector<int64_t> strides(targetShape->size(), 1);
+      SmallVector<Type> newTypes;
+      for (SmallVector<int64_t> offsets :
+           StaticTileOffsetRange(originalShape, *targetShape)) {
+        result = rewriter.create<vector::InsertStridedSliceOp>(
+            loc, dummy, result, offsets, strides);
+        newTypes.push_back(unrolledType);
+        unrolledInputNums.push_back(newInputNo);
+        ++newInputNo;
+        ++newOpCount;
+      }
+      rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+      oneToNTypeMapping.addInputs(origInputNo, newTypes);
+    }
+
+    // Change the function signature.
+    auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
+    auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
+    rewriter.modifyOpInPlace(newFuncOp,
+                             [&] { newFuncOp.setFunctionType(newFnType); });
+
+    // Update the arguments in the entry block.
+    entryBlock.eraseArguments(0, fnType.getNumInputs());
+    SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
+    entryBlock.addArguments(convertedTypes, locs);
+
+    // Replace the placeholder values with the new arguments. We assume there is
+    // only one block for now.
+    size_t idx = 0;
+    for (auto [count, op] : enumerate(entryBlock.getOperations())) {
+      // We first look for operands that are placeholders for initially legal
+      // arguments.
+      Operation &curOp = op;
+      for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
+        Operation *operandOp = operandVal.getDefiningOp();
+        if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
+          size_t idx = operandIdx;
+          rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
+            curOp.setOperand(idx, newFuncOp.getArgument(it->second));
+          });
+        }
+      }
+      // Since all newly created operations are in the beginning, reaching the
+      // end of them means that any later `vector.insert_strided_slice` should
+      // not be touched.
+      if (count >= newOpCount)
+        continue;
+      if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
+        size_t unrolledInputNo = unrolledInputNums[idx];
+        rewriter.modifyOpInPlace(&curOp, [&] {
+          curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
+        });
+        ++idx;
+      }
+      ++count;
+    }
+
+    // Erase the original funcOp. The `tmpOps` do not need to be erased since
+    // they have no uses and will be handled by dead-code elimination.
+    rewriter.eraseOp(funcOp);
+    return success();
+  }
+};
+} // namespace
+
+void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
+  patterns.add<FuncOpVectorUnroll>(patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// func::ReturnOp Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A pattern for rewriting function signature and the return op to convert
+/// vectors to be of valid types.
+struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(func::ReturnOp returnOp,
+                                PatternRewriter &rewriter) const override {
+    // Check whether the parent funcOp is valid.
+    auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
+    if (!funcOp)
+      return failure();
+
+    FunctionType fnType = funcOp.getFunctionType();
+    OneToNTypeMapping oneToNTypeMapping(fnType.getResults());
+    Location loc = returnOp.getLoc();
+
+    // For the new return op.
+    SmallVector<Value> newOperands;
+
+    // Enumerate through the results.
+    for (auto [origResultNo, origType] : enumerate(fnType.getResults())) {
+      // Check whether the argument is of vector type.
+      auto origVecType = dyn_cast<VectorType>(origType);
+      if (!origVecType) {
+        oneToNTypeMapping.addInputs(origResultNo, origType);
+        newOperands.push_back(returnOp.getOperand(origResultNo));
+        continue;
+      }
+      // Check whether the vector needs unrolling.
+      auto targetShape = getTargetShape(origVecType);
+      if (!targetShape) {
+        // The original argument can be used.
+        oneToNTypeMapping.addInputs(origResultNo, origType);
+        newOperands.push_back(returnOp.getOperand(origResultNo));
+        continue;
+      }
+      VectorType unrolledType =
+          VectorType::get(*targetShape, origVecType.getElementType());
+
+      // Create `vector.extract_strided_slice` ops to form legal vectors from
+      // the original operand of illegal type.
+      SmallVector<int64_t> originalShape =
+          llvm::to_vector<4>(origVecType.getShape());
----------------
kuhar wrote:

```suggestion
     auto originalShape = llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
```

Otherwise you are relying on the cache size calculation of the build target to end up producing 4 for the automatic small vector size.


https://github.com/llvm/llvm-project/pull/98337


More information about the Mlir-commits mailing list