[Mlir-commits] [mlir] [mlir][LLVM] Verify too many indices in GEP verifier (PR #70174)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 25 00:58:22 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Markus Böck (zero9178)
<details>
<summary>Changes</summary>
The current verifier stopped verification with a success value as soon as a type was encountered that cannot be indexed into. The correct behaviour in this case is to error out as there are too many indices for the element type. Not doing so leads to bad user-experience as an invalid GEP is likely to fail only later during LLVM IR translation.
This PR implements the correct verification behaviour. Some tests upstream had to also be fixed as they were creating invalid GEPs.
Fixes https://github.com/llvm/llvm-project/issues/70168
---
Full diff: https://github.com/llvm/llvm-project/pull/70174.diff
5 Files Affected:
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+44-25)
- (modified) mlir/test/Dialect/LLVMIR/invalid.mlir (+8)
- (modified) mlir/test/Dialect/LLVMIR/mem2reg.mlir (+4-5)
- (modified) mlir/test/Dialect/LLVMIR/roundtrip-typed-pointers.mlir (+2-2)
- (modified) mlir/test/Dialect/LLVMIR/roundtrip.mlir (+2-2)
``````````diff
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 95c04098d05fc2f..70045d028cc3214 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -704,50 +704,72 @@ struct GEPStaticIndexError
<< "to be constant";
}
};
+
+/// llvm::Error for non-static GEP index indexing a struct.
+struct GEPCannotIndexError
+ : public llvm::ErrorInfo<GEPCannotIndexError, GEPIndexError> {
+ static char ID;
+
+ using ErrorInfo::ErrorInfo;
+
+ void log(llvm::raw_ostream &os) const override {
+ os << "expected index " << indexPos << " indexing a struct "
+ << "to be constant";
+ }
+};
+
} // end anonymous namespace
char GEPIndexError::ID = 0;
char GEPIndexOutOfBoundError::ID = 0;
char GEPStaticIndexError::ID = 0;
-
-/// For the given `structIndices` and `indices`, check if they're complied
-/// with `baseGEPType`, especially check against LLVMStructTypes nested within.
-static llvm::Error verifyStructIndices(Type baseGEPType, unsigned indexPos,
- GEPIndicesAdaptor<ValueRange> indices) {
+char GEPCannotIndexError::ID = 0;
+
+/// For the given `indices`, check if they comply with `baseGEPType`,
+// especially check against LLVMStructTypes nested within.
+static LogicalResult
+verifyStructIndices(Type baseGEPType, unsigned indexPos,
+ GEPIndicesAdaptor<ValueRange> indices,
+ function_ref<InFlightDiagnostic()> emitOpError) {
if (indexPos >= indices.size())
// Stop searching
- return llvm::Error::success();
+ return success();
- return llvm::TypeSwitch<Type, llvm::Error>(baseGEPType)
- .Case<LLVMStructType>([&](LLVMStructType structType) -> llvm::Error {
+ return llvm::TypeSwitch<Type, LogicalResult>(baseGEPType)
+ .Case<LLVMStructType>([&](LLVMStructType structType) -> LogicalResult {
if (!indices[indexPos].is<IntegerAttr>())
- return llvm::make_error<GEPStaticIndexError>(indexPos);
+ return emitOpError() << "expected index " << indexPos
+ << " indexing a struct to be constant";
int32_t gepIndex = indices[indexPos].get<IntegerAttr>().getInt();
ArrayRef<Type> elementTypes = structType.getBody();
if (gepIndex < 0 ||
static_cast<size_t>(gepIndex) >= elementTypes.size())
- return llvm::make_error<GEPIndexOutOfBoundError>(indexPos);
+ return emitOpError() << "index " << indexPos
+ << " indexing a struct is out of bounds";
// Instead of recursively going into every children types, we only
// dive into the one indexed by gepIndex.
return verifyStructIndices(elementTypes[gepIndex], indexPos + 1,
- indices);
+ indices, emitOpError);
})
.Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
- LLVMArrayType>([&](auto containerType) -> llvm::Error {
+ LLVMArrayType>([&](auto containerType) -> LogicalResult {
return verifyStructIndices(containerType.getElementType(), indexPos + 1,
- indices);
+ indices, emitOpError);
})
- .Default(
- [](auto otherType) -> llvm::Error { return llvm::Error::success(); });
+ .Default([&](auto otherType) -> LogicalResult {
+ return emitOpError()
+ << "type " << otherType << " cannot be indexed (index #"
+ << indexPos << ")";
+ });
}
-/// Driver function around `recordStructIndices`. Note that we always check
-/// from the second GEP index since the first one is always dynamic.
-static llvm::Error verifyStructIndices(Type baseGEPType,
- GEPIndicesAdaptor<ValueRange> indices) {
- return verifyStructIndices(baseGEPType, /*indexPos=*/1, indices);
+/// Driver function around `verifyStructIndices`.
+static LogicalResult
+verifyStructIndices(Type baseGEPType, GEPIndicesAdaptor<ValueRange> indices,
+ function_ref<InFlightDiagnostic()> emitOpError) {
+ return verifyStructIndices(baseGEPType, /*indexPos=*/1, indices, emitOpError);
}
LogicalResult LLVM::GEPOp::verify() {
@@ -763,11 +785,8 @@ LogicalResult LLVM::GEPOp::verify() {
return emitOpError("expected as many dynamic indices as specified in '")
<< getRawConstantIndicesAttrName().getValue() << "'";
- if (llvm::Error err =
- verifyStructIndices(getSourceElementType(), getIndices()))
- return emitOpError() << llvm::toString(std::move(err));
-
- return success();
+ return verifyStructIndices(getSourceElementType(), getIndices(),
+ [&] { return emitOpError(); });
}
Type LLVM::GEPOp::getSourceElementType() {
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 2d0a68b8b6c942c..fe2f94454561a08 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1431,3 +1431,11 @@ llvm.func @invalid_variadic_call(%arg: i32) {
"llvm.call"(%arg) <{callee = @variadic}> : (i32) -> ()
llvm.return
}
+
+// -----
+
+llvm.func @foo(%arg: !llvm.ptr) {
+ // expected-error at +1 {{type '!llvm.ptr' cannot be indexed (index #1)}}
+ %0 = llvm.getelementptr %arg[0, 4] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr
+ llvm.return
+}
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 32e3fed7e5485df..16eb28c629ee232 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -549,7 +549,7 @@ llvm.func @trivial_get_element_ptr() {
%1 = llvm.mlir.constant(2 : i64) : i64
%2 = llvm.alloca %0 x i8 {alignment = 8 : i64} : (i32) -> !llvm.ptr
%3 = llvm.bitcast %2 : !llvm.ptr to !llvm.ptr
- %4 = llvm.getelementptr %3[0, 0, 0] : (!llvm.ptr) -> !llvm.ptr, i8
+ %4 = llvm.getelementptr %3[0] : (!llvm.ptr) -> !llvm.ptr, i8
llvm.intr.lifetime.start 2, %3 : !llvm.ptr
llvm.intr.lifetime.start 2, %4 : !llvm.ptr
llvm.return
@@ -563,9 +563,8 @@ llvm.func @nontrivial_get_element_ptr() {
%1 = llvm.mlir.constant(2 : i64) : i64
// CHECK: = llvm.alloca
%2 = llvm.alloca %0 x i8 {alignment = 8 : i64} : (i32) -> !llvm.ptr
- %3 = llvm.bitcast %2 : !llvm.ptr to !llvm.ptr
- %4 = llvm.getelementptr %3[0, 1, 0] : (!llvm.ptr) -> !llvm.ptr, i8
- llvm.intr.lifetime.start 2, %3 : !llvm.ptr
+ %4 = llvm.getelementptr %2[1] : (!llvm.ptr) -> !llvm.ptr, i8
+ llvm.intr.lifetime.start 2, %2 : !llvm.ptr
llvm.intr.lifetime.start 2, %4 : !llvm.ptr
llvm.return
}
@@ -579,7 +578,7 @@ llvm.func @dynamic_get_element_ptr() {
// CHECK: = llvm.alloca
%2 = llvm.alloca %0 x i8 {alignment = 8 : i64} : (i32) -> !llvm.ptr
%3 = llvm.bitcast %2 : !llvm.ptr to !llvm.ptr
- %4 = llvm.getelementptr %3[0, %0] : (!llvm.ptr, i32) -> !llvm.ptr, i8
+ %4 = llvm.getelementptr %3[%0] : (!llvm.ptr, i32) -> !llvm.ptr, i8
llvm.intr.lifetime.start 2, %3 : !llvm.ptr
llvm.intr.lifetime.start 2, %4 : !llvm.ptr
llvm.return
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip-typed-pointers.mlir b/mlir/test/Dialect/LLVMIR/roundtrip-typed-pointers.mlir
index b1d72b690595c31..f974bcd2e02aff6 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip-typed-pointers.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip-typed-pointers.mlir
@@ -6,12 +6,12 @@ func.func @ops(%arg0: i32) {
// Memory-related operations.
//
// CHECK-NEXT: %[[ALLOCA:.*]] = llvm.alloca %[[I32]] x f64 : (i32) -> !llvm.ptr<f64>
-// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][%[[I32]], %[[I32]]] : (!llvm.ptr<f64>, i32, i32) -> !llvm.ptr<f64>
+// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][%[[I32]]] : (!llvm.ptr<f64>, i32) -> !llvm.ptr<f64>
// CHECK-NEXT: %[[VALUE:.*]] = llvm.load %[[GEP]] : !llvm.ptr<f64>
// CHECK-NEXT: llvm.store %[[VALUE]], %[[ALLOCA]] : !llvm.ptr<f64>
// CHECK-NEXT: %{{.*}} = llvm.bitcast %[[ALLOCA]] : !llvm.ptr<f64> to !llvm.ptr<i64>
%13 = llvm.alloca %arg0 x f64 : (i32) -> !llvm.ptr<f64>
- %14 = llvm.getelementptr %13[%arg0, %arg0] : (!llvm.ptr<f64>, i32, i32) -> !llvm.ptr<f64>
+ %14 = llvm.getelementptr %13[%arg0] : (!llvm.ptr<f64>, i32) -> !llvm.ptr<f64>
%15 = llvm.load %14 : !llvm.ptr<f64>
llvm.store %15, %13 : !llvm.ptr<f64>
%16 = llvm.bitcast %13 : !llvm.ptr<f64> to !llvm.ptr<i64>
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 1134027c6b6570e..ee724a482cfb514 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -50,11 +50,11 @@ func.func @ops(%arg0: i32, %arg1: f32,
// Memory-related operations.
//
// CHECK-NEXT: %[[ALLOCA:.*]] = llvm.alloca %[[I32]] x f64 : (i32) -> !llvm.ptr
-// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][%[[I32]], %[[I32]]] : (!llvm.ptr, i32, i32) -> !llvm.ptr, f64
+// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][%[[I32]]] : (!llvm.ptr, i32) -> !llvm.ptr, f64
// CHECK-NEXT: %[[VALUE:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> f64
// CHECK-NEXT: llvm.store %[[VALUE]], %[[ALLOCA]] : f64, !llvm.ptr
%13 = llvm.alloca %arg0 x f64 : (i32) -> !llvm.ptr
- %14 = llvm.getelementptr %13[%arg0, %arg0] : (!llvm.ptr, i32, i32) -> !llvm.ptr, f64
+ %14 = llvm.getelementptr %13[%arg0] : (!llvm.ptr, i32) -> !llvm.ptr, f64
%15 = llvm.load %14 : !llvm.ptr -> f64
llvm.store %15, %13 : f64, !llvm.ptr
``````````
</details>
https://github.com/llvm/llvm-project/pull/70174
More information about the Mlir-commits
mailing list