[Mlir-commits] [mlir] [mlir][spirv] Fix a crash of typeConverter with non supported type (PR #79955)

Kohei Yamaguchi llvmlistbot at llvm.org
Tue Jan 30 17:58:25 PST 2024


https://github.com/sott0n updated https://github.com/llvm/llvm-project/pull/79955

>From 4abb2ac407ff9c89d951def478349f924fb7139d Mon Sep 17 00:00:00 2001
From: Kohei Yamaguchi <fix7211 at gmail.com>
Date: Tue, 30 Jan 2024 15:43:57 +0000
Subject: [PATCH 1/3] [mlir][spirv] Fix a crash of typeConverter with non
 supported type

---
 mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index c62e676efc159..c7c67c04c8919 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -367,12 +367,13 @@ class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
     Value zero = rewriter.create<LLVM::ConstantOp>(
         op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
     indices.insert(indices.begin(), zero);
-    rewriter.replaceOpWithNewOp<LLVM::GEPOp>(
-        op, dstType,
-        typeConverter.convertType(
-            cast<spirv::PointerType>(op.getBasePtr().getType())
-                .getPointeeType()),
-        adaptor.getBasePtr(), indices);
+
+    auto elementType = typeConverter.convertType(
+        cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
+    if (!elementType)
+      return failure();
+    rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType,
+                                             adaptor.getBasePtr(), indices);
     return success();
   }
 };

>From c0beb3feae88e878fb7db9647563430ea05f353f Mon Sep 17 00:00:00 2001
From: Kohei Yamaguchi <fix7211 at gmail.com>
Date: Wed, 31 Jan 2024 09:40:39 +0000
Subject: [PATCH 2/3] addressed comment

---
 mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index c7c67c04c8919..607f4c595169f 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -371,7 +371,7 @@ class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
     auto elementType = typeConverter.convertType(
         cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
     if (!elementType)
-      return failure();
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
     rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType,
                                              adaptor.getBasePtr(), indices);
     return success();

>From ea81352d770087f789d610fc5ad1f45d7147def5 Mon Sep 17 00:00:00 2001
From: Kohei Yamaguchi <fix7211 at gmail.com>
Date: Wed, 31 Jan 2024 10:54:24 +0000
Subject: [PATCH 3/3] more support

---
 .../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp    | 106 ++++++++++--------
 1 file changed, 57 insertions(+), 49 deletions(-)

diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 607f4c595169f..11d2312b9492f 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -240,7 +240,7 @@ static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
   if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
     auto dstType = typeConverter.convertType(loadOp.getType());
     if (!dstType)
-      return failure();
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
         loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
         isVolatile, isNonTemporal);
@@ -357,13 +357,13 @@ class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto dstType = typeConverter.convertType(op.getComponentPtr().getType());
     if (!dstType)
-      return failure();
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
     // To use GEP we need to add a first 0 index to go through the pointer.
     auto indices = llvm::to_vector<4>(adaptor.getIndices());
     Type indexType = op.getIndices().front().getType();
     auto llvmIndexType = typeConverter.convertType(indexType);
     if (!llvmIndexType)
-      return failure();
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
     Value zero = rewriter.create<LLVM::ConstantOp>(
         op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
     indices.insert(indices.begin(), zero);
@@ -387,7 +387,7 @@ class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto dstType = typeConverter.convertType(op.getPointer().getType());
     if (!dstType)
-      return failure();
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
     rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
                                                    op.getVariable());
     return success();
@@ -405,7 +405,7 @@ class BitFieldInsertPattern
     auto srcType = op.getType();
     auto dstType = typeConverter.convertType(srcType);
     if (!dstType)
-      return failure();
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
     Location loc = op.getLoc();
 
     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
@@ -452,7 +452,7 @@ class ConstantScalarAndVectorPattern
 
     auto dstType = typeConverter.convertType(srcType);
     if (!dstType)
-      return failure();
+      return rewriter.notifyMatchFailure(constOp, "type conversion failed");
 
     // SPIR-V constant can be a signed/unsigned integer, which has to be
     // casted to signless integer when converting to LLVM dialect. Removing the
@@ -493,7 +493,7 @@ class BitFieldSExtractPattern
     auto srcType = op.getType();
     auto dstType = typeConverter.convertType(srcType);
     if (!dstType)
-      return failure();
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
     Location loc = op.getLoc();
 
     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
@@ -546,7 +546,7 @@ class BitFieldUExtractPattern
     auto srcType = op.getType();
     auto dstType = typeConverter.convertType(srcType);
     if (!dstType)
-      return failure();
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
     Location loc = op.getLoc();
 
     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
@@ -622,7 +622,7 @@ class CompositeExtractPattern
                   ConversionPatternRewriter &rewriter) const override {
     auto dstType = this->typeConverter.convertType(op.getType());
     if (!dstType)
-      return failure();
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
 
     Type containerType = op.getComposite().getType();
     if (isa<VectorType>(containerType)) {
@@ -654,7 +654,7 @@ class CompositeInsertPattern
                   ConversionPatternRewriter &rewriter) const override {
     auto dstType = this->typeConverter.convertType(op.getType());
     if (!dstType)
-      return failure();
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
 
     Type containerType = op.getComposite().getType();
     if (isa<VectorType>(containerType)) {
@@ -681,13 +681,13 @@ class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
+  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto dstType = this->typeConverter.convertType(operation.getType());
+    auto dstType = this->typeConverter.convertType(op.getType());
     if (!dstType)
-      return failure();
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
     rewriter.template replaceOpWithNewOp<LLVMOp>(
-        operation, dstType, adaptor.getOperands(), operation->getAttrs());
+        op, dstType, adaptor.getOperands(), op->getAttrs());
     return success();
   }
 };
@@ -791,7 +791,7 @@ class GlobalVariablePattern
     auto srcType = cast<spirv::PointerType>(op.getType());
     auto dstType = typeConverter.convertType(srcType.getPointeeType());
     if (!dstType)
-      return failure();
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
 
     // Limit conversion to the current invocation only or `StorageBuffer`
     // required by SPIR-V runner.
@@ -844,23 +844,23 @@ class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
+  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-    Type fromType = operation.getOperand().getType();
-    Type toType = operation.getType();
+    Type fromType = op.getOperand().getType();
+    Type toType = op.getType();
 
     auto dstType = this->typeConverter.convertType(toType);
     if (!dstType)
-      return failure();
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
 
     if (getBitWidth(fromType) < getBitWidth(toType)) {
-      rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType,
+      rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
                                                       adaptor.getOperands());
       return success();
     }
     if (getBitWidth(fromType) > getBitWidth(toType)) {
-      rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType,
+      rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
                                                         adaptor.getOperands());
       return success();
     }
@@ -884,6 +884,8 @@ class FunctionCallPattern
 
     // Function returns a single result.
     auto dstType = typeConverter.convertType(callOp.getType(0));
+    if (!dstType)
+      return rewriter.notifyMatchFailure(callOp, "type conversion failed");
     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
         callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
     return success();
@@ -897,16 +899,15 @@ class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
+  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-    auto dstType = this->typeConverter.convertType(operation.getType());
+    auto dstType = this->typeConverter.convertType(op.getType());
     if (!dstType)
-      return failure();
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
 
     rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
-        operation, dstType, predicate, operation.getOperand1(),
-        operation.getOperand2());
+        op, dstType, predicate, op.getOperand1(), op.getOperand2());
     return success();
   }
 };
@@ -918,16 +919,15 @@ class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
+  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-    auto dstType = this->typeConverter.convertType(operation.getType());
+    auto dstType = this->typeConverter.convertType(op.getType());
     if (!dstType)
-      return failure();
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
 
     rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
-        operation, dstType, predicate, operation.getOperand1(),
-        operation.getOperand2());
+        op, dstType, predicate, op.getOperand1(), op.getOperand2());
     return success();
   }
 };
@@ -943,7 +943,7 @@ class InverseSqrtPattern
     auto srcType = op.getType();
     auto dstType = typeConverter.convertType(srcType);
     if (!dstType)
-      return failure();
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
 
     Location loc = op.getLoc();
     Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
@@ -1001,7 +1001,7 @@ class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
     auto srcType = notOp.getType();
     auto dstType = this->typeConverter.convertType(srcType);
     if (!dstType)
-      return failure();
+      return rewriter.notifyMatchFailure(notOp, "type conversion failed");
 
     Location loc = notOp.getLoc();
     IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
@@ -1227,18 +1227,18 @@ class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
+  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-    auto dstType = this->typeConverter.convertType(operation.getType());
+    auto dstType = this->typeConverter.convertType(op.getType());
     if (!dstType)
-      return failure();
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
 
-    Type op1Type = operation.getOperand1().getType();
-    Type op2Type = operation.getOperand2().getType();
+    Type op1Type = op.getOperand1().getType();
+    Type op2Type = op.getOperand2().getType();
 
     if (op1Type == op2Type) {
-      rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
+      rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
                                                    adaptor.getOperands());
       return success();
     }
@@ -1251,7 +1251,7 @@ class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
     if (!dstTypeWidth || !op2TypeWidth)
       return failure();
 
-    Location loc = operation.getLoc();
+    Location loc = op.getLoc();
     Value extended;
     if (op2TypeWidth < dstTypeWidth) {
       if (isUnsignedIntegerOrVector(op2Type)) {
@@ -1269,7 +1269,7 @@ class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
 
     Value result = rewriter.template create<LLVMOp>(
         loc, dstType, adaptor.getOperand1(), extended);
-    rewriter.replaceOp(operation, result);
+    rewriter.replaceOp(op, result);
     return success();
   }
 };
@@ -1283,7 +1283,7 @@ class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto dstType = typeConverter.convertType(tanOp.getType());
     if (!dstType)
-      return failure();
+      return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
 
     Location loc = tanOp.getLoc();
     Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
@@ -1309,7 +1309,7 @@ class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
     auto srcType = tanhOp.getType();
     auto dstType = typeConverter.convertType(srcType);
     if (!dstType)
-      return failure();
+      return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
 
     Location loc = tanhOp.getLoc();
     Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
@@ -1343,17 +1343,23 @@ class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
 
     auto dstType = typeConverter.convertType(srcType);
     if (!dstType)
-      return failure();
+      return rewriter.notifyMatchFailure(varOp, "type conversion failed");
 
     Location loc = varOp.getLoc();
     Value size = createI32ConstantOf(loc, rewriter, 1);
     if (!init) {
-      rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(
-          varOp, dstType, typeConverter.convertType(pointerTo), size);
+      auto elementType = typeConverter.convertType(pointerTo);
+      if (!elementType)
+        return rewriter.notifyMatchFailure(varOp, "type conversion failed");
+      rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
+                                                  size);
       return success();
     }
-    Value allocated = rewriter.create<LLVM::AllocaOp>(
-        loc, dstType, typeConverter.convertType(pointerTo), size);
+    auto elementType = typeConverter.convertType(pointerTo);
+    if (!elementType)
+      return rewriter.notifyMatchFailure(varOp, "type conversion failed");
+    Value allocated =
+        rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size);
     rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
     rewriter.replaceOp(varOp, allocated);
     return success();
@@ -1374,7 +1380,7 @@ class BitcastConversionPattern
                   ConversionPatternRewriter &rewriter) const override {
     auto dstType = typeConverter.convertType(bitcastOp.getType());
     if (!dstType)
-      return failure();
+      return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");
 
     // LLVM's opaque pointers do not require bitcasts.
     if (isa<LLVM::LLVMPointerType>(dstType)) {
@@ -1500,6 +1506,8 @@ class VectorShufflePattern
     }
 
     auto dstType = typeConverter.convertType(op.getType());
+    if (!dstType)
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
     auto scalarType = cast<VectorType>(dstType).getElementType();
     auto componentsArray = components.getValue();
     auto *context = rewriter.getContext();



More information about the Mlir-commits mailing list