[Mlir-commits] [mlir] f50cfc4 - [mlir] Require struct indices in LLVM::GEPOp to be constant

Alex Zinenko llvmlistbot at llvm.org
Fri Jan 7 00:56:14 PST 2022


Author: Alex Zinenko
Date: 2022-01-07T09:56:05+01:00
New Revision: f50cfc44d60bb6b12a2ee801a69b35fe7d6dbcf3

URL: https://github.com/llvm/llvm-project/commit/f50cfc44d60bb6b12a2ee801a69b35fe7d6dbcf3
DIFF: https://github.com/llvm/llvm-project/commit/f50cfc44d60bb6b12a2ee801a69b35fe7d6dbcf3.diff

LOG: [mlir] Require struct indices in LLVM::GEPOp to be constant

Recent commits added a possibility for indices in LLVM dialect GEP operations
to be supplied directly as constant attributes to ensure they remain such until
translation to LLVM IR happens. Make this required for indexing into LLVM
struct types to match LLVM IR requirements, otherwise the translation would
assert on constructing such IR.

For better compatibility with MLIR-style operation construction interface,
allow GEP operations to be constructed programmatically using Values pointing
to known constant operations as struct indices.

Depends On D116758

Reviewed By: wsmoses

Differential Revision: https://reviews.llvm.org/D116759

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
    mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
    mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Target/LLVMIR/llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 3d6f7f30ec775..328ff9da3639a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -350,6 +350,9 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]> {
     constexpr static int kDynamicIndex = std::numeric_limits<int32_t>::min();
   }];
   let hasFolder = 1;
+  let verifier = [{
+    return ::verify(*this);
+  }];
 }
 
 def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 56b14a77d2c6b..0f6c91ffd3f5f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -360,6 +360,58 @@ SwitchOp::getMutableSuccessorOperands(unsigned index) {
 // Code for LLVM::GEPOp.
 //===----------------------------------------------------------------------===//
 
+/// Populates `indices` with positions of GEP indices that would correspond to
+/// LLVMStructTypes potentially nested in the given type. The type currently
+/// visited gets `currentIndex` and LLVM container types are visited
+/// recursively. The recursion is bounded and takes care of recursive types by
+/// means of the `visited` set.
+static void recordStructIndices(Type type, unsigned currentIndex,
+                                SmallVectorImpl<unsigned> &indices,
+                                SmallVectorImpl<unsigned> *structSizes,
+                                SmallPtrSet<Type, 4> &visited) {
+  if (visited.contains(type))
+    return;
+
+  visited.insert(type);
+
+  llvm::TypeSwitch<Type>(type)
+      .Case<LLVMStructType>([&](LLVMStructType structType) {
+        indices.push_back(currentIndex);
+        if (structSizes)
+          structSizes->push_back(structType.getBody().size());
+        for (Type elementType : structType.getBody())
+          recordStructIndices(elementType, currentIndex + 1, indices,
+                              structSizes, visited);
+      })
+      .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
+            LLVMArrayType>([&](auto containerType) {
+        recordStructIndices(containerType.getElementType(), currentIndex + 1,
+                            indices, structSizes, visited);
+      });
+}
+
+/// Populates `indices` with positions of GEP indices that correspond to
+/// LLVMStructTypes potentially nested in the given `baseGEPType`, which must
+/// be either an LLVMPointer type or a vector thereof. If `structSizes` is
+/// provided, it is populated with sizes of the indexed structs for bounds
+/// verification purposes.
+static void
+findKnownStructIndices(Type baseGEPType, SmallVectorImpl<unsigned> &indices,
+                       SmallVectorImpl<unsigned> *structSizes = nullptr) {
+  Type type = baseGEPType;
+  if (auto vectorType = type.dyn_cast<VectorType>())
+    type = vectorType.getElementType();
+  if (auto scalableVectorType = type.dyn_cast<LLVMScalableVectorType>())
+    type = scalableVectorType.getElementType();
+  if (auto fixedVectorType = type.dyn_cast<LLVMFixedVectorType>())
+    type = fixedVectorType.getElementType();
+
+  Type pointeeType = type.cast<LLVMPointerType>().getElementType();
+  SmallPtrSet<Type, 4> visited;
+  recordStructIndices(pointeeType, /*currentIndex=*/1, indices, structSizes,
+                      visited);
+}
+
 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
                   Value basePtr, ValueRange operands,
                   ArrayRef<NamedAttribute> attributes) {
@@ -372,11 +424,58 @@ void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
                   Value basePtr, ValueRange indices,
                   ArrayRef<int32_t> structIndices,
                   ArrayRef<NamedAttribute> attributes) {
+  SmallVector<Value> remainingIndices;
+  SmallVector<int32_t> updatedStructIndices(structIndices.begin(),
+                                            structIndices.end());
+  SmallVector<unsigned> structRelatedPositions;
+  findKnownStructIndices(basePtr.getType(), structRelatedPositions);
+
+  SmallVector<unsigned> operandsToErase;
+  for (unsigned pos : structRelatedPositions) {
+    // GEP may not be indexing as deep as some structs are located.
+    if (pos >= structIndices.size())
+      continue;
+
+    // If the index is already static, it's fine.
+    if (structIndices[pos] != kDynamicIndex)
+      continue;
+
+    // Find the corresponding operand.
+    unsigned operandPos =
+        std::count(structIndices.begin(), std::next(structIndices.begin(), pos),
+                   kDynamicIndex);
+
+    // Extract the constant value from the operand and put it into the attribute
+    // instead.
+    APInt staticIndexValue;
+    bool matched =
+        matchPattern(indices[operandPos], m_ConstantInt(&staticIndexValue));
+    (void)matched;
+    assert(matched && "index into a struct must be a constant");
+    assert(staticIndexValue.sge(APInt::getSignedMinValue(/*numBits=*/32)) &&
+           "struct index underflows 32-bit integer");
+    assert(staticIndexValue.sle(APInt::getSignedMaxValue(/*numBits=*/32)) &&
+           "struct index overflows 32-bit integer");
+    auto staticIndex = static_cast<int32_t>(staticIndexValue.getSExtValue());
+    updatedStructIndices[pos] = staticIndex;
+    operandsToErase.push_back(operandPos);
+  }
+
+  for (unsigned i = 0, e = indices.size(); i < e; ++i) {
+    if (llvm::find(operandsToErase, i) == operandsToErase.end())
+      remainingIndices.push_back(indices[i]);
+  }
+
+  assert(remainingIndices.size() == static_cast<size_t>(llvm::count(
+                                        updatedStructIndices, kDynamicIndex)) &&
+         "exected as many index operands as dynamic index attr elements");
+
   result.addTypes(resultType);
   result.addAttributes(attributes);
-  result.addAttribute("structIndices", builder.getI32TensorAttr(structIndices));
+  result.addAttribute("structIndices",
+                      builder.getI32TensorAttr(updatedStructIndices));
   result.addOperands(basePtr);
-  result.addOperands(indices);
+  result.addOperands(remainingIndices);
 }
 
 static ParseResult
@@ -417,6 +516,27 @@ static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
                         });
 }
 
+LogicalResult verify(LLVM::GEPOp gepOp) {
+  SmallVector<unsigned> indices;
+  SmallVector<unsigned> structSizes;
+  findKnownStructIndices(gepOp.getBase().getType(), indices, &structSizes);
+  for (unsigned i = 0, e = indices.size(); i < e; ++i) {
+    unsigned index = indices[i];
+    // GEP may not be indexing as deep as some structs nested in the type.
+    if (index >= gepOp.getStructIndices().getNumElements())
+      continue;
+
+    int32_t staticIndex = gepOp.getStructIndices().getValues<int32_t>()[index];
+    if (staticIndex == LLVM::GEPOp::kDynamicIndex)
+      return gepOp.emitOpError() << "expected index " << index
+                                 << " indexing a struct to be constant";
+    if (staticIndex < 0 || static_cast<unsigned>(staticIndex) >= structSizes[i])
+      return gepOp.emitOpError()
+             << "index " << index << " indexing a struct is out of bounds";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Builder, printer and parser for for LLVM::LoadOp.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
index 321e6190c066c..be16764f5eb66 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
@@ -501,8 +501,7 @@ func @memref_reshape(%input : memref<2x3xf32>, %shape : memref<?xindex>) {
 // CHECK: [[STRUCT_PTR:%.*]] = llvm.bitcast [[UNDERLYING_DESC]]
 // CHECK-SAME: !llvm.ptr<i8> to !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, i64)>>
 // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: [[C3_I32:%.*]] = llvm.mlir.constant(3 : i32) : i32
-// CHECK: [[SIZES_PTR:%.*]] = llvm.getelementptr [[STRUCT_PTR]]{{\[}}[[C0]], [[C3_I32]]]
+// CHECK: [[SIZES_PTR:%.*]] = llvm.getelementptr [[STRUCT_PTR]]{{\[}}[[C0]], 3]
 // CHECK: [[STRIDES_PTR:%.*]] = llvm.getelementptr [[SIZES_PTR]]{{\[}}[[RANK]]]
 // CHECK: [[SHAPE_IN_PTR:%.*]] = llvm.extractvalue [[SHAPE]][1] : [[SHAPE_TY]]
 // CHECK: [[C1_:%.*]] = llvm.mlir.constant(1 : index) : i64

diff  --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 70ba47d2d176b..d790b2b14140d 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -547,12 +547,11 @@ func @dim_of_unranked(%unranked: memref<*xi32>) -> index {
 // CHECK: %[[ZERO_D_DESC:.*]] = llvm.bitcast %[[RANKED_DESC]]
 // CHECK-SAME:   : !llvm.ptr<i8> to !llvm.ptr<struct<(ptr<i32>, ptr<i32>, i64)>>
 
-// CHECK: %[[C2_i32:.*]] = llvm.mlir.constant(2 : i32) : i32
 // CHECK: %[[C0_:.*]] = llvm.mlir.constant(0 : index) : i64
 
 // CHECK: %[[OFFSET_PTR:.*]] = llvm.getelementptr %[[ZERO_D_DESC]]{{\[}}
-// CHECK-SAME:   %[[C0_]], %[[C2_i32]]] : (!llvm.ptr<struct<(ptr<i32>, ptr<i32>,
-// CHECK-SAME:   i64)>>, i64, i32) -> !llvm.ptr<i64>
+// CHECK-SAME:   %[[C0_]], 2] : (!llvm.ptr<struct<(ptr<i32>, ptr<i32>,
+// CHECK-SAME:   i64)>>, i64) -> !llvm.ptr<i64>
 
 // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
 // CHECK: %[[INDEX_INC:.*]] = llvm.add %[[C1]], %{{.*}} : i64

diff  --git a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir
index effc9befb2889..ea68dc9d57189 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir
@@ -10,7 +10,7 @@ spv.func @access_chain() "None" {
   %0 = spv.Constant 1: i32
   %1 = spv.Variable : !spv.ptr<!spv.struct<(f32, !spv.array<4xf32>)>, Function>
   // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
-  // CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], %[[ONE]], %[[ONE]]] : (!llvm.ptr<struct<packed (f32, array<4 x f32>)>>, i32, i32, i32) -> !llvm.ptr<f32>
+  // CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], 1, %[[ONE]]] : (!llvm.ptr<struct<packed (f32, array<4 x f32>)>>, i32, i32) -> !llvm.ptr<f32>
   %2 = spv.AccessChain %1[%0, %0] : !spv.ptr<!spv.struct<(f32, !spv.array<4xf32>)>, Function>, i32, i32
   spv.Return
 }

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 41919ceaaed03..82561a75e13bc 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1234,3 +1234,19 @@ func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
   nvvm.cp.async.shared.global %arg0, %arg1, 32
   return
 }
+
+// -----
+
+func @gep_struct_variable(%arg0: !llvm.ptr<struct<(i32)>>, %arg1: i32, %arg2: i32) {
+  // expected-error @below {{op expected index 1 indexing a struct to be constant}}
+  llvm.getelementptr %arg0[%arg1, %arg1] : (!llvm.ptr<struct<(i32)>>, i32, i32) -> !llvm.ptr<i32>
+  return
+}
+
+// -----
+
+func @gep_out_of_bounds(%ptr: !llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, %idx: i64) {
+  // expected-error @below {{index 2 indexing a struct is out of bounds}}
+  llvm.getelementptr %ptr[%idx, 1, 3] : (!llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, i64) -> !llvm.ptr<i32>
+  return
+}

diff  --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 04a65f845148d..c4a5434144180 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1444,7 +1444,7 @@ llvm.mlir.global linkonce @take_self_address() : !llvm.struct<(i32, !llvm.ptr<i3
   %z32 = llvm.mlir.constant(0 : i32) : i32
   %0 = llvm.mlir.undef : !llvm.struct<(i32, !llvm.ptr<i32>)>
   %1 = llvm.mlir.addressof @take_self_address : !llvm.ptr<!llvm.struct<(i32, !llvm.ptr<i32>)>>
-  %2 = llvm.getelementptr %1[%z32, %z32] : (!llvm.ptr<!llvm.struct<(i32, !llvm.ptr<i32>)>>, i32, i32) -> !llvm.ptr<i32>
+  %2 = llvm.getelementptr %1[%z32, 0] : (!llvm.ptr<!llvm.struct<(i32, !llvm.ptr<i32>)>>, i32) -> !llvm.ptr<i32>
   %3 = llvm.insertvalue %z32, %0[0 : i32] : !llvm.struct<(i32, !llvm.ptr<i32>)>
   %4 = llvm.insertvalue %2, %3[1 : i32] : !llvm.struct<(i32, !llvm.ptr<i32>)>
   llvm.return %4 : !llvm.struct<(i32, !llvm.ptr<i32>)>


        


More information about the Mlir-commits mailing list