[Mlir-commits] [mlir] 7f7e33c - [mlir][spirv][math] Fix crash on unsupported types in math-to-spirv
Jakub Kuderski
llvmlistbot at llvm.org
Thu Nov 17 10:46:29 PST 2022
Author: Jakub Kuderski
Date: 2022-11-17T13:45:36-05:00
New Revision: 7f7e33c2481dc10bd883129c8998b8d925375213
URL: https://github.com/llvm/llvm-project/commit/7f7e33c2481dc10bd883129c8998b8d925375213
DIFF: https://github.com/llvm/llvm-project/commit/7f7e33c2481dc10bd883129c8998b8d925375213.diff
LOG: [mlir][spirv][math] Fix crash on unsupported types in math-to-spirv
Fail to match conversion patterns when source op has unsupported types.
Fixes: https://github.com/llvm/llvm-project/issues/58749
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D138178
Added:
Modified:
mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
mlir/lib/Conversion/SPIRVCommon/Pattern.h
mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir
mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 242000485d554..5bd06c947e49c 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -18,7 +18,9 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatVariadic.h"
#define DEBUG_TYPE "math-to-spirv-pattern"
@@ -46,6 +48,48 @@ static Value getScalarOrVectorI32Constant(Type type, int value,
return nullptr;
}
+/// Check if the type is supported by math-to-spirv conversion. We expect to
+/// only see scalars and vectors at this point, with higher-level types already
+/// lowered.
+static bool isSupportedSourceType(Type originalType) {
+ if (originalType.isIntOrIndexOrFloat())
+ return true;
+
+ if (auto vecTy = originalType.dyn_cast<VectorType>()) {
+ if (!vecTy.getElementType().isIntOrIndexOrFloat())
+ return false;
+ if (vecTy.isScalable())
+ return false;
+ if (vecTy.getRank() > 1)
+ return false;
+
+ return true;
+ }
+
+ return false;
+}
+
+/// Check if all `sourceOp` types are supported by math-to-spirv conversion.
+/// Notify of a match failure othwerise and return a `failure` result.
+/// This is intended to simplify type checks in `OpConversionPattern`s.
+static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter,
+ Operation *sourceOp) {
+ auto allTypes = llvm::to_vector(sourceOp->getOperandTypes());
+ llvm::append_range(allTypes, sourceOp->getResultTypes());
+
+ for (Type ty : allTypes) {
+ if (!isSupportedSourceType(ty)) {
+ return rewriter.notifyMatchFailure(
+ sourceOp,
+ llvm::formatv(
+ "unsupported source type for Math to SPIR-V conversion: {0}",
+ ty));
+ }
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
@@ -55,14 +99,36 @@ static Value getScalarOrVectorI32Constant(Type type, int value,
// normal RewritePattern.
namespace {
+/// Converts elementwise unary, binary, and ternary standard operations to
+/// SPIR-V operations. Checks that source `Op` types are supported.
+template <typename Op, typename SPIRVOp>
+struct CheckedElementwiseOpPattern final
+ : public spirv::ElementwiseOpPattern<Op, SPIRVOp> {
+ using BasePattern = typename spirv::ElementwiseOpPattern<Op, SPIRVOp>;
+ using BasePattern::BasePattern;
+
+ LogicalResult
+ matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (LogicalResult res = checkSourceOpTypes(rewriter, op); failed(res))
+ return res;
+
+ return BasePattern::matchAndRewrite(op, adaptor, rewriter);
+ }
+};
+
/// Converts math.copysign to SPIR-V ops.
-class CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
+struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto type = getTypeConverter()->convertType(copySignOp.getType());
+ if (LogicalResult res = checkSourceOpTypes(rewriter, copySignOp);
+ failed(res))
+ return res;
+
+ Type type = getTypeConverter()->convertType(copySignOp.getType());
if (!type)
return failure();
@@ -121,14 +187,17 @@ class CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
/// SPIR-V does not have a direct operations for counting leading zeros. If
/// Shader capability is supported, we can leverage GL FindUMsb to calculate
/// it.
-class CountLeadingZerosPattern final
+struct CountLeadingZerosPattern final
: public OpConversionPattern<math::CountLeadingZerosOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto type = getTypeConverter()->convertType(countOp.getType());
+ if (LogicalResult res = checkSourceOpTypes(rewriter, countOp); failed(res))
+ return res;
+
+ Type type = getTypeConverter()->convertType(countOp.getType());
if (!type)
return failure();
@@ -177,9 +246,16 @@ struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(adaptor.getOperands().size() == 1);
+ if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
+ failed(res))
+ return res;
+
Location loc = operation.getLoc();
- auto type = this->getTypeConverter()->convertType(operation.getType());
- auto exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
+ Type type = this->getTypeConverter()->convertType(operation.getType());
+ if (!type)
+ return failure();
+
+ Value exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
return success();
@@ -198,10 +274,17 @@ struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(adaptor.getOperands().size() == 1);
+ if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
+ failed(res))
+ return res;
+
Location loc = operation.getLoc();
- auto type = this->getTypeConverter()->convertType(operation.getType());
+ Type type = this->getTypeConverter()->convertType(operation.getType());
+ if (!type)
+ return failure();
+
auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
- auto onePlus =
+ Value onePlus =
rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand());
rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
return success();
@@ -215,7 +298,10 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
LogicalResult
matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto dstType = getTypeConverter()->convertType(powfOp.getType());
+ if (LogicalResult res = checkSourceOpTypes(rewriter, powfOp); failed(res))
+ return res;
+
+ Type dstType = getTypeConverter()->convertType(powfOp.getType());
if (!dstType)
return failure();
@@ -241,10 +327,13 @@ struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
LogicalResult
matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (LogicalResult res = checkSourceOpTypes(rewriter, roundOp); failed(res))
+ return res;
+
Location loc = roundOp.getLoc();
- auto operand = roundOp.getOperand();
- auto ty = operand.getType();
- auto ety = getElementTypeOrSelf(ty);
+ Value operand = roundOp.getOperand();
+ Type ty = operand.getType();
+ Type ety = getElementTypeOrSelf(ty);
auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
@@ -287,38 +376,38 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
patterns
.add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
- spirv::ElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
- spirv::ElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
- spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
- spirv::ElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
- spirv::ElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
- spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
- spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
- spirv::ElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
- spirv::ElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
- spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
- spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
- spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
- spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>>(
+ CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
+ CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
+ CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
+ CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
+ CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
+ CheckedElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
+ CheckedElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
+ CheckedElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
+ CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
+ CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
+ CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
+ CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
+ CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>>(
typeConverter, patterns.getContext());
// OpenCL patterns
patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
- spirv::ElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
- spirv::ElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
- spirv::ElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
- spirv::ElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
- spirv::ElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
- spirv::ElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
- spirv::ElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
- spirv::ElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
- spirv::ElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
- spirv::ElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
- spirv::ElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
- spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
- spirv::ElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
- spirv::ElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
- spirv::ElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>>(
+ CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
+ CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
+ CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
+ CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
+ CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
+ CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
+ CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
+ CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
+ CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
+ CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
+ CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
+ CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
+ CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
+ CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
+ CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>>(
typeConverter, patterns.getContext());
}
diff --git a/mlir/lib/Conversion/SPIRVCommon/Pattern.h b/mlir/lib/Conversion/SPIRVCommon/Pattern.h
index ed859a86b64dc..4da3e197ca3ed 100644
--- a/mlir/lib/Conversion/SPIRVCommon/Pattern.h
+++ b/mlir/lib/Conversion/SPIRVCommon/Pattern.h
@@ -19,8 +19,7 @@ namespace spirv {
/// Converts elementwise unary, binary and ternary standard operations to SPIR-V
/// operations.
template <typename Op, typename SPIRVOp>
-class ElementwiseOpPattern final : public OpConversionPattern<Op> {
-public:
+struct ElementwiseOpPattern : public OpConversionPattern<Op> {
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir
index a9ea026d86b86..e84b9b0f97717 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir
@@ -41,3 +41,27 @@ func.func @copy_sign_vector(%value: vector<3xf16>, %sign: vector<3xf16>) -> vect
// CHECK: %[[OR:.+]] = spirv.BitwiseOr %[[VAND]], %[[SAND]] : vector<3xi16>
// CHECK: %[[RESULT:.+]] = spirv.Bitcast %[[OR]] : vector<3xi16> to vector<3xf16>
// CHECK: return %[[RESULT]]
+
+// -----
+
+// 2-D vectors are not supported.
+func.func @copy_sign_2d_vector(%value: vector<3x3xf32>, %sign: vector<3x3xf32>) -> vector<3x3xf32> {
+ %0 = math.copysign %value, %sign : vector<3x3xf32>
+ return %0: vector<3x3xf32>
+}
+
+// CHECK-LABEL: func @copy_sign_2d_vector
+// CHECK-NEXT: math.copysign {{%.+}}, {{%.+}} : vector<3x3xf32>
+// CHECK-NEXT: return
+
+// -----
+
+// Tensors are not supported.
+func.func @copy_sign_tensor(%value: tensor<3x3xf32>, %sign: tensor<3x3xf32>) -> tensor<3x3xf32> {
+ %0 = math.copysign %value, %sign : tensor<3x3xf32>
+ return %0: tensor<3x3xf32>
+}
+
+// CHECK-LABEL: func @copy_sign_tensor
+// CHECK-NEXT: math.copysign {{%.+}}, {{%.+}} : tensor<3x3xf32>
+// CHECK-NEXT: return
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
index a29b18b6812b9..125478e2cb214 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
@@ -211,3 +211,51 @@ func.func @ctlz_vector2(%val: vector<2xi16>) -> vector<2xi16> {
}
} // end module
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>
+} {
+
+// 2-D vectors are not supported.
+
+// CHECK-LABEL: @vector_2d
+func.func @vector_2d(%arg0: vector<2x2xf32>) {
+ // CHECK-NEXT: math.cos {{.+}} : vector<2x2xf32>
+ %0 = math.cos %arg0 : vector<2x2xf32>
+ // CHECK-NEXT: math.exp {{.+}} : vector<2x2xf32>
+ %1 = math.exp %arg0 : vector<2x2xf32>
+ // CHECK-NEXT: math.absf {{.+}} : vector<2x2xf32>
+ %2 = math.absf %arg0 : vector<2x2xf32>
+ // CHECK-NEXT: math.ceil {{.+}} : vector<2x2xf32>
+ %3 = math.ceil %arg0 : vector<2x2xf32>
+ // CHECK-NEXT: math.floor {{.+}} : vector<2x2xf32>
+ %4 = math.floor %arg0 : vector<2x2xf32>
+ // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : vector<2x2xf32>
+ %5 = math.powf %arg0, %arg0 : vector<2x2xf32>
+ // CHECK-NEXT: return
+ return
+}
+
+// Tensors are not supported.
+
+// CHECK-LABEL: @tensor_1d
+func.func @tensor_1d(%arg0: tensor<2xf32>) {
+ // CHECK-NEXT: math.cos {{.+}} : tensor<2xf32>
+ %0 = math.cos %arg0 : tensor<2xf32>
+ // CHECK-NEXT: math.exp {{.+}} : tensor<2xf32>
+ %1 = math.exp %arg0 : tensor<2xf32>
+ // CHECK-NEXT: math.absf {{.+}} : tensor<2xf32>
+ %2 = math.absf %arg0 : tensor<2xf32>
+ // CHECK-NEXT: math.ceil {{.+}} : tensor<2xf32>
+ %3 = math.ceil %arg0 : tensor<2xf32>
+ // CHECK-NEXT: math.floor {{.+}} : tensor<2xf32>
+ %4 = math.floor %arg0 : tensor<2xf32>
+ // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : tensor<2xf32>
+ %5 = math.powf %arg0, %arg0 : tensor<2xf32>
+ // CHECK-NEXT: return
+ return
+}
+
+} // end module
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
index 6897cfd9f2f5a..da02e6287961d 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
@@ -100,3 +100,51 @@ func.func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>,
}
} // end module
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>
+} {
+
+// 2-D vectors are not supported.
+
+// CHECK-LABEL: @vector_2d
+func.func @vector_2d(%arg0: vector<2x2xf32>) {
+ // CHECK-NEXT: math.cos {{.+}} : vector<2x2xf32>
+ %0 = math.cos %arg0 : vector<2x2xf32>
+ // CHECK-NEXT: math.exp {{.+}} : vector<2x2xf32>
+ %1 = math.exp %arg0 : vector<2x2xf32>
+ // CHECK-NEXT: math.absf {{.+}} : vector<2x2xf32>
+ %2 = math.absf %arg0 : vector<2x2xf32>
+ // CHECK-NEXT: math.ceil {{.+}} : vector<2x2xf32>
+ %3 = math.ceil %arg0 : vector<2x2xf32>
+ // CHECK-NEXT: math.floor {{.+}} : vector<2x2xf32>
+ %4 = math.floor %arg0 : vector<2x2xf32>
+ // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : vector<2x2xf32>
+ %5 = math.powf %arg0, %arg0 : vector<2x2xf32>
+ // CHECK-NEXT: return
+ return
+}
+
+// Tensors are not supported.
+
+// CHECK-LABEL: @tensor_1d
+func.func @tensor_1d(%arg0: tensor<2xf32>) {
+ // CHECK-NEXT: math.cos {{.+}} : tensor<2xf32>
+ %0 = math.cos %arg0 : tensor<2xf32>
+ // CHECK-NEXT: math.exp {{.+}} : tensor<2xf32>
+ %1 = math.exp %arg0 : tensor<2xf32>
+ // CHECK-NEXT: math.absf {{.+}} : tensor<2xf32>
+ %2 = math.absf %arg0 : tensor<2xf32>
+ // CHECK-NEXT: math.ceil {{.+}} : tensor<2xf32>
+ %3 = math.ceil %arg0 : tensor<2xf32>
+ // CHECK-NEXT: math.floor {{.+}} : tensor<2xf32>
+ %4 = math.floor %arg0 : tensor<2xf32>
+ // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : tensor<2xf32>
+ %5 = math.powf %arg0, %arg0 : tensor<2xf32>
+ // CHECK-NEXT: return
+ return
+}
+
+} // end module
More information about the Mlir-commits
mailing list