[Mlir-commits] [mlir] 37da2a1 - [mlir][LLVM] Rework the API of GEPOp
    Markus Böck 
    llvmlistbot at llvm.org
       
    Fri Jul 29 09:32:42 PDT 2022
    
    
  
Author: Markus Böck
Date: 2022-07-29T18:22:54+02:00
New Revision: 37da2a141c6aa02e3ef86a9010ef2b2d793a2c66
URL: https://github.com/llvm/llvm-project/commit/37da2a141c6aa02e3ef86a9010ef2b2d793a2c66
DIFF: https://github.com/llvm/llvm-project/commit/37da2a141c6aa02e3ef86a9010ef2b2d793a2c66.diff
LOG: [mlir][LLVM] Rework the API of GEPOp
The implementation and API of GEP Op has gotten a bit convoluted over the time. Issues with it are:
* Misleading naming: `indices` actually only contains the dynamic indices, not all of them. To get the amount of indices you need to query the size of `structIndices`
* Very difficult to iterate over all indices properly: One had to iterate over `structIndices`, check whether it contains the magic constant `kDynamicIndex`, if it does, access the next value in `index` etc.
* Inconvenient to build: One either has create lots of constant ops for every index or have an odd split of passing both a `ValueRange` as well as a `ArrayRef<int32_t>` filled with `kDynamicIndex` at the correct places.
* Implementation doing verification in the build method
and more.
This patch attempts to address all these issues via convenience classes and reworking the way GEP Op works:
* Adds `GEPArg` class which is a sum type of a `int32_t` and `Value` and is used to have a single convenient easy to use `ArrayRef<GEPArg>` in the builders instead of the previous `ValueRange` + `ArrayRef<int32_t>` builders.
* Adds `GEPIndicesAdapter` which is a class used for easy random access and iteration over the indices of a GEP. It is generic and flexible enough to also instead return eg. a corresponding `Attribute` for an operand inside of `fold`.
*  Rename `structIndices` to `rawConstantIndices` and `indices` to `dynamicIndices`: `rawConstantIndices` signifies one shouldn't access it directly as it is encoded, and `dynamicIndices` is more accurate and also frees up the `indices` name.
* Add `getIndices` returning a `GEPIndicesAdapter` to easily iterate over the GEP Ops indices.
* Move the verification/asserts out of the build method and into the `verify` method emitting op error messages.
* Add convenient builder methods making use of `GEPArg`.
* Add canonicalizer turning dynamic indices with constant values into constant indices to have a canonical representation.
The only breaking change is for any users building GEPOps that have so far used the old `ValueRange` + `ArrayRef<int32_t>` builder as well as those using the generic syntax.
Another follow up patch then goes through upstream and makes use of the new `ArrayRef<GEPArg>` builder to remove a lot of code building constants for GEP indices.
Differential Revision: https://reviews.llvm.org/D130730
Added: 
    
Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/IR/BuiltinAttributes.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
    mlir/test/Dialect/LLVMIR/canonicalize.mlir
    mlir/test/Dialect/LLVMIR/dynamic-gep-index.mlir
    mlir/test/Dialect/LLVMIR/invalid.mlir
Removed: 
    
################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index 107701303099..cfdf4e31bdf6 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -27,6 +27,7 @@
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Support/ThreadLocalCache.h"
+#include "llvm/ADT/PointerEmbeddedInt.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
@@ -72,6 +73,36 @@ struct LLVMDialectImpl;
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.h.inc"
 
+namespace mlir {
+namespace LLVM {
+template <typename Values>
+class GEPIndicesAdaptor;
+
+/// Bit-width of a 'GEPConstantIndex' within GEPArg.
+constexpr int kGEPConstantBitWidth = 29;
+/// Wrapper around a int32_t for use in a PointerUnion.
+using GEPConstantIndex =
+    llvm::PointerEmbeddedInt<int32_t, kGEPConstantBitWidth>;
+
+/// Class used for building a 'llvm.getelementptr'. A single instance represents
+/// a sum type that is either a 'Value' or a constant 'GEPConstantIndex' index.
+/// The former represents a dynamic index in a GEP operation, while the later is
+/// a constant index as is required for indices into struct types.
+class GEPArg : public PointerUnion<Value, GEPConstantIndex> {
+  using BaseT = PointerUnion<Value, GEPConstantIndex>;
+
+public:
+  /// Constructs a GEPArg with a constant index.
+  /*implicit*/ GEPArg(int32_t integer) : BaseT(integer) {}
+
+  /// Constructs a GEPArg with a dynamic index.
+  /*implicit*/ GEPArg(Value value) : BaseT(value) {}
+
+  using BaseT::operator=;
+};
+} // namespace LLVM
+} // namespace mlir
+
 ///// Ops /////
 #define GET_OP_CLASSES
 #include "mlir/Dialect/LLVMIR/LLVMOps.h.inc"
@@ -82,6 +113,114 @@ struct LLVMDialectImpl;
 
 namespace mlir {
 namespace LLVM {
+
+/// Class used for convenient access and iteration over GEP indices.
+/// This class is templated to support not only retrieving the dynamic operands
+/// of a GEP operation, but also as an adaptor during folding or conversion to
+/// LLVM IR.
+///
+/// GEP indices may either be constant indices or dynamic indices. The
+/// 'rawConstantIndices' is specially encoded by GEPOp and contains either the
+/// constant index or the information that an index is a dynamic index.
+///
+/// When an access to such an index is made it is done through the
+/// 'DynamicRange' of this class. This way it can be used as getter in GEPOp via
+/// 'GEPIndicesAdaptor<ValueRange>' or during folding via
+/// 'GEPIndicesAdaptor<ArrayRef<Attribute>>'.
+template <typename DynamicRange>
+class GEPIndicesAdaptor {
+public:
+  /// Return type of 'operator[]' and the iterators 'operator*'. It is depended
+  /// upon the value type of 'DynamicRange'. If 'DynamicRange' contains
+  /// Attributes or subclasses thereof, then value_type is 'Attribute'. In
+  /// all other cases it is a pointer union between the value type of
+  /// 'DynamicRange' and IntegerAttr.
+  using value_type = std::conditional_t<
+      std::is_base_of<Attribute,
+                      llvm::detail::ValueOfRange<DynamicRange>>::value,
+      Attribute,
+      PointerUnion<IntegerAttr, llvm::detail::ValueOfRange<DynamicRange>>>;
+
+  /// Constructs a GEPIndicesAdaptor with the raw constant indices of a GEPOp
+  /// and the range that is indexed into for retrieving dynamic indices.
+  GEPIndicesAdaptor(DenseI32ArrayAttr rawConstantIndices, DynamicRange values)
+      : rawConstantIndices(rawConstantIndices), values(std::move(values)) {}
+
+  /// Returns the GEP index at the given position. Note that this operation has
+  /// a linear complexity in regards to the accessed position. To iterate over
+  /// all indices, use the iterators.
+  ///
+  /// This operation is invalid if the index is out of bounds.
+  value_type operator[](size_t index) const {
+    assert(index < size() && "index out of bounds");
+    return *std::next(begin(), index);
+  }
+
+  /// Returns whether the GEP index at the given position is a dynamic index.
+  bool isDynamicIndex(size_t index) const {
+    return rawConstantIndices[index] == GEPOp::kDynamicIndex;
+  }
+
+  /// Returns the amount of indices of the GEPOp.
+  size_t size() const { return rawConstantIndices.size(); }
+
+  /// Returns true if this GEPOp does not have any indices.
+  bool empty() const { return rawConstantIndices.empty(); }
+
+  class iterator
+      : public llvm::iterator_facade_base<iterator, std::forward_iterator_tag,
+                                          value_type, std::ptr
diff _t,
+                                          value_type *, value_type> {
+  public:
+    iterator(const GEPIndicesAdaptor *base,
+             ArrayRef<int32_t>::iterator rawConstantIter,
+             llvm::detail::IterOfRange<const DynamicRange> valuesIter)
+        : base(base), rawConstantIter(rawConstantIter), valuesIter(valuesIter) {
+    }
+
+    value_type operator*() const {
+      if (*rawConstantIter == GEPOp::kDynamicIndex)
+        return *valuesIter;
+
+      return IntegerAttr::get(
+          ElementsAttr::getElementType(base->rawConstantIndices),
+          *rawConstantIter);
+    }
+
+    iterator &operator++() {
+      if (*rawConstantIter == GEPOp::kDynamicIndex)
+        valuesIter++;
+      rawConstantIter++;
+      return *this;
+    }
+
+    bool operator==(const iterator &rhs) const {
+      return base == rhs.base && rawConstantIter == rhs.rawConstantIter &&
+             valuesIter == rhs.valuesIter;
+    }
+
+  private:
+    const GEPIndicesAdaptor *base;
+    ArrayRef<int32_t>::const_iterator rawConstantIter;
+    llvm::detail::IterOfRange<const DynamicRange> valuesIter;
+  };
+
+  /// Returns the begin iterator, iterating over all GEP indices.
+  iterator begin() const {
+    return iterator(this, rawConstantIndices.asArrayRef().begin(),
+                    values.begin());
+  }
+
+  /// Returns the end iterator, iterating over all GEP indices.
+  iterator end() const {
+    return iterator(this, rawConstantIndices.asArrayRef().end(), values.end());
+  }
+
+private:
+  DenseI32ArrayAttr rawConstantIndices;
+  DynamicRange values;
+};
+
 /// Create an LLVM global containing the string "value" at the module containing
 /// surrounding the insertion point of builder. Obtain the address of that
 /// global and use it to compute the address of the first character in the
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index afeec12b1b74..1eb861c3b8de 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -423,51 +423,75 @@ def LLVM_AllocaOp : LLVM_Op<"alloca">, MemoryOpWithAlignmentBase {
 
 def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]> {
   let arguments = (ins LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$base,
-                   Variadic<LLVM_ScalarOrVectorOf<AnyInteger>>:$indices,
-                   I32ElementsAttr:$structIndices,
+                   Variadic<LLVM_ScalarOrVectorOf<AnyInteger>>:$dynamicIndices,
+                   DenseI32ArrayAttr:$rawConstantIndices,
                    OptionalAttr<TypeAttr>:$elem_type);
   let results = (outs LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$res);
   let skipDefaultBuilders = 1;
+
+  let description = [{
+    This operation mirrors LLVM IRs 'getelementptr' operation that is used to
+    perform pointer arithmetic.
+
+    Like in LLVM IR, it is possible to use both constants as well as SSA values
+    as indices. In the case of indexing within a structure, it is required to
+    either use constant indices directly, or supply a constant SSA value.
+
+    Examples:
+
+    ```mlir
+    // GEP with an SSA value offset
+    %0 = llvm.getelementptr %1[%2] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+
+    // GEP with a constant offset
+    %0 = llvm.getelementptr %1[3] : (!llvm.ptr<f32>) -> !llvm.ptr<f32>
+
+    // GEP with constant offsets into a structure
+    %0 = llvm.getelementptr %1[0, 1]
+       : (!llvm.ptr<struct(i32, f32)>) -> !llvm.ptr<f32>
+    ```
+  }];
+
   let builders = [
-    OpBuilder<(ins "Type":$resultType, "Value":$basePtr, "ValueRange":$indices,
-               CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
-    OpBuilder<(ins "Type":$resultType, "Value":$basePtr, "ValueRange":$indices,
-               "ArrayRef<int32_t>":$structIndices,
-               CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
     OpBuilder<(ins "Type":$resultType, "Type":$basePtrType, "Value":$basePtr,
                "ValueRange":$indices,
                CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
-	OpBuilder<(ins "Type":$resultType, "Type":$basePtrType, "Value":$basePtr,
-               "ValueRange":$indices,
+    OpBuilder<(ins "Type":$resultType, "Value":$basePtr, "ValueRange":$indices,
+               CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
+    OpBuilder<(ins "Type":$resultType, "Value":$basePtr, "ArrayRef<GEPArg>":$indices,
                CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
     OpBuilder<(ins "Type":$resultType, "Type":$basePtrType, "Value":$basePtr,
-               "ValueRange":$indices, "ArrayRef<int32_t>":$structIndices,
-               CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
+               "ArrayRef<GEPArg>":$indices,
+               CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
   ];
   let llvmBuilder = [{
     SmallVector<llvm::Value *> indices;
-    indices.reserve($structIndices.size());
-    unsigned operandIdx = 0;
-    for (int32_t structIndex : $structIndices.getValues<int32_t>()) {
-      if (structIndex == GEPOp::kDynamicIndex)
-        indices.push_back($indices[operandIdx++]);
+    indices.reserve($rawConstantIndices.size());
+    GEPIndicesAdaptor<decltype($dynamicIndices)>
+        gepIndices(op.getRawConstantIndicesAttr(), $dynamicIndices);
+    for (PointerUnion<IntegerAttr, llvm::Value*> valueOrAttr : gepIndices) {
+      if (llvm::Value* value = valueOrAttr.dyn_cast<llvm::Value*>())
+        indices.push_back(value);
       else
-        indices.push_back(builder.getInt32(structIndex));
+        indices.push_back(
+            builder.getInt32(valueOrAttr.get<IntegerAttr>().getInt()));
     }
     Type baseElementType = op.getSourceElementType();
     llvm::Type *elementType = moduleTranslation.convertType(baseElementType);
     $res = builder.CreateGEP(elementType, $base, indices);
   }];
   let assemblyFormat = [{
-    $base `[` custom<GEPIndices>($indices, $structIndices) `]` attr-dict
+    $base `[` custom<GEPIndices>($dynamicIndices, $rawConstantIndices) `]` attr-dict
     `:` functional-type(operands, results) (`,` $elem_type^)?
   }];
 
   let extraClassDeclaration = [{
-    constexpr static int kDynamicIndex = std::numeric_limits<int32_t>::min();
+    constexpr static int32_t kDynamicIndex = std::numeric_limits<int32_t>::min();
 
     /// Returns the type pointed to by the pointer argument of this GEP.
     Type getSourceElementType();
+
+    GEPIndicesAdaptor<ValueRange> getIndices();
   }];
   let hasFolder = 1;
   let hasVerifier = 1;
diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index a88dc7af9985..29f377b214e6 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -755,7 +755,10 @@ class DenseArrayAttr : public DenseArrayBaseAttr {
 
   /// Implicit conversion to ArrayRef<T>.
   operator ArrayRef<T>() const;
-  ArrayRef<T> asArrayRef() { return ArrayRef<T>{*this}; }
+  ArrayRef<T> asArrayRef() const { return ArrayRef<T>{*this}; }
+
+  /// Random access to elements.
+  T operator[](std::size_t index) const { return asArrayRef()[index]; }
 
   /// Builder from ArrayRef<T>.
   static DenseArrayAttr get(MLIRContext *context, ArrayRef<T> content);
@@ -1017,6 +1020,14 @@ struct PointerLikeTypeTraits<mlir::StringAttr>
   }
 };
 
+template <>
+struct PointerLikeTypeTraits<mlir::IntegerAttr>
+    : public PointerLikeTypeTraits<mlir::Attribute> {
+  static inline mlir::IntegerAttr getFromVoidPointer(void *p) {
+    return mlir::IntegerAttr::getFromOpaquePointer(p);
+  }
+};
+
 template <>
 struct PointerLikeTypeTraits<mlir::SymbolRefAttr>
     : public PointerLikeTypeTraits<mlir::Attribute> {
diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index f635c9c21434..e77f5a118c0e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -418,7 +418,159 @@ SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
 // Code for LLVM::GEPOp.
 //===----------------------------------------------------------------------===//
 
-constexpr int GEPOp::kDynamicIndex;
+constexpr int32_t GEPOp::kDynamicIndex;
+
+GEPIndicesAdaptor<ValueRange> GEPOp::getIndices() {
+  return GEPIndicesAdaptor<ValueRange>(getRawConstantIndicesAttr(),
+                                       getDynamicIndices());
+}
+
+/// Returns the elemental type of any LLVM-compatible vector type or self.
+static Type extractVectorElementType(Type type) {
+  if (auto vectorType = type.dyn_cast<VectorType>())
+    return vectorType.getElementType();
+  if (auto scalableVectorType = type.dyn_cast<LLVMScalableVectorType>())
+    return scalableVectorType.getElementType();
+  if (auto fixedVectorType = type.dyn_cast<LLVMFixedVectorType>())
+    return fixedVectorType.getElementType();
+  return type;
+}
+
+void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
+                  Value basePtr, ArrayRef<GEPArg> indices,
+                  ArrayRef<NamedAttribute> attributes) {
+  auto ptrType =
+      extractVectorElementType(basePtr.getType()).cast<LLVMPointerType>();
+  assert(!ptrType.isOpaque() &&
+         "expected non-opaque pointer, provide elementType explicitly when "
+         "opaque pointers are used");
+  build(builder, result, resultType, ptrType.getElementType(), basePtr, indices,
+        attributes);
+}
+
+/// Destructures the 'indices' parameter into 'rawConstantIndices' and
+/// 'dynamicIndices', encoding the former in the process. In the process,
+/// dynamic indices which are used to index into a structure type are converted
+/// to constant indices when possible. To do this, the GEPs element type should
+/// be passed as first parameter.
+static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
+                               SmallVectorImpl<int32_t> &rawConstantIndices,
+                               SmallVectorImpl<Value> &dynamicIndices) {
+  for (const GEPArg &iter : indices) {
+    // If the thing we are currently indexing into is a struct we must turn
+    // any integer constants into constant indices. If this is not possible
+    // we don't do anything here. The verifier will catch it and emit a proper
+    // error. All other canonicalization is done in the fold method.
+    bool requiresConst = !rawConstantIndices.empty() &&
+                         currType.isa_and_nonnull<LLVMStructType>();
+    if (Value val = iter.dyn_cast<Value>()) {
+      APInt intC;
+      if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) &&
+          intC.isSignedIntN(kGEPConstantBitWidth)) {
+        rawConstantIndices.push_back(intC.getSExtValue());
+      } else {
+        rawConstantIndices.push_back(GEPOp::kDynamicIndex);
+        dynamicIndices.push_back(val);
+      }
+    } else {
+      rawConstantIndices.push_back(iter.get<GEPConstantIndex>());
+    }
+
+    // Skip for very first iteration of this loop. First index does not index
+    // within the aggregates, but is just a pointer offset.
+    if (rawConstantIndices.size() == 1 || !currType)
+      continue;
+
+    currType =
+        TypeSwitch<Type, Type>(currType)
+            .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
+                  LLVMArrayType>([](auto containerType) {
+              return containerType.getElementType();
+            })
+            .Case([&](LLVMStructType structType) -> Type {
+              int64_t memberIndex = rawConstantIndices.back();
+              if (memberIndex >= 0 && static_cast<size_t>(memberIndex) <
+                                          structType.getBody().size())
+                return structType.getBody()[memberIndex];
+              return nullptr;
+            })
+            .Default(Type(nullptr));
+  }
+}
+
+void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
+                  Type elementType, Value basePtr, ArrayRef<GEPArg> indices,
+                  ArrayRef<NamedAttribute> attributes) {
+  SmallVector<int32_t> rawConstantIndices;
+  SmallVector<Value> dynamicIndices;
+  destructureIndices(elementType, indices, rawConstantIndices, dynamicIndices);
+
+  result.addTypes(resultType);
+  result.addAttributes(attributes);
+  result.addAttribute(getRawConstantIndicesAttrName(result.name),
+                      builder.getDenseI32ArrayAttr(rawConstantIndices));
+  if (extractVectorElementType(basePtr.getType())
+          .cast<LLVMPointerType>()
+          .isOpaque())
+    result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType));
+  result.addOperands(basePtr);
+  result.addOperands(dynamicIndices);
+}
+
+void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
+                  Value basePtr, ValueRange indices,
+                  ArrayRef<NamedAttribute> attributes) {
+  build(builder, result, resultType, basePtr, SmallVector<GEPArg>(indices),
+        attributes);
+}
+
+void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
+                  Type elementType, Value basePtr, ValueRange indices,
+                  ArrayRef<NamedAttribute> attributes) {
+  build(builder, result, resultType, elementType, basePtr,
+        SmallVector<GEPArg>(indices), attributes);
+}
+
+static ParseResult
+parseGEPIndices(OpAsmParser &parser,
+                SmallVectorImpl<OpAsmParser::UnresolvedOperand> &indices,
+                DenseI32ArrayAttr &rawConstantIndices) {
+  SmallVector<int32_t> constantIndices;
+
+  auto idxParser = [&]() -> ParseResult {
+    int32_t constantIndex;
+    OptionalParseResult parsedInteger =
+        parser.parseOptionalInteger(constantIndex);
+    if (parsedInteger.hasValue()) {
+      if (failed(parsedInteger.getValue()))
+        return failure();
+      constantIndices.push_back(constantIndex);
+      return success();
+    }
+
+    constantIndices.push_back(LLVM::GEPOp::kDynamicIndex);
+    return parser.parseOperand(indices.emplace_back());
+  };
+  if (parser.parseCommaSeparatedList(idxParser))
+    return failure();
+
+  rawConstantIndices =
+      DenseI32ArrayAttr::get(parser.getContext(), constantIndices);
+  return success();
+}
+
+static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
+                            OperandRange indices,
+                            DenseI32ArrayAttr rawConstantIndices) {
+  llvm::interleaveComma(
+      GEPIndicesAdaptor<OperandRange>(rawConstantIndices, indices), printer,
+      [&](PointerUnion<IntegerAttr, Value> cst) {
+        if (Value val = cst.dyn_cast<Value>())
+          printer.printOperand(val);
+        else
+          printer << cst.get<IntegerAttr>().getInt();
+      });
+}
 
 namespace {
 /// Base class for llvm::Error related to GEP index.
@@ -467,69 +619,33 @@ char GEPIndexOutOfBoundError::ID = 0;
 char GEPStaticIndexError::ID = 0;
 
 /// For the given `structIndices` and `indices`, check if they're complied
-/// with `baseGEPType`, especially check against LLVMStructTypes nested within,
-/// and refine/promote struct index from `indices` to `updatedStructIndices`
-/// if the latter argument is not null.
-static llvm::Error
-recordStructIndices(Type baseGEPType, unsigned indexPos,
-                    ArrayRef<int32_t> structIndices, ValueRange indices,
-                    SmallVectorImpl<int32_t> *updatedStructIndices,
-                    SmallVectorImpl<Value> *remainingIndices) {
-  if (indexPos >= structIndices.size())
+/// with `baseGEPType`, especially check against LLVMStructTypes nested within.
+static llvm::Error verifyStructIndices(Type baseGEPType, unsigned indexPos,
+                                       GEPIndicesAdaptor<ValueRange> indices) {
+  if (indexPos >= indices.size())
     // Stop searching
     return llvm::Error::success();
 
-  int32_t gepIndex = structIndices[indexPos];
-  bool isStaticIndex = gepIndex != GEPOp::kDynamicIndex;
-
-  unsigned dynamicIndexPos = indexPos;
-  if (!isStaticIndex)
-    dynamicIndexPos = llvm::count(structIndices.take_front(indexPos + 1),
-                                  LLVM::GEPOp::kDynamicIndex) -
-                      1;
-
   return llvm::TypeSwitch<Type, llvm::Error>(baseGEPType)
       .Case<LLVMStructType>([&](LLVMStructType structType) -> llvm::Error {
-        // We don't always want to refine the index (e.g. when performing
-        // verification), so we only refine when updatedStructIndices is not
-        // null.
-        if (!isStaticIndex && updatedStructIndices) {
-          // Try to refine.
-          APInt staticIndexValue;
-          isStaticIndex = matchPattern(indices[dynamicIndexPos],
-                                       m_ConstantInt(&staticIndexValue));
-          if (isStaticIndex) {
-            assert(staticIndexValue.getBitWidth() <= 64 &&
-                   llvm::isInt<32>(staticIndexValue.getLimitedValue()) &&
-                   "struct index can't fit within int32_t");
-            gepIndex = static_cast<int32_t>(staticIndexValue.getSExtValue());
-          }
-        }
-        if (!isStaticIndex)
+        if (!indices[indexPos].is<IntegerAttr>())
           return llvm::make_error<GEPStaticIndexError>(indexPos);
 
+        int32_t gepIndex = indices[indexPos].get<IntegerAttr>().getInt();
         ArrayRef<Type> elementTypes = structType.getBody();
         if (gepIndex < 0 ||
             static_cast<size_t>(gepIndex) >= elementTypes.size())
           return llvm::make_error<GEPIndexOutOfBoundError>(indexPos);
 
-        if (updatedStructIndices)
-          (*updatedStructIndices)[indexPos] = gepIndex;
-
-        // Instead of recusively going into every children types, we only
+        // Instead of recursively going into every children types, we only
         // dive into the one indexed by gepIndex.
-        return recordStructIndices(elementTypes[gepIndex], indexPos + 1,
-                                   structIndices, indices, updatedStructIndices,
-                                   remainingIndices);
+        return verifyStructIndices(elementTypes[gepIndex], indexPos + 1,
+                                   indices);
       })
       .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
             LLVMArrayType>([&](auto containerType) -> llvm::Error {
-        // Currently we don't refine non-struct index even if it's static.
-        if (remainingIndices)
-          remainingIndices->push_back(indices[dynamicIndexPos]);
-        return recordStructIndices(containerType.getElementType(), indexPos + 1,
-                                   structIndices, indices, updatedStructIndices,
-                                   remainingIndices);
+        return verifyStructIndices(containerType.getElementType(), indexPos + 1,
+                                   indices);
       })
       .Default(
           [](auto otherType) -> llvm::Error { return llvm::Error::success(); });
@@ -537,122 +653,9 @@ recordStructIndices(Type baseGEPType, unsigned indexPos,
 
 /// Driver function around `recordStructIndices`. Note that we always check
 /// from the second GEP index since the first one is always dynamic.
-static llvm::Error
-findStructIndices(Type baseGEPType, ArrayRef<int32_t> structIndices,
-                  ValueRange indices,
-                  SmallVectorImpl<int32_t> *updatedStructIndices = nullptr,
-                  SmallVectorImpl<Value> *remainingIndices = nullptr) {
-  if (remainingIndices)
-    // The first GEP index is always dynamic.
-    remainingIndices->push_back(indices[0]);
-  return recordStructIndices(baseGEPType, /*indexPos=*/1, structIndices,
-                             indices, updatedStructIndices, remainingIndices);
-}
-
-void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
-                  Value basePtr, ValueRange operands,
-                  ArrayRef<NamedAttribute> attributes) {
-  build(builder, result, resultType, basePtr, operands,
-        SmallVector<int32_t>(operands.size(), kDynamicIndex), attributes);
-}
-
-/// Returns the elemental type of any LLVM-compatible vector type or self.
-static Type extractVectorElementType(Type type) {
-  if (auto vectorType = type.dyn_cast<VectorType>())
-    return vectorType.getElementType();
-  if (auto scalableVectorType = type.dyn_cast<LLVMScalableVectorType>())
-    return scalableVectorType.getElementType();
-  if (auto fixedVectorType = type.dyn_cast<LLVMFixedVectorType>())
-    return fixedVectorType.getElementType();
-  return type;
-}
-
-void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
-                  Type elementType, Value basePtr, ValueRange indices,
-                  ArrayRef<NamedAttribute> attributes) {
-  build(builder, result, resultType, elementType, basePtr, indices,
-        SmallVector<int32_t>(indices.size(), kDynamicIndex), attributes);
-}
-
-void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
-                  Value basePtr, ValueRange indices,
-                  ArrayRef<int32_t> structIndices,
-                  ArrayRef<NamedAttribute> attributes) {
-  auto ptrType =
-      extractVectorElementType(basePtr.getType()).cast<LLVMPointerType>();
-  assert(!ptrType.isOpaque() &&
-         "expected non-opaque pointer, provide elementType explicitly when "
-         "opaque pointers are used");
-  build(builder, result, resultType, ptrType.getElementType(), basePtr, indices,
-        structIndices, attributes);
-}
-
-void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
-                  Type elementType, Value basePtr, ValueRange indices,
-                  ArrayRef<int32_t> structIndices,
-                  ArrayRef<NamedAttribute> attributes) {
-  SmallVector<Value> remainingIndices;
-  SmallVector<int32_t> updatedStructIndices(structIndices.begin(),
-                                            structIndices.end());
-  if (llvm::Error err =
-          findStructIndices(elementType, structIndices, indices,
-                            &updatedStructIndices, &remainingIndices))
-    llvm::report_fatal_error(StringRef(llvm::toString(std::move(err))));
-
-  assert(remainingIndices.size() == static_cast<size_t>(llvm::count(
-                                        updatedStructIndices, kDynamicIndex)) &&
-         "expected as many index operands as dynamic index attr elements");
-
-  result.addTypes(resultType);
-  result.addAttributes(attributes);
-  result.addAttribute("structIndices",
-                      builder.getI32TensorAttr(updatedStructIndices));
-  if (extractVectorElementType(basePtr.getType())
-          .cast<LLVMPointerType>()
-          .isOpaque())
-    result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType));
-  result.addOperands(basePtr);
-  result.addOperands(remainingIndices);
-}
-
-static ParseResult
-parseGEPIndices(OpAsmParser &parser,
-                SmallVectorImpl<OpAsmParser::UnresolvedOperand> &indices,
-                DenseIntElementsAttr &structIndices) {
-  SmallVector<int32_t> constantIndices;
-
-  auto idxParser = [&]() -> ParseResult {
-    int32_t constantIndex;
-    OptionalParseResult parsedInteger =
-        parser.parseOptionalInteger(constantIndex);
-    if (parsedInteger.hasValue()) {
-      if (failed(parsedInteger.getValue()))
-        return failure();
-      constantIndices.push_back(constantIndex);
-      return success();
-    }
-
-    constantIndices.push_back(LLVM::GEPOp::kDynamicIndex);
-    return parser.parseOperand(indices.emplace_back());
-  };
-  if (parser.parseCommaSeparatedList(idxParser))
-    return failure();
-
-  structIndices = parser.getBuilder().getI32TensorAttr(constantIndices);
-  return success();
-}
-
-static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
-                            OperandRange indices,
-                            DenseIntElementsAttr structIndices) {
-  unsigned operandIdx = 0;
-  llvm::interleaveComma(structIndices.getValues<int32_t>(), printer,
-                        [&](int32_t cst) {
-                          if (cst == LLVM::GEPOp::kDynamicIndex)
-                            printer.printOperand(indices[operandIdx++]);
-                          else
-                            printer << cst;
-                        });
+static llvm::Error verifyStructIndices(Type baseGEPType,
+                                       GEPIndicesAdaptor<ValueRange> indices) {
+  return verifyStructIndices(baseGEPType, /*indexPos=*/1, indices);
 }
 
 LogicalResult LLVM::GEPOp::verify() {
@@ -662,14 +665,14 @@ LogicalResult LLVM::GEPOp::verify() {
           getElemType())))
     return failure();
 
-  auto structIndexRange = getStructIndices().getValues<int32_t>();
-  // structIndexRange is a kind of iterator, which cannot be converted
-  // to ArrayRef directly.
-  SmallVector<int32_t> structIndices(structIndexRange.size());
-  for (unsigned i : llvm::seq<unsigned>(0, structIndexRange.size()))
-    structIndices[i] = structIndexRange[i];
-  if (llvm::Error err = findStructIndices(getSourceElementType(), structIndices,
-                                          getIndices()))
+  if (static_cast<size_t>(
+          llvm::count(getRawConstantIndices(), kDynamicIndex)) !=
+      getDynamicIndices().size())
+    return emitOpError("expected as many dynamic indices as specified in '")
+           << getRawConstantIndicesAttrName().getValue() << "'";
+
+  if (llvm::Error err =
+          verifyStructIndices(getSourceElementType(), getIndices()))
     return emitOpError() << llvm::toString(std::move(err));
 
   return success();
@@ -2697,10 +2700,49 @@ OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) {
+  GEPIndicesAdaptor<ArrayRef<Attribute>> indices(getRawConstantIndicesAttr(),
+                                                 operands.drop_front());
+
   // gep %x:T, 0 -> %x
-  if (getBase().getType() == getType() && getIndices().size() == 1 &&
-      getStructIndices().size() == 1 && matchPattern(getIndices()[0], m_Zero()))
-    return getBase();
+  if (getBase().getType() == getType() && indices.size() == 1)
+    if (auto integer = indices[0].dyn_cast_or_null<IntegerAttr>())
+      if (integer.getValue().isZero())
+        return getBase();
+
+  // Canonicalize any dynamic indices of constant value to constant indices.
+  bool changed = false;
+  SmallVector<GEPArg> gepArgs;
+  for (auto &iter : llvm::enumerate(indices)) {
+    auto integer = iter.value().dyn_cast_or_null<IntegerAttr>();
+    // Constant indices can only be int32_t, so if integer does not fit we
+    // are forced to keep it dynamic, despite being a constant.
+    if (!indices.isDynamicIndex(iter.index()) || !integer ||
+        !integer.getValue().isSignedIntN(kGEPConstantBitWidth)) {
+
+      PointerUnion<IntegerAttr, Value> existing = getIndices()[iter.index()];
+      if (Value val = existing.dyn_cast<Value>())
+        gepArgs.emplace_back(val);
+      else
+        gepArgs.emplace_back(existing.get<IntegerAttr>().getInt());
+
+      continue;
+    }
+
+    changed = true;
+    gepArgs.emplace_back(integer.getInt());
+  }
+  if (changed) {
+    SmallVector<int32_t> rawConstantIndices;
+    SmallVector<Value> dynamicIndices;
+    destructureIndices(getSourceElementType(), gepArgs, rawConstantIndices,
+                       dynamicIndices);
+
+    getDynamicIndicesMutable().assign(dynamicIndices);
+    setRawConstantIndicesAttr(
+        DenseI32ArrayAttr::get(getContext(), rawConstantIndices));
+    return Value{*this};
+  }
+
   return {};
 }
 
diff  --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index 1d33ccd708d3..675fcbad5340 100644
--- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -1072,24 +1072,23 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
     Value basePtr = processValue(gep->getOperand(0));
     Type sourceElementType = processType(gep->getSourceElementType());
 
-    SmallVector<Value> indices;
-    for (llvm::Value *operand : llvm::drop_begin(gep->operand_values())) {
-      indices.push_back(processValue(operand));
-      if (!indices.back())
-        return failure();
-    }
     // Treat every indices as dynamic since GEPOp::build will refine those
     // indices into static attributes later. One small downside of this
     // approach is that many unused `llvm.mlir.constant` would be emitted
     // at first place.
-    SmallVector<int32_t> structIndices(indices.size(),
-                                       LLVM::GEPOp::kDynamicIndex);
+    SmallVector<GEPArg> indices;
+    for (llvm::Value *operand : llvm::drop_begin(gep->operand_values())) {
+      Value val = processValue(operand);
+      if (!val)
+        return failure();
+      indices.push_back(val);
+    }
 
     Type type = processType(inst->getType());
     if (!type)
       return failure();
-    instMap[inst] = b.create<GEPOp>(loc, type, sourceElementType, basePtr,
-                                    indices, structIndices);
+    instMap[inst] =
+        b.create<GEPOp>(loc, type, sourceElementType, basePtr, indices);
     return success();
   }
   case llvm::Instruction::InsertValue: {
diff  --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
index 459fa2f8d362..d989b6d4116b 100644
--- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir
+++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
@@ -102,8 +102,7 @@ llvm.func @fold_gep(%x : !llvm.ptr<i8>) -> !llvm.ptr<i8> {
 
 // CHECK-LABEL: fold_gep_neg
 // CHECK-SAME: %[[a0:arg[0-9]+]]
-// CHECK-NEXT: %[[C:.*]] = arith.constant 0
-// CHECK-NEXT: %[[RES:.*]] = llvm.getelementptr %[[a0]][%[[C]], 1]
+// CHECK-NEXT: %[[RES:.*]] = llvm.getelementptr %[[a0]][0, 1]
 // CHECK-NEXT: llvm.return %[[RES]]
 llvm.func @fold_gep_neg(%x : !llvm.ptr) -> !llvm.ptr {
   %c0 = arith.constant 0 : i32
@@ -111,6 +110,17 @@ llvm.func @fold_gep_neg(%x : !llvm.ptr) -> !llvm.ptr {
   llvm.return %0 : !llvm.ptr
 }
 
+// CHECK-LABEL: fold_gep_canon
+// CHECK-SAME: %[[a0:arg[0-9]+]]
+// CHECK-NEXT: %[[RES:.*]] = llvm.getelementptr %[[a0]][2]
+// CHECK-NEXT: llvm.return %[[RES]]
+llvm.func @fold_gep_canon(%x : !llvm.ptr<i8>) -> !llvm.ptr<i8> {
+  %c2 = arith.constant 2 : i32
+  %c = llvm.getelementptr %x[%c2] : (!llvm.ptr<i8>, i32) -> !llvm.ptr<i8>
+  llvm.return %c : !llvm.ptr<i8>
+}
+
+
 // -----
 
 // Check that LLVM constants participate in cross-dialect constant folding. The
diff  --git a/mlir/test/Dialect/LLVMIR/dynamic-gep-index.mlir b/mlir/test/Dialect/LLVMIR/dynamic-gep-index.mlir
index 52282bc586fb..b71fa84e0509 100644
--- a/mlir/test/Dialect/LLVMIR/dynamic-gep-index.mlir
+++ b/mlir/test/Dialect/LLVMIR/dynamic-gep-index.mlir
@@ -6,7 +6,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<i1, dense<8> : ve
     // CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i32)
     %0 = llvm.mlir.constant(0 : i32) : i32
     // CHECK: llvm.getelementptr %[[ARG0]][%[[C0]], 1, %[[ARG1]]]
-    %1 = "llvm.getelementptr"(%arg0, %0, %arg1) {structIndices = dense<[-2147483648, 1, -2147483648]> : tensor<3xi32>} : (!llvm.ptr<struct<"my_struct", (struct<"sub_struct", (i32, i8)>, array<4 x i32>)>>, i32, i32) -> !llvm.ptr<i32>
+    %1 = "llvm.getelementptr"(%arg0, %0, %arg1) {rawConstantIndices = [:i32 -2147483648, 1, -2147483648]} : (!llvm.ptr<struct<"my_struct", (struct<"sub_struct", (i32, i8)>, array<4 x i32>)>>, i32, i32) -> !llvm.ptr<i32>
     llvm.return
   }
 }
diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 70b7de516477..e5e15be3b90d 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -146,6 +146,13 @@ func.func @gep_non_function_type(%pos : i64, %base : !llvm.ptr<f32>) {
 
 // -----
 
+func.func @gep_too_few_dynamic(%base : !llvm.ptr<f32>) {
+  // expected-error at +1 {{expected as many dynamic indices as specified in 'rawConstantIndices'}}
+  %1 = "llvm.getelementptr"(%base) {rawConstantIndices = [:i32 -2147483648]} : (!llvm.ptr<f32>) -> !llvm.ptr<f32>
+}
+
+// -----
+
 func.func @load_non_llvm_type(%foo : memref<f32>) {
   // expected-error at +1 {{expected LLVM pointer type}}
   llvm.load %foo : memref<f32>
        
    
    
More information about the Mlir-commits
mailing list