[Mlir-commits] [mlir] [mlir][spirv][cf] legalize block arguments when convert cf to spirv (PR #71288)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Nov 4 09:17:26 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv
Author: Xiang Li (python3kgae)
<details>
<summary>Changes</summary>
When converting branches, legalize target block arguments first.
Fixes llvm#<!-- -->70813
---
Full diff: https://github.com/llvm/llvm-project/pull/71288.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp (+19-23)
- (modified) mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir (+5-1)
``````````diff
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
index 5dba79016120b38..c24a5b2e5b05d50 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
@@ -28,20 +28,23 @@
using namespace mlir;
-/// Checks that the target block arguments are legal.
-static LogicalResult checkBlockArguments(Block &block, Operation *op,
- PatternRewriter &rewriter,
- const TypeConverter &converter) {
- for (BlockArgument arg : block.getArguments()) {
- if (!converter.isLegal(arg.getType())) {
- return rewriter.notifyMatchFailure(
- op,
- llvm::formatv(
- "failed to match, destination argument not legalized (found {0})",
- arg));
- }
+/// Legailze target block arguments.
+static void legalizeBlockArguments(Block &block,
+ const TypeConverter &converter) {
+ auto builder = OpBuilder::atBlockBegin(&block);
+ for (unsigned i = 0; i < block.getNumArguments(); ++i) {
+ const auto arg = block.getArgument(i);
+ if (converter.isLegal(arg.getType()))
+ continue;
+ unsigned argNum = arg.getArgNumber();
+ Location loc = arg.getLoc();
+ Type ty = arg.getType();
+ Type newTy = converter.convertType(ty);
+ Value newArg = block.insertArgument(argNum, newTy, loc);
+ auto cast = builder.create<UnrealizedConversionCastOp>(loc, ty, newArg);
+ arg.replaceAllUsesWith(cast.getResult(0));
+ block.eraseArgument(argNum + 1);
}
- return success();
}
//===----------------------------------------------------------------------===//
@@ -56,9 +59,7 @@ struct BranchOpPattern final : OpConversionPattern<cf::BranchOp> {
LogicalResult
matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (failed(checkBlockArguments(*op.getDest(), op, rewriter,
- *getTypeConverter())))
- return failure();
+ legalizeBlockArguments(*op.getDest(), *getTypeConverter());
rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
adaptor.getDestOperands());
@@ -73,13 +74,8 @@ struct CondBranchOpPattern final : OpConversionPattern<cf::CondBranchOp> {
LogicalResult
matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (failed(checkBlockArguments(*op.getTrueDest(), op, rewriter,
- *getTypeConverter())))
- return failure();
-
- if (failed(checkBlockArguments(*op.getFalseDest(), op, rewriter,
- *getTypeConverter())))
- return failure();
+ legalizeBlockArguments(*op.getTrueDest(), *getTypeConverter());
+ legalizeBlockArguments(*op.getFalseDest(), *getTypeConverter());
rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
op, adaptor.getCondition(), op.getTrueDest(),
diff --git a/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir b/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
index a10a1468e582499..4e2f6a12c8e7ee4 100644
--- a/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
@@ -39,16 +39,20 @@ func.func @simple_loop(%begin: i32, %end: i32, %step: i32) {
// -----
-// TODO: We should handle blocks whose arguments require type conversion.
+// Handle blocks whose arguments require type conversion.
// CHECK-LABEL: func.func @main_graph
func.func @main_graph(%arg0: index) {
%c3 = arith.constant 1 : index
+// CHECK: spirv.Branch ^bb1({{.*}} : i32)
cf.br ^bb1(%arg0 : index)
+// CHECK: ^bb1({{.*}}: i32): // 2 preds: ^bb0, ^bb2
^bb1(%0: index): // 2 preds: ^bb0, ^bb2
%1 = arith.cmpi slt, %0, %c3 : index
+// CHECK: spirv.BranchConditional {{.*}}, ^bb2, ^bb3
cf.cond_br %1, ^bb2, ^bb3
^bb2: // pred: ^bb1
+// CHECK: spirv.Branch ^bb1({{.*}} : i32)
cf.br ^bb1(%c3 : index)
^bb3: // pred: ^bb1
return
``````````
</details>
https://github.com/llvm/llvm-project/pull/71288
More information about the Mlir-commits
mailing list