[Mlir-commits] [mlir] 4f47677 - [mlir][arith][spirv] Account for possible type conversion failures

Jakub Kuderski llvmlistbot at llvm.org
Wed Dec 14 16:33:42 PST 2022


Author: Jakub Kuderski
Date: 2022-12-14T19:32:40-05:00
New Revision: 4f47677dee24b78548b49f3abca2c1ea65a79c8a

URL: https://github.com/llvm/llvm-project/commit/4f47677dee24b78548b49f3abca2c1ea65a79c8a
DIFF: https://github.com/llvm/llvm-project/commit/4f47677dee24b78548b49f3abca2c1ea65a79c8a.diff

LOG: [mlir][arith][spirv] Account for possible type conversion failures

Check results of all type conversions in `--convert-arith-to-spirv`.

Fixes: https://github.com/llvm/llvm-project/issues/59496

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D140033

Added: 
    

Modified: 
    mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
    mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index ed5d044f82ed4..9533494c5456c 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -320,10 +320,13 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
 
 /// Returns true if the given `type` is a boolean scalar or vector type.
 static bool isBoolScalarOrVector(Type type) {
+  assert(type && "Not a valid type");
   if (type.isInteger(1))
     return true;
+
   if (auto vecType = type.dyn_cast<VectorType>())
     return vecType.getElementType().isInteger(1);
+
   return false;
 }
 
@@ -343,6 +346,22 @@ static bool hasSameBitwidth(Type a, Type b) {
   return aBW != 0 && bBW != 0 && aBW == bBW;
 }
 
+/// Returns a source type conversion failure for `srcType` and operation `op`.
+static LogicalResult
+getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op,
+                         Type srcType) {
+  return rewriter.notifyMatchFailure(
+      op->getLoc(),
+      llvm::formatv("failed to convert source type '{0}'", srcType));
+}
+
+/// Returns a source type conversion failure for the result type of `op`.
+static LogicalResult
+getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {
+  assert(op->getNumResults() == 1);
+  return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
+}
+
 //===----------------------------------------------------------------------===//
 // ConstantOp with composite type
 //===----------------------------------------------------------------------===//
@@ -562,10 +581,10 @@ BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite(
     Op op, typename Op::Adaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   assert(adaptor.getOperands().size() == 2);
-  auto dstType =
-      this->getTypeConverter()->convertType(op.getResult().getType());
+  Type dstType = this->getTypeConverter()->convertType(op.getType());
   if (!dstType)
-    return failure();
+    return getTypeConversionFailure(rewriter, op);
+
   if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
     rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType,
                                                          adaptor.getOperands());
@@ -590,7 +609,8 @@ LogicalResult XOrIOpLogicalPattern::matchAndRewrite(
 
   Type dstType = getTypeConverter()->convertType(op.getType());
   if (!dstType)
-    return failure();
+    return getTypeConversionFailure(rewriter, op);
+
   rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
                                                    adaptor.getOperands());
 
@@ -611,7 +631,8 @@ LogicalResult XOrIOpBooleanPattern::matchAndRewrite(
 
   Type dstType = getTypeConverter()->convertType(op.getType());
   if (!dstType)
-    return failure();
+    return getTypeConversionFailure(rewriter, op);
+
   rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType,
                                                         adaptor.getOperands());
   return success();
@@ -628,7 +649,10 @@ UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
   if (!isBoolScalarOrVector(srcType))
     return failure();
 
-  Type dstType = getTypeConverter()->convertType(op.getResult().getType());
+  Type dstType = getTypeConverter()->convertType(op.getType());
+  if (!dstType)
+    return getTypeConversionFailure(rewriter, op);
+
   Location loc = op.getLoc();
   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
@@ -649,7 +673,9 @@ ExtSII1Pattern::matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
     return failure();
 
   Location loc = op.getLoc();
-  Type dstType = getTypeConverter()->convertType(op.getResult().getType());
+  Type dstType = getTypeConverter()->convertType(op.getType());
+  if (!dstType)
+    return getTypeConversionFailure(rewriter, op);
 
   Value allOnes;
   if (auto intTy = dstType.dyn_cast<IntegerType>()) {
@@ -684,7 +710,10 @@ ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
   if (!isBoolScalarOrVector(srcType))
     return failure();
 
-  Type dstType = getTypeConverter()->convertType(op.getResult().getType());
+  Type dstType = getTypeConverter()->convertType(op.getType());
+  if (!dstType)
+    return getTypeConversionFailure(rewriter, op);
+
   Location loc = op.getLoc();
   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
@@ -700,7 +729,10 @@ ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
 LogicalResult
 TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
                                  ConversionPatternRewriter &rewriter) const {
-  Type dstType = getTypeConverter()->convertType(op.getResult().getType());
+  Type dstType = getTypeConverter()->convertType(op.getType());
+  if (!dstType)
+    return getTypeConversionFailure(rewriter, op);
+
   if (!isBoolScalarOrVector(dstType))
     return failure();
 
@@ -728,10 +760,13 @@ LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite(
     ConversionPatternRewriter &rewriter) const {
   assert(adaptor.getOperands().size() == 1);
   Type srcType = adaptor.getOperands().front().getType();
-  Type dstType =
-      this->getTypeConverter()->convertType(op.getResult().getType());
+  Type dstType = this->getTypeConverter()->convertType(op.getType());
+  if (!dstType)
+    return getTypeConversionFailure(rewriter, op);
+
   if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
     return failure();
+
   if (dstType == srcType) {
     // Due to type conversion, we are seeing the same source and target type.
     // Then we can just erase this operation by forwarding its operand.
@@ -755,7 +790,7 @@ LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
     return failure();
   Type dstType = getTypeConverter()->convertType(srcType);
   if (!dstType)
-    return failure();
+    return getTypeConversionFailure(rewriter, op, srcType);
 
   switch (op.getPredicate()) {
   case arith::CmpIPredicate::eq: {
@@ -804,7 +839,7 @@ CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
     return failure();
   Type dstType = getTypeConverter()->convertType(srcType);
   if (!dstType)
-    return failure();
+    return getTypeConversionFailure(rewriter, op, srcType);
 
   switch (op.getPredicate()) {
 #define DISPATCH(cmpPredicate, spirvOp)                                        \
@@ -999,7 +1034,7 @@ LogicalResult MinMaxFOpPattern<Op, SPIRVOp>::matchAndRewrite(
   auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
   Type dstType = converter->convertType(op.getType());
   if (!dstType)
-    return failure();
+    return getTypeConversionFailure(rewriter, op);
 
   // arith.maxf/minf:
   //   "if one of the arguments is NaN, then the result is also NaN."

diff  --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index f6e84e80bbf51..0d92a8e676d85 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -130,3 +130,19 @@ func.func @unsupported_f64(%arg0: f64) {
 }
 
 } // end module
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+
+// i64 is not a valid result type in this target env.
+func.func @type_conversion_failure(%arg0: i32) {
+  // expected-error at +1 {{failed to legalize operation 'arith.extsi'}}
+  %2 = arith.extsi %arg0 : i32 to i64
+  return
+}
+
+} // end module


        


More information about the Mlir-commits mailing list