[flang-commits] [flang] [mlir][LLVM] Verify too many indices in GEP verifier (PR #70174)

Markus Böck via flang-commits flang-commits at lists.llvm.org
Wed Oct 25 01:40:15 PDT 2023


https://github.com/zero9178 updated https://github.com/llvm/llvm-project/pull/70174

>From 03fb7aceae81227c3a64cf4b9ba1e53a69e46511 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Markus=20B=C3=B6ck?= <markus.boeck02 at gmail.com>
Date: Wed, 25 Oct 2023 09:54:16 +0200
Subject: [PATCH 1/3] [mlir][LLVM] Verify too many indices in GEP verifier

The current verifier stopped verification with a success value as soon as a type was encountered that cannot be indexed into. The correct behaviour in this case is to error out as there are too many indices for the element type. Not doing so leads to bad user-experience as an invalid GEP is likely to fail only later during LLVM IR translation.

This PR implements the correct verification behaviour. Some tests upstream had to also be fixed as they were creating invalid GEPs.

Fixes https://github.com/llvm/llvm-project/issues/70168
---
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    | 69 ++++++++++++-------
 mlir/test/Dialect/LLVMIR/invalid.mlir         |  8 +++
 mlir/test/Dialect/LLVMIR/mem2reg.mlir         |  9 ++-
 .../LLVMIR/roundtrip-typed-pointers.mlir      |  4 +-
 mlir/test/Dialect/LLVMIR/roundtrip.mlir       |  4 +-
 5 files changed, 60 insertions(+), 34 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 95c04098d05fc2f..70045d028cc3214 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -704,50 +704,72 @@ struct GEPStaticIndexError
        << "to be constant";
   }
 };
+
+/// llvm::Error for non-static GEP index indexing a struct.
+struct GEPCannotIndexError
+    : public llvm::ErrorInfo<GEPCannotIndexError, GEPIndexError> {
+  static char ID;
+
+  using ErrorInfo::ErrorInfo;
+
+  void log(llvm::raw_ostream &os) const override {
+    os << "expected index " << indexPos << " indexing a struct "
+       << "to be constant";
+  }
+};
+
 } // end anonymous namespace
 
 char GEPIndexError::ID = 0;
 char GEPIndexOutOfBoundError::ID = 0;
 char GEPStaticIndexError::ID = 0;
-
-/// For the given `structIndices` and `indices`, check if they're complied
-/// with `baseGEPType`, especially check against LLVMStructTypes nested within.
-static llvm::Error verifyStructIndices(Type baseGEPType, unsigned indexPos,
-                                       GEPIndicesAdaptor<ValueRange> indices) {
+char GEPCannotIndexError::ID = 0;
+
+/// For the given `indices`, check if they comply with `baseGEPType`,
+// especially check against LLVMStructTypes nested within.
+static LogicalResult
+verifyStructIndices(Type baseGEPType, unsigned indexPos,
+                    GEPIndicesAdaptor<ValueRange> indices,
+                    function_ref<InFlightDiagnostic()> emitOpError) {
   if (indexPos >= indices.size())
     // Stop searching
-    return llvm::Error::success();
+    return success();
 
-  return llvm::TypeSwitch<Type, llvm::Error>(baseGEPType)
-      .Case<LLVMStructType>([&](LLVMStructType structType) -> llvm::Error {
+  return llvm::TypeSwitch<Type, LogicalResult>(baseGEPType)
+      .Case<LLVMStructType>([&](LLVMStructType structType) -> LogicalResult {
         if (!indices[indexPos].is<IntegerAttr>())
-          return llvm::make_error<GEPStaticIndexError>(indexPos);
+          return emitOpError() << "expected index " << indexPos
+                               << " indexing a struct to be constant";
 
         int32_t gepIndex = indices[indexPos].get<IntegerAttr>().getInt();
         ArrayRef<Type> elementTypes = structType.getBody();
         if (gepIndex < 0 ||
             static_cast<size_t>(gepIndex) >= elementTypes.size())
-          return llvm::make_error<GEPIndexOutOfBoundError>(indexPos);
+          return emitOpError() << "index " << indexPos
+                               << " indexing a struct is out of bounds";
 
         // Instead of recursively going into every children types, we only
         // dive into the one indexed by gepIndex.
         return verifyStructIndices(elementTypes[gepIndex], indexPos + 1,
-                                   indices);
+                                   indices, emitOpError);
       })
       .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
-            LLVMArrayType>([&](auto containerType) -> llvm::Error {
+            LLVMArrayType>([&](auto containerType) -> LogicalResult {
         return verifyStructIndices(containerType.getElementType(), indexPos + 1,
-                                   indices);
+                                   indices, emitOpError);
       })
-      .Default(
-          [](auto otherType) -> llvm::Error { return llvm::Error::success(); });
+      .Default([&](auto otherType) -> LogicalResult {
+        return emitOpError()
+               << "type " << otherType << " cannot be indexed (index #"
+               << indexPos << ")";
+      });
 }
 
-/// Driver function around `recordStructIndices`. Note that we always check
-/// from the second GEP index since the first one is always dynamic.
-static llvm::Error verifyStructIndices(Type baseGEPType,
-                                       GEPIndicesAdaptor<ValueRange> indices) {
-  return verifyStructIndices(baseGEPType, /*indexPos=*/1, indices);
+/// Driver function around `verifyStructIndices`.
+static LogicalResult
+verifyStructIndices(Type baseGEPType, GEPIndicesAdaptor<ValueRange> indices,
+                    function_ref<InFlightDiagnostic()> emitOpError) {
+  return verifyStructIndices(baseGEPType, /*indexPos=*/1, indices, emitOpError);
 }
 
 LogicalResult LLVM::GEPOp::verify() {
@@ -763,11 +785,8 @@ LogicalResult LLVM::GEPOp::verify() {
     return emitOpError("expected as many dynamic indices as specified in '")
            << getRawConstantIndicesAttrName().getValue() << "'";
 
-  if (llvm::Error err =
-          verifyStructIndices(getSourceElementType(), getIndices()))
-    return emitOpError() << llvm::toString(std::move(err));
-
-  return success();
+  return verifyStructIndices(getSourceElementType(), getIndices(),
+                             [&] { return emitOpError(); });
 }
 
 Type LLVM::GEPOp::getSourceElementType() {
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 2d0a68b8b6c942c..fe2f94454561a08 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1431,3 +1431,11 @@ llvm.func @invalid_variadic_call(%arg: i32)  {
   "llvm.call"(%arg) <{callee = @variadic}> : (i32) -> ()
   llvm.return
 }
+
+// -----
+
+llvm.func @foo(%arg: !llvm.ptr) {
+  // expected-error at +1 {{type '!llvm.ptr' cannot be indexed (index #1)}}
+  %0 = llvm.getelementptr %arg[0, 4] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr
+  llvm.return
+}
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 32e3fed7e5485df..16eb28c629ee232 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -549,7 +549,7 @@ llvm.func @trivial_get_element_ptr() {
   %1 = llvm.mlir.constant(2 : i64) : i64
   %2 = llvm.alloca %0 x i8 {alignment = 8 : i64} : (i32) -> !llvm.ptr
   %3 = llvm.bitcast %2 : !llvm.ptr to !llvm.ptr
-  %4 = llvm.getelementptr %3[0, 0, 0] : (!llvm.ptr) -> !llvm.ptr, i8
+  %4 = llvm.getelementptr %3[0] : (!llvm.ptr) -> !llvm.ptr, i8
   llvm.intr.lifetime.start 2, %3 : !llvm.ptr
   llvm.intr.lifetime.start 2, %4 : !llvm.ptr
   llvm.return
@@ -563,9 +563,8 @@ llvm.func @nontrivial_get_element_ptr() {
   %1 = llvm.mlir.constant(2 : i64) : i64
   // CHECK: = llvm.alloca
   %2 = llvm.alloca %0 x i8 {alignment = 8 : i64} : (i32) -> !llvm.ptr
-  %3 = llvm.bitcast %2 : !llvm.ptr to !llvm.ptr
-  %4 = llvm.getelementptr %3[0, 1, 0] : (!llvm.ptr) -> !llvm.ptr, i8
-  llvm.intr.lifetime.start 2, %3 : !llvm.ptr
+  %4 = llvm.getelementptr %2[1] : (!llvm.ptr) -> !llvm.ptr, i8
+  llvm.intr.lifetime.start 2, %2 : !llvm.ptr
   llvm.intr.lifetime.start 2, %4 : !llvm.ptr
   llvm.return
 }
@@ -579,7 +578,7 @@ llvm.func @dynamic_get_element_ptr() {
   // CHECK: = llvm.alloca
   %2 = llvm.alloca %0 x i8 {alignment = 8 : i64} : (i32) -> !llvm.ptr
   %3 = llvm.bitcast %2 : !llvm.ptr to !llvm.ptr
-  %4 = llvm.getelementptr %3[0, %0] : (!llvm.ptr, i32) -> !llvm.ptr, i8
+  %4 = llvm.getelementptr %3[%0] : (!llvm.ptr, i32) -> !llvm.ptr, i8
   llvm.intr.lifetime.start 2, %3 : !llvm.ptr
   llvm.intr.lifetime.start 2, %4 : !llvm.ptr
   llvm.return
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip-typed-pointers.mlir b/mlir/test/Dialect/LLVMIR/roundtrip-typed-pointers.mlir
index b1d72b690595c31..f974bcd2e02aff6 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip-typed-pointers.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip-typed-pointers.mlir
@@ -6,12 +6,12 @@ func.func @ops(%arg0: i32) {
 // Memory-related operations.
 //
 // CHECK-NEXT:  %[[ALLOCA:.*]] = llvm.alloca %[[I32]] x f64 : (i32) -> !llvm.ptr<f64>
-// CHECK-NEXT:  %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][%[[I32]], %[[I32]]] : (!llvm.ptr<f64>, i32, i32) -> !llvm.ptr<f64>
+// CHECK-NEXT:  %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][%[[I32]]] : (!llvm.ptr<f64>, i32) -> !llvm.ptr<f64>
 // CHECK-NEXT:  %[[VALUE:.*]] = llvm.load %[[GEP]] : !llvm.ptr<f64>
 // CHECK-NEXT:  llvm.store %[[VALUE]], %[[ALLOCA]] : !llvm.ptr<f64>
 // CHECK-NEXT:  %{{.*}} = llvm.bitcast %[[ALLOCA]] : !llvm.ptr<f64> to !llvm.ptr<i64>
   %13 = llvm.alloca %arg0 x f64 : (i32) -> !llvm.ptr<f64>
-  %14 = llvm.getelementptr %13[%arg0, %arg0] : (!llvm.ptr<f64>, i32, i32) -> !llvm.ptr<f64>
+  %14 = llvm.getelementptr %13[%arg0] : (!llvm.ptr<f64>, i32) -> !llvm.ptr<f64>
   %15 = llvm.load %14 : !llvm.ptr<f64>
   llvm.store %15, %13 : !llvm.ptr<f64>
   %16 = llvm.bitcast %13 : !llvm.ptr<f64> to !llvm.ptr<i64>
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 1134027c6b6570e..ee724a482cfb514 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -50,11 +50,11 @@ func.func @ops(%arg0: i32, %arg1: f32,
 // Memory-related operations.
 //
 // CHECK-NEXT:  %[[ALLOCA:.*]] = llvm.alloca %[[I32]] x f64 : (i32) -> !llvm.ptr
-// CHECK-NEXT:  %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][%[[I32]], %[[I32]]] : (!llvm.ptr, i32, i32) -> !llvm.ptr, f64
+// CHECK-NEXT:  %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][%[[I32]]] : (!llvm.ptr, i32) -> !llvm.ptr, f64
 // CHECK-NEXT:  %[[VALUE:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> f64
 // CHECK-NEXT:  llvm.store %[[VALUE]], %[[ALLOCA]] : f64, !llvm.ptr
   %13 = llvm.alloca %arg0 x f64 : (i32) -> !llvm.ptr
-  %14 = llvm.getelementptr %13[%arg0, %arg0] : (!llvm.ptr, i32, i32) -> !llvm.ptr, f64
+  %14 = llvm.getelementptr %13[%arg0] : (!llvm.ptr, i32) -> !llvm.ptr, f64
   %15 = llvm.load %14 : !llvm.ptr -> f64
   llvm.store %15, %13 : f64, !llvm.ptr
 

>From 1b52eaaa9a1fe51fd169b4e1502173539690486c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Markus=20B=C3=B6ck?= <markus.boeck02 at gmail.com>
Date: Wed, 25 Oct 2023 09:59:00 +0200
Subject: [PATCH 2/3] forgot to remove dead code

---
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 61 ----------------------
 1 file changed, 61 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 70045d028cc3214..06a3bd4561f7eb5 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -664,67 +664,6 @@ static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
       });
 }
 
-namespace {
-/// Base class for llvm::Error related to GEP index.
-class GEPIndexError : public llvm::ErrorInfo<GEPIndexError> {
-protected:
-  unsigned indexPos;
-
-public:
-  static char ID;
-
-  std::error_code convertToErrorCode() const override {
-    return llvm::inconvertibleErrorCode();
-  }
-
-  explicit GEPIndexError(unsigned pos) : indexPos(pos) {}
-};
-
-/// llvm::Error for out-of-bound GEP index.
-struct GEPIndexOutOfBoundError
-    : public llvm::ErrorInfo<GEPIndexOutOfBoundError, GEPIndexError> {
-  static char ID;
-
-  using ErrorInfo::ErrorInfo;
-
-  void log(llvm::raw_ostream &os) const override {
-    os << "index " << indexPos << " indexing a struct is out of bounds";
-  }
-};
-
-/// llvm::Error for non-static GEP index indexing a struct.
-struct GEPStaticIndexError
-    : public llvm::ErrorInfo<GEPStaticIndexError, GEPIndexError> {
-  static char ID;
-
-  using ErrorInfo::ErrorInfo;
-
-  void log(llvm::raw_ostream &os) const override {
-    os << "expected index " << indexPos << " indexing a struct "
-       << "to be constant";
-  }
-};
-
-/// llvm::Error for non-static GEP index indexing a struct.
-struct GEPCannotIndexError
-    : public llvm::ErrorInfo<GEPCannotIndexError, GEPIndexError> {
-  static char ID;
-
-  using ErrorInfo::ErrorInfo;
-
-  void log(llvm::raw_ostream &os) const override {
-    os << "expected index " << indexPos << " indexing a struct "
-       << "to be constant";
-  }
-};
-
-} // end anonymous namespace
-
-char GEPIndexError::ID = 0;
-char GEPIndexOutOfBoundError::ID = 0;
-char GEPStaticIndexError::ID = 0;
-char GEPCannotIndexError::ID = 0;
-
 /// For the given `indices`, check if they comply with `baseGEPType`,
 // especially check against LLVMStructTypes nested within.
 static LogicalResult

>From 3a5d92491519450a82659b5ddabd88fb244c6ea7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Markus=20B=C3=B6ck?= <markus.boeck02 at gmail.com>
Date: Wed, 25 Oct 2023 10:32:47 +0200
Subject: [PATCH 3/3] address review comments

---
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 06a3bd4561f7eb5..7f5681e7bdc0592 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -665,7 +665,7 @@ static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
 }
 
 /// For the given `indices`, check if they comply with `baseGEPType`,
-// especially check against LLVMStructTypes nested within.
+/// especially check against LLVMStructTypes nested within.
 static LogicalResult
 verifyStructIndices(Type baseGEPType, unsigned indexPos,
                     GEPIndicesAdaptor<ValueRange> indices,
@@ -674,7 +674,7 @@ verifyStructIndices(Type baseGEPType, unsigned indexPos,
     // Stop searching
     return success();
 
-  return llvm::TypeSwitch<Type, LogicalResult>(baseGEPType)
+  return TypeSwitch<Type, LogicalResult>(baseGEPType)
       .Case<LLVMStructType>([&](LLVMStructType structType) -> LogicalResult {
         if (!indices[indexPos].is<IntegerAttr>())
           return emitOpError() << "expected index " << indexPos



More information about the flang-commits mailing list