[Mlir-commits] [mlir] [mlir][spirv][cf] Check destination block argument types (PR #70889)
Jakub Kuderski
llvmlistbot at llvm.org
Tue Oct 31 21:49:45 PDT 2023
https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/70889
>From f4b1f53f82b553dbac57cd43fd68aec5a0a41ea1 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Tue, 31 Oct 2023 23:06:53 -0400
Subject: [PATCH] [mlir][spirv][cf] Check destination block argument types
Do not match with illegal destination blocks.
We should add region/block argument type conversions to properly address
this.
Issue: https://github.com/llvm/llvm-project/issues/70813
---
.../ControlFlowToSPIRV/ControlFlowToSPIRV.cpp | 46 +++++++++++++++----
.../ControlFlowToSPIRVPass.cpp | 4 +-
.../ControlFlowToSPIRV/cf-ops-to-spirv.mlir | 19 +++++++-
3 files changed, 59 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
index 6787a5ccd3a4a87..5dba79016120b38 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
@@ -18,26 +18,48 @@
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatVariadic.h"
#define DEBUG_TYPE "cf-to-spirv-pattern"
using namespace mlir;
+/// Checks that the target block arguments are legal.
+static LogicalResult checkBlockArguments(Block &block, Operation *op,
+ PatternRewriter &rewriter,
+ const TypeConverter &converter) {
+ for (BlockArgument arg : block.getArguments()) {
+ if (!converter.isLegal(arg.getType())) {
+ return rewriter.notifyMatchFailure(
+ op,
+ llvm::formatv(
+ "failed to match, destination argument not legalized (found {0})",
+ arg));
+ }
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
namespace {
-
/// Converts cf.br to spirv.Branch.
-struct BranchOpPattern final : public OpConversionPattern<cf::BranchOp> {
- using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
+struct BranchOpPattern final : OpConversionPattern<cf::BranchOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (failed(checkBlockArguments(*op.getDest(), op, rewriter,
+ *getTypeConverter())))
+ return failure();
+
rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
adaptor.getDestOperands());
return success();
@@ -45,16 +67,24 @@ struct BranchOpPattern final : public OpConversionPattern<cf::BranchOp> {
};
/// Converts cf.cond_br to spirv.BranchConditional.
-struct CondBranchOpPattern final
- : public OpConversionPattern<cf::CondBranchOp> {
- using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
+struct CondBranchOpPattern final : OpConversionPattern<cf::CondBranchOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (failed(checkBlockArguments(*op.getTrueDest(), op, rewriter,
+ *getTypeConverter())))
+ return failure();
+
+ if (failed(checkBlockArguments(*op.getFalseDest(), op, rewriter,
+ *getTypeConverter())))
+ return failure();
+
rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
- op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(),
- op.getFalseDest(), adaptor.getFalseDestOperands());
+ op, adaptor.getCondition(), op.getTrueDest(),
+ adaptor.getTrueDestOperands(), op.getFalseDest(),
+ adaptor.getFalseDestOperands());
return success();
}
};
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index d8aecae257b461e..a752b82eac7c343 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -25,7 +25,7 @@ using namespace mlir;
namespace {
/// A pass converting MLIR ControlFlow operations into the SPIR-V dialect.
-class ConvertControlFlowToSPIRVPass
+class ConvertControlFlowToSPIRVPass final
: public impl::ConvertControlFlowToSPIRVBase<
ConvertControlFlowToSPIRVPass> {
void runOnOperation() override;
@@ -44,6 +44,8 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
+ // TODO: We should also take care of block argument type conversion.
+
RewritePatternSet patterns(context);
cf::populateControlFlowToSPIRVPatterns(typeConverter, patterns);
diff --git a/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir b/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
index 98ee6891453496d..a10a1468e582499 100644
--- a/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -convert-cf-to-spirv -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --convert-cf-to-spirv --verify-diagnostics %s | FileCheck %s
//===----------------------------------------------------------------------===//
// cf.br, cf.cond_br
@@ -36,3 +36,20 @@ func.func @simple_loop(%begin: i32, %end: i32, %step: i32) {
}
}
+
+// -----
+
+// TODO: We should handle blocks whose arguments require type conversion.
+
+// CHECK-LABEL: func.func @main_graph
+func.func @main_graph(%arg0: index) {
+ %c3 = arith.constant 1 : index
+ cf.br ^bb1(%arg0 : index)
+^bb1(%0: index): // 2 preds: ^bb0, ^bb2
+ %1 = arith.cmpi slt, %0, %c3 : index
+ cf.cond_br %1, ^bb2, ^bb3
+^bb2: // pred: ^bb1
+ cf.br ^bb1(%c3 : index)
+^bb3: // pred: ^bb1
+ return
+}
More information about the Mlir-commits
mailing list