[Mlir-commits] [mlir] 179978d - [mlir][arith][spirv] Hard fail in `-convert-arith-to-spirv`
Jakub Kuderski
llvmlistbot at llvm.org
Mon Oct 31 14:01:49 PDT 2022
Author: Jakub Kuderski
Date: 2022-10-31T17:01:21-04:00
New Revision: 179978d7b8ec00291401d2ec49fc0a55e7f7bfb3
URL: https://github.com/llvm/llvm-project/commit/179978d7b8ec00291401d2ec49fc0a55e7f7bfb3
DIFF: https://github.com/llvm/llvm-project/commit/179978d7b8ec00291401d2ec49fc0a55e7f7bfb3.diff
LOG: [mlir][arith][spirv] Hard fail in `-convert-arith-to-spirv`
Turn legalization failures into hard failures to make sure that we do
not miss conversion pattern application failures.
Add a message on type conversion failure.
Move unsupported cases into a separate test file.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D137102
Added:
mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
Modified:
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
mlir/lib/Conversion/SPIRVCommon/Pattern.h
mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 24bead82c18bb..2452928dd4503 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -1046,6 +1046,9 @@ struct ConvertArithToSPIRVPass
typeConverter.addTargetMaterialization(addUnrealizedCast);
target->addLegalOp<UnrealizedConversionCastOp>();
+ // Fail hard when there are any remaining 'arith' ops.
+ target->addIllegalDialect<arith::ArithDialect>();
+
RewritePatternSet patterns(&getContext());
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
diff --git a/mlir/lib/Conversion/SPIRVCommon/Pattern.h b/mlir/lib/Conversion/SPIRVCommon/Pattern.h
index 5d32fa81cc6b2..ed859a86b64dc 100644
--- a/mlir/lib/Conversion/SPIRVCommon/Pattern.h
+++ b/mlir/lib/Conversion/SPIRVCommon/Pattern.h
@@ -11,6 +11,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
namespace mlir {
namespace spirv {
@@ -26,9 +27,13 @@ class ElementwiseOpPattern final : public OpConversionPattern<Op> {
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(adaptor.getOperands().size() <= 3);
- auto dstType = this->getTypeConverter()->convertType(op.getType());
- if (!dstType)
- return failure();
+ Type dstType = this->getTypeConverter()->convertType(op.getType());
+ if (!dstType) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(),
+ llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
+ }
+
if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
!op.getType().isIndex() && dstType != op.getType()) {
return op.emitError(
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
new file mode 100644
index 0000000000000..967adbc84a3bb
--- /dev/null
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt -split-input-file -convert-arith-to-spirv -verify-diagnostics %s
+
+///===----------------------------------------------------------------------===//
+// Binary ops
+//===----------------------------------------------------------------------===//
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>
+} {
+
+func.func @unsupported_5elem_vector(%arg0: vector<5xi32>) {
+ // expected-error at +1 {{failed to legalize operation 'arith.subi'}}
+ %1 = arith.subi %arg0, %arg0: vector<5xi32>
+ return
+}
+
+} // end module
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>
+} {
+
+func.func @unsupported_2x2elem_vector(%arg0: vector<2x2xi32>) {
+ // expected-error at +1 {{failed to legalize operation 'arith.muli'}}
+ %2 = arith.muli %arg0, %arg0: vector<2x2xi32>
+ return
+}
+
+} // end module
+
+// -----
+
+func.func @int_vector4_invalid(%arg0: vector<2xi16>) {
+ // expected-error @+2 {{failed to legalize operation 'arith.divui'}}
+ // expected-error @+1 {{bitwidth emulation is not implemented yet on unsigned op}}
+ %0 = arith.divui %arg0, %arg0: vector<2xi16>
+ return
+}
+
+///===----------------------------------------------------------------------===//
+// Constant ops
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @unsupported_constant_0() {
+ // expected-error @+1 {{failed to legalize operation 'arith.constant'}}
+ %0 = arith.constant 4294967296 : i64 // 2^32
+ return
+}
+
+// -----
+
+func.func @unsupported_constant_1() {
+ // expected-error @+1 {{failed to legalize operation 'arith.constant'}}
+ %1 = arith.constant -2147483649 : i64 // -2^31 - 1
+ return
+}
+
+// -----
+
+func.func @unsupported_constant_2() {
+ // expected-error @+1 {{failed to legalize operation 'arith.constant'}}
+ %2 = arith.constant -2147483649 : i64 // -2^31 - 1
+ return
+}
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index bd5238f3629f6..df6806a0e4bd1 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -163,63 +163,6 @@ func.func @one_elem_vector(%arg0: vector<1xi32>) {
return
}
-// CHECK-LABEL: @unsupported_5elem_vector
-func.func @unsupported_5elem_vector(%arg0: vector<5xi32>) {
- // CHECK: arith.subi
- %1 = arith.subi %arg0, %arg0: vector<5xi32>
- return
-}
-
-// CHECK-LABEL: @unsupported_2x2elem_vector
-func.func @unsupported_2x2elem_vector(%arg0: vector<2x2xi32>) {
- // CHECK: arith.muli
- %2 = arith.muli %arg0, %arg0: vector<2x2xi32>
- return
-}
-
-} // end module
-
-// -----
-
-// Check that types are converted to 32-bit when no special capabilities.
-module attributes {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
-} {
-
-// CHECK-LABEL: @int_vector23
-func.func @int_vector23(%arg0: vector<2xi8>, %arg1: vector<3xi16>) {
- // CHECK: spirv.SDiv %{{.*}}, %{{.*}}: vector<2xi32>
- %0 = arith.divsi %arg0, %arg0: vector<2xi8>
- // CHECK: spirv.SDiv %{{.*}}, %{{.*}}: vector<3xi32>
- %1 = arith.divsi %arg1, %arg1: vector<3xi16>
- return
-}
-
-// CHECK-LABEL: @float_scalar
-func.func @float_scalar(%arg0: f16, %arg1: f64) {
- // CHECK: spirv.FAdd %{{.*}}, %{{.*}}: f32
- %0 = arith.addf %arg0, %arg0: f16
- // CHECK: spirv.FMul %{{.*}}, %{{.*}}: f32
- %1 = arith.mulf %arg1, %arg1: f64
- return
-}
-
-} // end module
-
-// -----
-
-// Check that types are converted to 32-bit when no special capabilities that
-// are not supported.
-module attributes {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
-} {
-
-func.func @int_vector4_invalid(%arg0: vector<4xi64>) {
- // expected-error @+1 {{bitwidth emulation is not implemented yet on unsigned op}}
- %0 = arith.divui %arg0, %arg0: vector<4xi64>
- return
-}
-
} // end module
// -----
@@ -643,17 +586,6 @@ func.func @corner_cases() {
return
}
-// CHECK-LABEL: @unsupported_cases
-func.func @unsupported_cases() {
- // CHECK: %{{.*}} = arith.constant 4294967296 : i64
- %0 = arith.constant 4294967296 : i64 // 2^32
- // CHECK: %{{.*}} = arith.constant -2147483649 : i64
- %1 = arith.constant -2147483649 : i64 // -2^31 - 1
- // CHECK: %{{.*}} = arith.constant 1.0000000000000002 : f64
- %2 = arith.constant 0x3FF0000000000001 : f64 // smallest number > 1
- return
-}
-
} // end module
// -----
@@ -1258,20 +1190,6 @@ func.func @one_elem_vector(%arg0: vector<1xi32>) {
return
}
-// CHECK-LABEL: @unsupported_5elem_vector
-func.func @unsupported_5elem_vector(%arg0: vector<5xi32>) {
- // CHECK: subi
- %1 = arith.subi %arg0, %arg0: vector<5xi32>
- return
-}
-
-// CHECK-LABEL: @unsupported_2x2elem_vector
-func.func @unsupported_2x2elem_vector(%arg0: vector<2x2xi32>) {
- // CHECK: muli
- %2 = arith.muli %arg0, %arg0: vector<2x2xi32>
- return
-}
-
} // end module
// -----
@@ -1303,22 +1221,6 @@ func.func @float_scalar(%arg0: f16, %arg1: f64) {
// -----
-// Check that types are converted to 32-bit when no special capabilities that
-// are not supported.
-module attributes {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
-} {
-
-func.func @int_vector4_invalid(%arg0: vector<4xi64>) {
- // expected-error at +1 {{bitwidth emulation is not implemented yet on unsigned op}}
- %0 = arith.divui %arg0, %arg0: vector<4xi64>
- return
-}
-
-} // end module
-
-// -----
-
//===----------------------------------------------------------------------===//
// std bit ops
//===----------------------------------------------------------------------===//
@@ -1675,17 +1577,6 @@ func.func @corner_cases() {
return
}
-// CHECK-LABEL: @unsupported_cases
-func.func @unsupported_cases() {
- // CHECK: %{{.*}} = arith.constant 4294967296 : i64
- %0 = arith.constant 4294967296 : i64 // 2^32
- // CHECK: %{{.*}} = arith.constant -2147483649 : i64
- %1 = arith.constant -2147483649 : i64 // -2^31 - 1
- // CHECK: %{{.*}} = arith.constant 1.0000000000000002 : f64
- %2 = arith.constant 0x3FF0000000000001 : f64 // smallest number > 1
- return
-}
-
} // end module
// -----
More information about the Mlir-commits
mailing list