[Mlir-commits] [mlir] 4f47677 - [mlir][arith][spirv] Account for possible type conversion failures
Jakub Kuderski
llvmlistbot at llvm.org
Wed Dec 14 16:33:42 PST 2022
Author: Jakub Kuderski
Date: 2022-12-14T19:32:40-05:00
New Revision: 4f47677dee24b78548b49f3abca2c1ea65a79c8a
URL: https://github.com/llvm/llvm-project/commit/4f47677dee24b78548b49f3abca2c1ea65a79c8a
DIFF: https://github.com/llvm/llvm-project/commit/4f47677dee24b78548b49f3abca2c1ea65a79c8a.diff
LOG: [mlir][arith][spirv] Account for possible type conversion failures
Check results of all type conversions in `--convert-arith-to-spirv`.
Fixes: https://github.com/llvm/llvm-project/issues/59496
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D140033
Added:
Modified:
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index ed5d044f82ed4..9533494c5456c 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -320,10 +320,13 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
/// Returns true if the given `type` is a boolean scalar or vector type.
static bool isBoolScalarOrVector(Type type) {
+ assert(type && "Not a valid type");
if (type.isInteger(1))
return true;
+
if (auto vecType = type.dyn_cast<VectorType>())
return vecType.getElementType().isInteger(1);
+
return false;
}
@@ -343,6 +346,22 @@ static bool hasSameBitwidth(Type a, Type b) {
return aBW != 0 && bBW != 0 && aBW == bBW;
}
+/// Returns a source type conversion failure for `srcType` and operation `op`.
+static LogicalResult
+getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op,
+ Type srcType) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(),
+ llvm::formatv("failed to convert source type '{0}'", srcType));
+}
+
+/// Returns a source type conversion failure for the result type of `op`.
+static LogicalResult
+getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {
+ assert(op->getNumResults() == 1);
+ return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
+}
+
//===----------------------------------------------------------------------===//
// ConstantOp with composite type
//===----------------------------------------------------------------------===//
@@ -562,10 +581,10 @@ BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite(
Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
assert(adaptor.getOperands().size() == 2);
- auto dstType =
- this->getTypeConverter()->convertType(op.getResult().getType());
+ Type dstType = this->getTypeConverter()->convertType(op.getType());
if (!dstType)
- return failure();
+ return getTypeConversionFailure(rewriter, op);
+
if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType,
adaptor.getOperands());
@@ -590,7 +609,8 @@ LogicalResult XOrIOpLogicalPattern::matchAndRewrite(
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
- return failure();
+ return getTypeConversionFailure(rewriter, op);
+
rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
adaptor.getOperands());
@@ -611,7 +631,8 @@ LogicalResult XOrIOpBooleanPattern::matchAndRewrite(
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
- return failure();
+ return getTypeConversionFailure(rewriter, op);
+
rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType,
adaptor.getOperands());
return success();
@@ -628,7 +649,10 @@ UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
if (!isBoolScalarOrVector(srcType))
return failure();
- Type dstType = getTypeConverter()->convertType(op.getResult().getType());
+ Type dstType = getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
Location loc = op.getLoc();
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
@@ -649,7 +673,9 @@ ExtSII1Pattern::matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
return failure();
Location loc = op.getLoc();
- Type dstType = getTypeConverter()->convertType(op.getResult().getType());
+ Type dstType = getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
Value allOnes;
if (auto intTy = dstType.dyn_cast<IntegerType>()) {
@@ -684,7 +710,10 @@ ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
if (!isBoolScalarOrVector(srcType))
return failure();
- Type dstType = getTypeConverter()->convertType(op.getResult().getType());
+ Type dstType = getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
Location loc = op.getLoc();
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
@@ -700,7 +729,10 @@ ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
LogicalResult
TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- Type dstType = getTypeConverter()->convertType(op.getResult().getType());
+ Type dstType = getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
if (!isBoolScalarOrVector(dstType))
return failure();
@@ -728,10 +760,13 @@ LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const {
assert(adaptor.getOperands().size() == 1);
Type srcType = adaptor.getOperands().front().getType();
- Type dstType =
- this->getTypeConverter()->convertType(op.getResult().getType());
+ Type dstType = this->getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
return failure();
+
if (dstType == srcType) {
// Due to type conversion, we are seeing the same source and target type.
// Then we can just erase this operation by forwarding its operand.
@@ -755,7 +790,7 @@ LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
return failure();
Type dstType = getTypeConverter()->convertType(srcType);
if (!dstType)
- return failure();
+ return getTypeConversionFailure(rewriter, op, srcType);
switch (op.getPredicate()) {
case arith::CmpIPredicate::eq: {
@@ -804,7 +839,7 @@ CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
return failure();
Type dstType = getTypeConverter()->convertType(srcType);
if (!dstType)
- return failure();
+ return getTypeConversionFailure(rewriter, op, srcType);
switch (op.getPredicate()) {
#define DISPATCH(cmpPredicate, spirvOp) \
@@ -999,7 +1034,7 @@ LogicalResult MinMaxFOpPattern<Op, SPIRVOp>::matchAndRewrite(
auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
Type dstType = converter->convertType(op.getType());
if (!dstType)
- return failure();
+ return getTypeConversionFailure(rewriter, op);
// arith.maxf/minf:
// "if one of the arguments is NaN, then the result is also NaN."
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index f6e84e80bbf51..0d92a8e676d85 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -130,3 +130,19 @@ func.func @unsupported_f64(%arg0: f64) {
}
} // end module
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+
+// i64 is not a valid result type in this target env.
+func.func @type_conversion_failure(%arg0: i32) {
+ // expected-error at +1 {{failed to legalize operation 'arith.extsi'}}
+ %2 = arith.extsi %arg0 : i32 to i64
+ return
+}
+
+} // end module
More information about the Mlir-commits
mailing list