[Mlir-commits] [mlir] 1e9799e - [mlir][spirv] Fix crash in convert-gpu-to-spirv pass with memrefs with affine maps
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 30 02:51:22 PDT 2021
Author: Butygin
Date: 2021-07-30T12:46:13+03:00
New Revision: 1e9799e204ff9eaa2160304e6a139c2faa850d33
URL: https://github.com/llvm/llvm-project/commit/1e9799e204ff9eaa2160304e6a139c2faa850d33
DIFF: https://github.com/llvm/llvm-project/commit/1e9799e204ff9eaa2160304e6a139c2faa850d33.diff
LOG: [mlir][spirv] Fix crash in convert-gpu-to-spirv pass with memrefs with affine maps
spirv::getElementPtr can return null (for memrefs with affine map) but patterns didn't handle this.
Differential Revision: https://reviews.llvm.org/D106988
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 2186107851d91..a11214673cbac 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -147,6 +147,7 @@ Value linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
/// Performs the index computation to get to the element at `indices` of the
/// memory pointed to by `basePtr`, using the layout map of `baseType`.
+/// Returns null if index computation cannot be performed.
// TODO: This method assumes that the `baseType` is a MemRefType with AffineMap
// that has static strides. Extend to handle dynamic strides.
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index c6be486e2ea6a..ddc312eb09145 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -268,6 +268,9 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
loadOperands.indices(), loc, rewriter);
+ if (!accessChainOp)
+ return failure();
+
int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
bool isBool = srcBits == 1;
if (isBool)
@@ -358,6 +361,10 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
auto loadPtr = spirv::getElementPtr(
*getTypeConverter<SPIRVTypeConverter>(), memrefType,
loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter);
+
+ if (!loadPtr)
+ return failure();
+
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
return success();
}
@@ -376,6 +383,10 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
spirv::AccessChainOp accessChainOp =
spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
storeOperands.indices(), loc, rewriter);
+
+ if (!accessChainOp)
+ return failure();
+
int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
bool isBool = srcBits == 1;
@@ -467,6 +478,10 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
spirv::getElementPtr(*getTypeConverter<SPIRVTypeConverter>(), memrefType,
storeOperands.memref(), storeOperands.indices(),
storeOp.getLoc(), rewriter);
+
+ if (!storePtr)
+ return failure();
+
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
storeOperands.value());
return success();
More information about the Mlir-commits
mailing list