[Mlir-commits] [mlir] [mlir][LLVM] Verify too many indices in GEP verifier (PR #70174)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 25 00:58:22 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Markus Böck (zero9178)

<details>
<summary>Changes</summary>

The current verifier stopped verification with a success value as soon as a type was encountered that cannot be indexed into. The correct behaviour in this case is to error out as there are too many indices for the element type. Not doing so leads to bad user-experience as an invalid GEP is likely to fail only later during LLVM IR translation.

This PR implements the correct verification behaviour. Some tests upstream had to also be fixed as they were creating invalid GEPs.

Fixes https://github.com/llvm/llvm-project/issues/70168

---
Full diff: https://github.com/llvm/llvm-project/pull/70174.diff


5 Files Affected:

- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+44-25) 
- (modified) mlir/test/Dialect/LLVMIR/invalid.mlir (+8) 
- (modified) mlir/test/Dialect/LLVMIR/mem2reg.mlir (+4-5) 
- (modified) mlir/test/Dialect/LLVMIR/roundtrip-typed-pointers.mlir (+2-2) 
- (modified) mlir/test/Dialect/LLVMIR/roundtrip.mlir (+2-2) 


``````````diff
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 95c04098d05fc2f..70045d028cc3214 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -704,50 +704,72 @@ struct GEPStaticIndexError
        << "to be constant";
   }
 };
+
+/// llvm::Error for non-static GEP index indexing a struct.
+struct GEPCannotIndexError
+    : public llvm::ErrorInfo<GEPCannotIndexError, GEPIndexError> {
+  static char ID;
+
+  using ErrorInfo::ErrorInfo;
+
+  void log(llvm::raw_ostream &os) const override {
+    os << "expected index " << indexPos << " indexing a struct "
+       << "to be constant";
+  }
+};
+
 } // end anonymous namespace
 
 char GEPIndexError::ID = 0;
 char GEPIndexOutOfBoundError::ID = 0;
 char GEPStaticIndexError::ID = 0;
-
-/// For the given `structIndices` and `indices`, check if they're complied
-/// with `baseGEPType`, especially check against LLVMStructTypes nested within.
-static llvm::Error verifyStructIndices(Type baseGEPType, unsigned indexPos,
-                                       GEPIndicesAdaptor<ValueRange> indices) {
+char GEPCannotIndexError::ID = 0;
+
+/// For the given `indices`, check if they comply with `baseGEPType`,
+// especially check against LLVMStructTypes nested within.
+static LogicalResult
+verifyStructIndices(Type baseGEPType, unsigned indexPos,
+                    GEPIndicesAdaptor<ValueRange> indices,
+                    function_ref<InFlightDiagnostic()> emitOpError) {
   if (indexPos >= indices.size())
     // Stop searching
-    return llvm::Error::success();
+    return success();
 
-  return llvm::TypeSwitch<Type, llvm::Error>(baseGEPType)
-      .Case<LLVMStructType>([&](LLVMStructType structType) -> llvm::Error {
+  return llvm::TypeSwitch<Type, LogicalResult>(baseGEPType)
+      .Case<LLVMStructType>([&](LLVMStructType structType) -> LogicalResult {
         if (!indices[indexPos].is<IntegerAttr>())
-          return llvm::make_error<GEPStaticIndexError>(indexPos);
+          return emitOpError() << "expected index " << indexPos
+                               << " indexing a struct to be constant";
 
         int32_t gepIndex = indices[indexPos].get<IntegerAttr>().getInt();
         ArrayRef<Type> elementTypes = structType.getBody();
         if (gepIndex < 0 ||
             static_cast<size_t>(gepIndex) >= elementTypes.size())
-          return llvm::make_error<GEPIndexOutOfBoundError>(indexPos);
+          return emitOpError() << "index " << indexPos
+                               << " indexing a struct is out of bounds";
 
         // Instead of recursively going into every children types, we only
         // dive into the one indexed by gepIndex.
         return verifyStructIndices(elementTypes[gepIndex], indexPos + 1,
-                                   indices);
+                                   indices, emitOpError);
       })
       .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
-            LLVMArrayType>([&](auto containerType) -> llvm::Error {
+            LLVMArrayType>([&](auto containerType) -> LogicalResult {
         return verifyStructIndices(containerType.getElementType(), indexPos + 1,
-                                   indices);
+                                   indices, emitOpError);
       })
-      .Default(
-          [](auto otherType) -> llvm::Error { return llvm::Error::success(); });
+      .Default([&](auto otherType) -> LogicalResult {
+        return emitOpError()
+               << "type " << otherType << " cannot be indexed (index #"
+               << indexPos << ")";
+      });
 }
 
-/// Driver function around `recordStructIndices`. Note that we always check
-/// from the second GEP index since the first one is always dynamic.
-static llvm::Error verifyStructIndices(Type baseGEPType,
-                                       GEPIndicesAdaptor<ValueRange> indices) {
-  return verifyStructIndices(baseGEPType, /*indexPos=*/1, indices);
+/// Driver function around `verifyStructIndices`.
+static LogicalResult
+verifyStructIndices(Type baseGEPType, GEPIndicesAdaptor<ValueRange> indices,
+                    function_ref<InFlightDiagnostic()> emitOpError) {
+  return verifyStructIndices(baseGEPType, /*indexPos=*/1, indices, emitOpError);
 }
 
 LogicalResult LLVM::GEPOp::verify() {
@@ -763,11 +785,8 @@ LogicalResult LLVM::GEPOp::verify() {
     return emitOpError("expected as many dynamic indices as specified in '")
            << getRawConstantIndicesAttrName().getValue() << "'";
 
-  if (llvm::Error err =
-          verifyStructIndices(getSourceElementType(), getIndices()))
-    return emitOpError() << llvm::toString(std::move(err));
-
-  return success();
+  return verifyStructIndices(getSourceElementType(), getIndices(),
+                             [&] { return emitOpError(); });
 }
 
 Type LLVM::GEPOp::getSourceElementType() {
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 2d0a68b8b6c942c..fe2f94454561a08 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1431,3 +1431,11 @@ llvm.func @invalid_variadic_call(%arg: i32)  {
   "llvm.call"(%arg) <{callee = @variadic}> : (i32) -> ()
   llvm.return
 }
+
+// -----
+
+llvm.func @foo(%arg: !llvm.ptr) {
+  // expected-error at +1 {{type '!llvm.ptr' cannot be indexed (index #1)}}
+  %0 = llvm.getelementptr %arg[0, 4] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr
+  llvm.return
+}
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 32e3fed7e5485df..16eb28c629ee232 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -549,7 +549,7 @@ llvm.func @trivial_get_element_ptr() {
   %1 = llvm.mlir.constant(2 : i64) : i64
   %2 = llvm.alloca %0 x i8 {alignment = 8 : i64} : (i32) -> !llvm.ptr
   %3 = llvm.bitcast %2 : !llvm.ptr to !llvm.ptr
-  %4 = llvm.getelementptr %3[0, 0, 0] : (!llvm.ptr) -> !llvm.ptr, i8
+  %4 = llvm.getelementptr %3[0] : (!llvm.ptr) -> !llvm.ptr, i8
   llvm.intr.lifetime.start 2, %3 : !llvm.ptr
   llvm.intr.lifetime.start 2, %4 : !llvm.ptr
   llvm.return
@@ -563,9 +563,8 @@ llvm.func @nontrivial_get_element_ptr() {
   %1 = llvm.mlir.constant(2 : i64) : i64
   // CHECK: = llvm.alloca
   %2 = llvm.alloca %0 x i8 {alignment = 8 : i64} : (i32) -> !llvm.ptr
-  %3 = llvm.bitcast %2 : !llvm.ptr to !llvm.ptr
-  %4 = llvm.getelementptr %3[0, 1, 0] : (!llvm.ptr) -> !llvm.ptr, i8
-  llvm.intr.lifetime.start 2, %3 : !llvm.ptr
+  %4 = llvm.getelementptr %2[1] : (!llvm.ptr) -> !llvm.ptr, i8
+  llvm.intr.lifetime.start 2, %2 : !llvm.ptr
   llvm.intr.lifetime.start 2, %4 : !llvm.ptr
   llvm.return
 }
@@ -579,7 +578,7 @@ llvm.func @dynamic_get_element_ptr() {
   // CHECK: = llvm.alloca
   %2 = llvm.alloca %0 x i8 {alignment = 8 : i64} : (i32) -> !llvm.ptr
   %3 = llvm.bitcast %2 : !llvm.ptr to !llvm.ptr
-  %4 = llvm.getelementptr %3[0, %0] : (!llvm.ptr, i32) -> !llvm.ptr, i8
+  %4 = llvm.getelementptr %3[%0] : (!llvm.ptr, i32) -> !llvm.ptr, i8
   llvm.intr.lifetime.start 2, %3 : !llvm.ptr
   llvm.intr.lifetime.start 2, %4 : !llvm.ptr
   llvm.return
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip-typed-pointers.mlir b/mlir/test/Dialect/LLVMIR/roundtrip-typed-pointers.mlir
index b1d72b690595c31..f974bcd2e02aff6 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip-typed-pointers.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip-typed-pointers.mlir
@@ -6,12 +6,12 @@ func.func @ops(%arg0: i32) {
 // Memory-related operations.
 //
 // CHECK-NEXT:  %[[ALLOCA:.*]] = llvm.alloca %[[I32]] x f64 : (i32) -> !llvm.ptr<f64>
-// CHECK-NEXT:  %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][%[[I32]], %[[I32]]] : (!llvm.ptr<f64>, i32, i32) -> !llvm.ptr<f64>
+// CHECK-NEXT:  %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][%[[I32]]] : (!llvm.ptr<f64>, i32) -> !llvm.ptr<f64>
 // CHECK-NEXT:  %[[VALUE:.*]] = llvm.load %[[GEP]] : !llvm.ptr<f64>
 // CHECK-NEXT:  llvm.store %[[VALUE]], %[[ALLOCA]] : !llvm.ptr<f64>
 // CHECK-NEXT:  %{{.*}} = llvm.bitcast %[[ALLOCA]] : !llvm.ptr<f64> to !llvm.ptr<i64>
   %13 = llvm.alloca %arg0 x f64 : (i32) -> !llvm.ptr<f64>
-  %14 = llvm.getelementptr %13[%arg0, %arg0] : (!llvm.ptr<f64>, i32, i32) -> !llvm.ptr<f64>
+  %14 = llvm.getelementptr %13[%arg0] : (!llvm.ptr<f64>, i32) -> !llvm.ptr<f64>
   %15 = llvm.load %14 : !llvm.ptr<f64>
   llvm.store %15, %13 : !llvm.ptr<f64>
   %16 = llvm.bitcast %13 : !llvm.ptr<f64> to !llvm.ptr<i64>
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 1134027c6b6570e..ee724a482cfb514 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -50,11 +50,11 @@ func.func @ops(%arg0: i32, %arg1: f32,
 // Memory-related operations.
 //
 // CHECK-NEXT:  %[[ALLOCA:.*]] = llvm.alloca %[[I32]] x f64 : (i32) -> !llvm.ptr
-// CHECK-NEXT:  %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][%[[I32]], %[[I32]]] : (!llvm.ptr, i32, i32) -> !llvm.ptr, f64
+// CHECK-NEXT:  %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][%[[I32]]] : (!llvm.ptr, i32) -> !llvm.ptr, f64
 // CHECK-NEXT:  %[[VALUE:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> f64
 // CHECK-NEXT:  llvm.store %[[VALUE]], %[[ALLOCA]] : f64, !llvm.ptr
   %13 = llvm.alloca %arg0 x f64 : (i32) -> !llvm.ptr
-  %14 = llvm.getelementptr %13[%arg0, %arg0] : (!llvm.ptr, i32, i32) -> !llvm.ptr, f64
+  %14 = llvm.getelementptr %13[%arg0] : (!llvm.ptr, i32) -> !llvm.ptr, f64
   %15 = llvm.load %14 : !llvm.ptr -> f64
   llvm.store %15, %13 : f64, !llvm.ptr
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/70174


More information about the Mlir-commits mailing list