[Mlir-commits] [mlir] [mlir][spirv][cf] Check destination block argument types (PR #70889)
Jakub Kuderski
llvmlistbot at llvm.org
Tue Oct 31 20:13:49 PDT 2023
https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/70889
>From f54d81de121f9623ef3b088d619fd209e16b7022 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Tue, 31 Oct 2023 23:06:53 -0400
Subject: [PATCH 1/2] [mlir][spirv][cf] Check destination block argument types
Emit errors for illegal destination blocks.
We should add region/block argument type conversions to properly address
this.
Issue: https://github.com/llvm/llvm-project/issues/70813
---
.../ControlFlowToSPIRV/ControlFlowToSPIRV.cpp | 39 +++++++++++++++----
.../ControlFlowToSPIRVPass.cpp | 4 +-
.../ControlFlowToSPIRV/cf-ops-to-spirv.mlir | 19 ++++++++-
3 files changed, 52 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
index 6787a5ccd3a4a87..6ba4388cb3b5825 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
@@ -19,25 +19,41 @@
#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "cf-to-spirv-pattern"
using namespace mlir;
+/// Ensures that the target block arguments are legal, emits an error if not.
+static LogicalResult checkBlockArguments(Block &block, Operation *branchOp,
+ const TypeConverter &converter) {
+ for (BlockArgument arg : block.getArguments()) {
+ if (!converter.isLegal(arg.getType())) {
+ return branchOp->emitOpError(
+ "failed to match, destination argument not legalized (found ")
+ << arg << ")";
+ }
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
namespace {
-
/// Converts cf.br to spirv.Branch.
-struct BranchOpPattern final : public OpConversionPattern<cf::BranchOp> {
- using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
+struct BranchOpPattern final : OpConversionPattern<cf::BranchOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (failed(checkBlockArguments(*op.getDest(), op, *getTypeConverter())))
+ return failure();
+
rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
adaptor.getDestOperands());
return success();
@@ -45,16 +61,23 @@ struct BranchOpPattern final : public OpConversionPattern<cf::BranchOp> {
};
/// Converts cf.cond_br to spirv.BranchConditional.
-struct CondBranchOpPattern final
- : public OpConversionPattern<cf::CondBranchOp> {
- using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
+struct CondBranchOpPattern final : OpConversionPattern<cf::CondBranchOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (failed(checkBlockArguments(*op.getTrueDest(), op, *getTypeConverter())))
+ return failure();
+
+ if (failed(
+ checkBlockArguments(*op.getFalseDest(), op, *getTypeConverter())))
+ return failure();
+
rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
- op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(),
- op.getFalseDest(), adaptor.getFalseDestOperands());
+ op, adaptor.getCondition(), op.getTrueDest(),
+ adaptor.getTrueDestOperands(), op.getFalseDest(),
+ adaptor.getFalseDestOperands());
return success();
}
};
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index d8aecae257b461e..a752b82eac7c343 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -25,7 +25,7 @@ using namespace mlir;
namespace {
/// A pass converting MLIR ControlFlow operations into the SPIR-V dialect.
-class ConvertControlFlowToSPIRVPass
+class ConvertControlFlowToSPIRVPass final
: public impl::ConvertControlFlowToSPIRVBase<
ConvertControlFlowToSPIRVPass> {
void runOnOperation() override;
@@ -44,6 +44,8 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
+ // TODO: We should also take care of block argument type conversion.
+
RewritePatternSet patterns(context);
cf::populateControlFlowToSPIRVPatterns(typeConverter, patterns);
diff --git a/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir b/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
index 98ee6891453496d..0e078dc43208c7b 100644
--- a/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -convert-cf-to-spirv -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --convert-cf-to-spirv --verify-diagnostics %s | FileCheck %s
//===----------------------------------------------------------------------===//
// cf.br, cf.cond_br
@@ -36,3 +36,20 @@ func.func @simple_loop(%begin: i32, %end: i32, %step: i32) {
}
}
+
+// -----
+
+// TODO: We should handle blocks whose arguments require type conversion.
+
+func.func @main_graph(%arg0: index) {
+ %c3 = arith.constant 1 : index
+ cf.br ^bb1(%arg0 : index)
+// expected-error at +1 {{'cf.br' op failed to match, destination argument not legalized}}
+^bb1(%0: index): // 2 preds: ^bb0, ^bb2
+ %1 = arith.cmpi slt, %0, %c3 : index
+ cf.cond_br %1, ^bb2, ^bb3
+^bb2: // pred: ^bb1
+ cf.br ^bb1(%c3 : index)
+^bb3: // pred: ^bb1
+ return
+}
>From 54cab9910deafe85f2468928b54b77f923e4c0dc Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Tue, 31 Oct 2023 23:13:37 -0400
Subject: [PATCH 2/2] fix test
---
mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir b/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
index 0e078dc43208c7b..7303ea6e9ca6242 100644
--- a/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
@@ -43,12 +43,13 @@ func.func @simple_loop(%begin: i32, %end: i32, %step: i32) {
func.func @main_graph(%arg0: index) {
%c3 = arith.constant 1 : index
- cf.br ^bb1(%arg0 : index)
// expected-error at +1 {{'cf.br' op failed to match, destination argument not legalized}}
+ cf.br ^bb1(%arg0 : index)
^bb1(%0: index): // 2 preds: ^bb0, ^bb2
%1 = arith.cmpi slt, %0, %c3 : index
cf.cond_br %1, ^bb2, ^bb3
^bb2: // pred: ^bb1
+// expected-error at +1 {{'cf.br' op failed to match, destination argument not legalized}}
cf.br ^bb1(%c3 : index)
^bb3: // pred: ^bb1
return
More information about the Mlir-commits
mailing list