[Mlir-commits] [mlir] c19436e - [mlir][spirv] Fix a crash of typeConverter with non supported type (#79955)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 30 19:37:00 PST 2024
Author: Kohei Yamaguchi
Date: 2024-01-30T22:36:56-05:00
New Revision: c19436eec1c236cbe622c04e33f35f1f9478fa15
URL: https://github.com/llvm/llvm-project/commit/c19436eec1c236cbe622c04e33f35f1f9478fa15
DIFF: https://github.com/llvm/llvm-project/commit/c19436eec1c236cbe622c04e33f35f1f9478fa15.diff
LOG: [mlir][spirv] Fix a crash of typeConverter with non supported type (#79955)
Fixes a crash in the `convert-to-spirv-llvm` pass caused by unsupported
types (e.g. `spirv.matrix` ). This PR fixes it by checking the converted type.
Fixes #60017
Added:
Modified:
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index c62e676efc159..11d2312b9492f 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -240,7 +240,7 @@ static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
auto dstType = typeConverter.convertType(loadOp.getType());
if (!dstType)
- return failure();
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
isVolatile, isNonTemporal);
@@ -357,22 +357,23 @@ class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
ConversionPatternRewriter &rewriter) const override {
auto dstType = typeConverter.convertType(op.getComponentPtr().getType());
if (!dstType)
- return failure();
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
// To use GEP we need to add a first 0 index to go through the pointer.
auto indices = llvm::to_vector<4>(adaptor.getIndices());
Type indexType = op.getIndices().front().getType();
auto llvmIndexType = typeConverter.convertType(indexType);
if (!llvmIndexType)
- return failure();
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
Value zero = rewriter.create<LLVM::ConstantOp>(
op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
indices.insert(indices.begin(), zero);
- rewriter.replaceOpWithNewOp<LLVM::GEPOp>(
- op, dstType,
- typeConverter.convertType(
- cast<spirv::PointerType>(op.getBasePtr().getType())
- .getPointeeType()),
- adaptor.getBasePtr(), indices);
+
+ auto elementType = typeConverter.convertType(
+ cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
+ if (!elementType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+ rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType,
+ adaptor.getBasePtr(), indices);
return success();
}
};
@@ -386,7 +387,7 @@ class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
ConversionPatternRewriter &rewriter) const override {
auto dstType = typeConverter.convertType(op.getPointer().getType());
if (!dstType)
- return failure();
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
op.getVariable());
return success();
@@ -404,7 +405,7 @@ class BitFieldInsertPattern
auto srcType = op.getType();
auto dstType = typeConverter.convertType(srcType);
if (!dstType)
- return failure();
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
Location loc = op.getLoc();
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
@@ -451,7 +452,7 @@ class ConstantScalarAndVectorPattern
auto dstType = typeConverter.convertType(srcType);
if (!dstType)
- return failure();
+ return rewriter.notifyMatchFailure(constOp, "type conversion failed");
// SPIR-V constant can be a signed/unsigned integer, which has to be
// casted to signless integer when converting to LLVM dialect. Removing the
@@ -492,7 +493,7 @@ class BitFieldSExtractPattern
auto srcType = op.getType();
auto dstType = typeConverter.convertType(srcType);
if (!dstType)
- return failure();
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
Location loc = op.getLoc();
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
@@ -545,7 +546,7 @@ class BitFieldUExtractPattern
auto srcType = op.getType();
auto dstType = typeConverter.convertType(srcType);
if (!dstType)
- return failure();
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
Location loc = op.getLoc();
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
@@ -621,7 +622,7 @@ class CompositeExtractPattern
ConversionPatternRewriter &rewriter) const override {
auto dstType = this->typeConverter.convertType(op.getType());
if (!dstType)
- return failure();
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
Type containerType = op.getComposite().getType();
if (isa<VectorType>(containerType)) {
@@ -653,7 +654,7 @@ class CompositeInsertPattern
ConversionPatternRewriter &rewriter) const override {
auto dstType = this->typeConverter.convertType(op.getType());
if (!dstType)
- return failure();
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
Type containerType = op.getComposite().getType();
if (isa<VectorType>(containerType)) {
@@ -680,13 +681,13 @@ class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
LogicalResult
- matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
+ matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto dstType = this->typeConverter.convertType(operation.getType());
+ auto dstType = this->typeConverter.convertType(op.getType());
if (!dstType)
- return failure();
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
rewriter.template replaceOpWithNewOp<LLVMOp>(
- operation, dstType, adaptor.getOperands(), operation->getAttrs());
+ op, dstType, adaptor.getOperands(), op->getAttrs());
return success();
}
};
@@ -790,7 +791,7 @@ class GlobalVariablePattern
auto srcType = cast<spirv::PointerType>(op.getType());
auto dstType = typeConverter.convertType(srcType.getPointeeType());
if (!dstType)
- return failure();
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
// Limit conversion to the current invocation only or `StorageBuffer`
// required by SPIR-V runner.
@@ -843,23 +844,23 @@ class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
LogicalResult
- matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
+ matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Type fromType = operation.getOperand().getType();
- Type toType = operation.getType();
+ Type fromType = op.getOperand().getType();
+ Type toType = op.getType();
auto dstType = this->typeConverter.convertType(toType);
if (!dstType)
- return failure();
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
if (getBitWidth(fromType) < getBitWidth(toType)) {
- rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType,
+ rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
adaptor.getOperands());
return success();
}
if (getBitWidth(fromType) > getBitWidth(toType)) {
- rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType,
+ rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
adaptor.getOperands());
return success();
}
@@ -883,6 +884,8 @@ class FunctionCallPattern
// Function returns a single result.
auto dstType = typeConverter.convertType(callOp.getType(0));
+ if (!dstType)
+ return rewriter.notifyMatchFailure(callOp, "type conversion failed");
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
return success();
@@ -896,16 +899,15 @@ class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
LogicalResult
- matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
+ matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto dstType = this->typeConverter.convertType(operation.getType());
+ auto dstType = this->typeConverter.convertType(op.getType());
if (!dstType)
- return failure();
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
- operation, dstType, predicate, operation.getOperand1(),
- operation.getOperand2());
+ op, dstType, predicate, op.getOperand1(), op.getOperand2());
return success();
}
};
@@ -917,16 +919,15 @@ class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
LogicalResult
- matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
+ matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto dstType = this->typeConverter.convertType(operation.getType());
+ auto dstType = this->typeConverter.convertType(op.getType());
if (!dstType)
- return failure();
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
- operation, dstType, predicate, operation.getOperand1(),
- operation.getOperand2());
+ op, dstType, predicate, op.getOperand1(), op.getOperand2());
return success();
}
};
@@ -942,7 +943,7 @@ class InverseSqrtPattern
auto srcType = op.getType();
auto dstType = typeConverter.convertType(srcType);
if (!dstType)
- return failure();
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
Location loc = op.getLoc();
Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
@@ -1000,7 +1001,7 @@ class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
auto srcType = notOp.getType();
auto dstType = this->typeConverter.convertType(srcType);
if (!dstType)
- return failure();
+ return rewriter.notifyMatchFailure(notOp, "type conversion failed");
Location loc = notOp.getLoc();
IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
@@ -1226,18 +1227,18 @@ class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
LogicalResult
- matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
+ matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto dstType = this->typeConverter.convertType(operation.getType());
+ auto dstType = this->typeConverter.convertType(op.getType());
if (!dstType)
- return failure();
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
- Type op1Type = operation.getOperand1().getType();
- Type op2Type = operation.getOperand2().getType();
+ Type op1Type = op.getOperand1().getType();
+ Type op2Type = op.getOperand2().getType();
if (op1Type == op2Type) {
- rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
+ rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
adaptor.getOperands());
return success();
}
@@ -1250,7 +1251,7 @@ class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
if (!dstTypeWidth || !op2TypeWidth)
return failure();
- Location loc = operation.getLoc();
+ Location loc = op.getLoc();
Value extended;
if (op2TypeWidth < dstTypeWidth) {
if (isUnsignedIntegerOrVector(op2Type)) {
@@ -1268,7 +1269,7 @@ class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
Value result = rewriter.template create<LLVMOp>(
loc, dstType, adaptor.getOperand1(), extended);
- rewriter.replaceOp(operation, result);
+ rewriter.replaceOp(op, result);
return success();
}
};
@@ -1282,7 +1283,7 @@ class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
ConversionPatternRewriter &rewriter) const override {
auto dstType = typeConverter.convertType(tanOp.getType());
if (!dstType)
- return failure();
+ return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
Location loc = tanOp.getLoc();
Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
@@ -1308,7 +1309,7 @@ class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
auto srcType = tanhOp.getType();
auto dstType = typeConverter.convertType(srcType);
if (!dstType)
- return failure();
+ return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
Location loc = tanhOp.getLoc();
Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
@@ -1342,17 +1343,23 @@ class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
auto dstType = typeConverter.convertType(srcType);
if (!dstType)
- return failure();
+ return rewriter.notifyMatchFailure(varOp, "type conversion failed");
Location loc = varOp.getLoc();
Value size = createI32ConstantOf(loc, rewriter, 1);
if (!init) {
- rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(
- varOp, dstType, typeConverter.convertType(pointerTo), size);
+ auto elementType = typeConverter.convertType(pointerTo);
+ if (!elementType)
+ return rewriter.notifyMatchFailure(varOp, "type conversion failed");
+ rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
+ size);
return success();
}
- Value allocated = rewriter.create<LLVM::AllocaOp>(
- loc, dstType, typeConverter.convertType(pointerTo), size);
+ auto elementType = typeConverter.convertType(pointerTo);
+ if (!elementType)
+ return rewriter.notifyMatchFailure(varOp, "type conversion failed");
+ Value allocated =
+ rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size);
rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
rewriter.replaceOp(varOp, allocated);
return success();
@@ -1373,7 +1380,7 @@ class BitcastConversionPattern
ConversionPatternRewriter &rewriter) const override {
auto dstType = typeConverter.convertType(bitcastOp.getType());
if (!dstType)
- return failure();
+ return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");
// LLVM's opaque pointers do not require bitcasts.
if (isa<LLVM::LLVMPointerType>(dstType)) {
@@ -1499,6 +1506,8 @@ class VectorShufflePattern
}
auto dstType = typeConverter.convertType(op.getType());
+ if (!dstType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
auto scalarType = cast<VectorType>(dstType).getElementType();
auto componentsArray = components.getValue();
auto *context = rewriter.getContext();
More information about the Mlir-commits
mailing list