[Mlir-commits] [mlir] 1e9799e - [mlir][spirv] Fix crash in convert-gpu-to-spirv pass with memrefs with affine maps

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 30 02:51:22 PDT 2021


Author: Butygin
Date: 2021-07-30T12:46:13+03:00
New Revision: 1e9799e204ff9eaa2160304e6a139c2faa850d33

URL: https://github.com/llvm/llvm-project/commit/1e9799e204ff9eaa2160304e6a139c2faa850d33
DIFF: https://github.com/llvm/llvm-project/commit/1e9799e204ff9eaa2160304e6a139c2faa850d33.diff

LOG: [mlir][spirv] Fix crash in convert-gpu-to-spirv pass with memrefs with affine maps

spirv::getElementPtr can return null (for memrefs with affine map) but patterns didn't handle this.

Differential Revision: https://reviews.llvm.org/D106988

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 2186107851d91..a11214673cbac 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -147,6 +147,7 @@ Value linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
 
 /// Performs the index computation to get to the element at `indices` of the
 /// memory pointed to by `basePtr`, using the layout map of `baseType`.
+/// Returns null if index computation cannot be performed.
 
 // TODO: This method assumes that the `baseType` is a MemRefType with AffineMap
 // that has static strides. Extend to handle dynamic strides.

diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index c6be486e2ea6a..ddc312eb09145 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -268,6 +268,9 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
       spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
                            loadOperands.indices(), loc, rewriter);
 
+  if (!accessChainOp)
+    return failure();
+
   int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
   bool isBool = srcBits == 1;
   if (isBool)
@@ -358,6 +361,10 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
   auto loadPtr = spirv::getElementPtr(
       *getTypeConverter<SPIRVTypeConverter>(), memrefType,
       loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter);
+
+  if (!loadPtr)
+    return failure();
+
   rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
   return success();
 }
@@ -376,6 +383,10 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
   spirv::AccessChainOp accessChainOp =
       spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
                            storeOperands.indices(), loc, rewriter);
+
+  if (!accessChainOp)
+    return failure();
+
   int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
 
   bool isBool = srcBits == 1;
@@ -467,6 +478,10 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
       spirv::getElementPtr(*getTypeConverter<SPIRVTypeConverter>(), memrefType,
                            storeOperands.memref(), storeOperands.indices(),
                            storeOp.getLoc(), rewriter);
+
+  if (!storePtr)
+    return failure();
+
   rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
                                               storeOperands.value());
   return success();


        


More information about the Mlir-commits mailing list