[Mlir-commits] [mlir] 419c6da - [mlir][LLVM] Verify too many indices in GEP verifier (#70174)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 25 02:38:30 PDT 2023
Author: Markus Böck
Date: 2023-10-25T11:38:26+02:00
New Revision: 419c6da3d763cc34b15aa66d4b6e7fed6031cedd
URL: https://github.com/llvm/llvm-project/commit/419c6da3d763cc34b15aa66d4b6e7fed6031cedd
DIFF: https://github.com/llvm/llvm-project/commit/419c6da3d763cc34b15aa66d4b6e7fed6031cedd.diff
LOG: [mlir][LLVM] Verify too many indices in GEP verifier (#70174)
The current verifier stopped verification with a success value as soon
as a type was encountered that cannot be indexed into. The correct
behaviour in this case is to error out as there are too many indices for
the element type. Not doing so leads to bad user-experience as an
invalid GEP is likely to fail only later during LLVM IR translation.
This PR implements the correct verification behaviour. Some tests
upstream had to also be fixed as they were creating invalid GEPs.
Fixes https://github.com/llvm/llvm-project/issues/70168
Added:
Modified:
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Dialect/LLVMIR/invalid.mlir
mlir/test/Dialect/LLVMIR/mem2reg.mlir
mlir/test/Dialect/LLVMIR/roundtrip-typed-pointers.mlir
mlir/test/Dialect/LLVMIR/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 95c04098d05fc2f..7f5681e7bdc0592 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -664,90 +664,51 @@ static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
});
}
-namespace {
-/// Base class for llvm::Error related to GEP index.
-class GEPIndexError : public llvm::ErrorInfo<GEPIndexError> {
-protected:
- unsigned indexPos;
-
-public:
- static char ID;
-
- std::error_code convertToErrorCode() const override {
- return llvm::inconvertibleErrorCode();
- }
-
- explicit GEPIndexError(unsigned pos) : indexPos(pos) {}
-};
-
-/// llvm::Error for out-of-bound GEP index.
-struct GEPIndexOutOfBoundError
- : public llvm::ErrorInfo<GEPIndexOutOfBoundError, GEPIndexError> {
- static char ID;
-
- using ErrorInfo::ErrorInfo;
-
- void log(llvm::raw_ostream &os) const override {
- os << "index " << indexPos << " indexing a struct is out of bounds";
- }
-};
-
-/// llvm::Error for non-static GEP index indexing a struct.
-struct GEPStaticIndexError
- : public llvm::ErrorInfo<GEPStaticIndexError, GEPIndexError> {
- static char ID;
-
- using ErrorInfo::ErrorInfo;
-
- void log(llvm::raw_ostream &os) const override {
- os << "expected index " << indexPos << " indexing a struct "
- << "to be constant";
- }
-};
-} // end anonymous namespace
-
-char GEPIndexError::ID = 0;
-char GEPIndexOutOfBoundError::ID = 0;
-char GEPStaticIndexError::ID = 0;
-
-/// For the given `structIndices` and `indices`, check if they're complied
-/// with `baseGEPType`, especially check against LLVMStructTypes nested within.
-static llvm::Error verifyStructIndices(Type baseGEPType, unsigned indexPos,
- GEPIndicesAdaptor<ValueRange> indices) {
+/// For the given `indices`, check if they comply with `baseGEPType`,
+/// especially check against LLVMStructTypes nested within.
+static LogicalResult
+verifyStructIndices(Type baseGEPType, unsigned indexPos,
+ GEPIndicesAdaptor<ValueRange> indices,
+ function_ref<InFlightDiagnostic()> emitOpError) {
if (indexPos >= indices.size())
// Stop searching
- return llvm::Error::success();
+ return success();
- return llvm::TypeSwitch<Type, llvm::Error>(baseGEPType)
- .Case<LLVMStructType>([&](LLVMStructType structType) -> llvm::Error {
+ return TypeSwitch<Type, LogicalResult>(baseGEPType)
+ .Case<LLVMStructType>([&](LLVMStructType structType) -> LogicalResult {
if (!indices[indexPos].is<IntegerAttr>())
- return llvm::make_error<GEPStaticIndexError>(indexPos);
+ return emitOpError() << "expected index " << indexPos
+ << " indexing a struct to be constant";
int32_t gepIndex = indices[indexPos].get<IntegerAttr>().getInt();
ArrayRef<Type> elementTypes = structType.getBody();
if (gepIndex < 0 ||
static_cast<size_t>(gepIndex) >= elementTypes.size())
- return llvm::make_error<GEPIndexOutOfBoundError>(indexPos);
+ return emitOpError() << "index " << indexPos
+ << " indexing a struct is out of bounds";
// Instead of recursively going into every children types, we only
// dive into the one indexed by gepIndex.
return verifyStructIndices(elementTypes[gepIndex], indexPos + 1,
- indices);
+ indices, emitOpError);
})
.Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
- LLVMArrayType>([&](auto containerType) -> llvm::Error {
+ LLVMArrayType>([&](auto containerType) -> LogicalResult {
return verifyStructIndices(containerType.getElementType(), indexPos + 1,
- indices);
+ indices, emitOpError);
})
- .Default(
- [](auto otherType) -> llvm::Error { return llvm::Error::success(); });
+ .Default([&](auto otherType) -> LogicalResult {
+ return emitOpError()
+ << "type " << otherType << " cannot be indexed (index #"
+ << indexPos << ")";
+ });
}
-/// Driver function around `recordStructIndices`. Note that we always check
-/// from the second GEP index since the first one is always dynamic.
-static llvm::Error verifyStructIndices(Type baseGEPType,
- GEPIndicesAdaptor<ValueRange> indices) {
- return verifyStructIndices(baseGEPType, /*indexPos=*/1, indices);
+/// Driver function around `verifyStructIndices`.
+static LogicalResult
+verifyStructIndices(Type baseGEPType, GEPIndicesAdaptor<ValueRange> indices,
+ function_ref<InFlightDiagnostic()> emitOpError) {
+ return verifyStructIndices(baseGEPType, /*indexPos=*/1, indices, emitOpError);
}
LogicalResult LLVM::GEPOp::verify() {
@@ -763,11 +724,8 @@ LogicalResult LLVM::GEPOp::verify() {
return emitOpError("expected as many dynamic indices as specified in '")
<< getRawConstantIndicesAttrName().getValue() << "'";
- if (llvm::Error err =
- verifyStructIndices(getSourceElementType(), getIndices()))
- return emitOpError() << llvm::toString(std::move(err));
-
- return success();
+ return verifyStructIndices(getSourceElementType(), getIndices(),
+ [&] { return emitOpError(); });
}
Type LLVM::GEPOp::getSourceElementType() {
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 2d0a68b8b6c942c..fe2f94454561a08 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1431,3 +1431,11 @@ llvm.func @invalid_variadic_call(%arg: i32) {
"llvm.call"(%arg) <{callee = @variadic}> : (i32) -> ()
llvm.return
}
+
+// -----
+
+llvm.func @foo(%arg: !llvm.ptr) {
+ // expected-error at +1 {{type '!llvm.ptr' cannot be indexed (index #1)}}
+ %0 = llvm.getelementptr %arg[0, 4] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr
+ llvm.return
+}
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 32e3fed7e5485df..16eb28c629ee232 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -549,7 +549,7 @@ llvm.func @trivial_get_element_ptr() {
%1 = llvm.mlir.constant(2 : i64) : i64
%2 = llvm.alloca %0 x i8 {alignment = 8 : i64} : (i32) -> !llvm.ptr
%3 = llvm.bitcast %2 : !llvm.ptr to !llvm.ptr
- %4 = llvm.getelementptr %3[0, 0, 0] : (!llvm.ptr) -> !llvm.ptr, i8
+ %4 = llvm.getelementptr %3[0] : (!llvm.ptr) -> !llvm.ptr, i8
llvm.intr.lifetime.start 2, %3 : !llvm.ptr
llvm.intr.lifetime.start 2, %4 : !llvm.ptr
llvm.return
@@ -563,9 +563,8 @@ llvm.func @nontrivial_get_element_ptr() {
%1 = llvm.mlir.constant(2 : i64) : i64
// CHECK: = llvm.alloca
%2 = llvm.alloca %0 x i8 {alignment = 8 : i64} : (i32) -> !llvm.ptr
- %3 = llvm.bitcast %2 : !llvm.ptr to !llvm.ptr
- %4 = llvm.getelementptr %3[0, 1, 0] : (!llvm.ptr) -> !llvm.ptr, i8
- llvm.intr.lifetime.start 2, %3 : !llvm.ptr
+ %4 = llvm.getelementptr %2[1] : (!llvm.ptr) -> !llvm.ptr, i8
+ llvm.intr.lifetime.start 2, %2 : !llvm.ptr
llvm.intr.lifetime.start 2, %4 : !llvm.ptr
llvm.return
}
@@ -579,7 +578,7 @@ llvm.func @dynamic_get_element_ptr() {
// CHECK: = llvm.alloca
%2 = llvm.alloca %0 x i8 {alignment = 8 : i64} : (i32) -> !llvm.ptr
%3 = llvm.bitcast %2 : !llvm.ptr to !llvm.ptr
- %4 = llvm.getelementptr %3[0, %0] : (!llvm.ptr, i32) -> !llvm.ptr, i8
+ %4 = llvm.getelementptr %3[%0] : (!llvm.ptr, i32) -> !llvm.ptr, i8
llvm.intr.lifetime.start 2, %3 : !llvm.ptr
llvm.intr.lifetime.start 2, %4 : !llvm.ptr
llvm.return
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip-typed-pointers.mlir b/mlir/test/Dialect/LLVMIR/roundtrip-typed-pointers.mlir
index b1d72b690595c31..f974bcd2e02aff6 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip-typed-pointers.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip-typed-pointers.mlir
@@ -6,12 +6,12 @@ func.func @ops(%arg0: i32) {
// Memory-related operations.
//
// CHECK-NEXT: %[[ALLOCA:.*]] = llvm.alloca %[[I32]] x f64 : (i32) -> !llvm.ptr<f64>
-// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][%[[I32]], %[[I32]]] : (!llvm.ptr<f64>, i32, i32) -> !llvm.ptr<f64>
+// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][%[[I32]]] : (!llvm.ptr<f64>, i32) -> !llvm.ptr<f64>
// CHECK-NEXT: %[[VALUE:.*]] = llvm.load %[[GEP]] : !llvm.ptr<f64>
// CHECK-NEXT: llvm.store %[[VALUE]], %[[ALLOCA]] : !llvm.ptr<f64>
// CHECK-NEXT: %{{.*}} = llvm.bitcast %[[ALLOCA]] : !llvm.ptr<f64> to !llvm.ptr<i64>
%13 = llvm.alloca %arg0 x f64 : (i32) -> !llvm.ptr<f64>
- %14 = llvm.getelementptr %13[%arg0, %arg0] : (!llvm.ptr<f64>, i32, i32) -> !llvm.ptr<f64>
+ %14 = llvm.getelementptr %13[%arg0] : (!llvm.ptr<f64>, i32) -> !llvm.ptr<f64>
%15 = llvm.load %14 : !llvm.ptr<f64>
llvm.store %15, %13 : !llvm.ptr<f64>
%16 = llvm.bitcast %13 : !llvm.ptr<f64> to !llvm.ptr<i64>
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 1134027c6b6570e..ee724a482cfb514 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -50,11 +50,11 @@ func.func @ops(%arg0: i32, %arg1: f32,
// Memory-related operations.
//
// CHECK-NEXT: %[[ALLOCA:.*]] = llvm.alloca %[[I32]] x f64 : (i32) -> !llvm.ptr
-// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][%[[I32]], %[[I32]]] : (!llvm.ptr, i32, i32) -> !llvm.ptr, f64
+// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][%[[I32]]] : (!llvm.ptr, i32) -> !llvm.ptr, f64
// CHECK-NEXT: %[[VALUE:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> f64
// CHECK-NEXT: llvm.store %[[VALUE]], %[[ALLOCA]] : f64, !llvm.ptr
%13 = llvm.alloca %arg0 x f64 : (i32) -> !llvm.ptr
- %14 = llvm.getelementptr %13[%arg0, %arg0] : (!llvm.ptr, i32, i32) -> !llvm.ptr, f64
+ %14 = llvm.getelementptr %13[%arg0] : (!llvm.ptr, i32) -> !llvm.ptr, f64
%15 = llvm.load %14 : !llvm.ptr -> f64
llvm.store %15, %13 : f64, !llvm.ptr
More information about the Mlir-commits
mailing list