[Mlir-commits] [mlir] c10f8bd - [mlir][LLVM] Add `SplitGEP` type-consistency pattern

Markus Böck llvmlistbot at llvm.org
Mon Jul 10 01:51:29 PDT 2023


Author: Markus Böck
Date: 2023-07-10T10:45:42+02:00
New Revision: c10f8bd6c3ec30b68900b6fac2f0813c338d99e6

URL: https://github.com/llvm/llvm-project/commit/c10f8bd6c3ec30b68900b6fac2f0813c338d99e6
DIFF: https://github.com/llvm/llvm-project/commit/c10f8bd6c3ec30b68900b6fac2f0813c338d99e6.diff

LOG: [mlir][LLVM] Add `SplitGEP` type-consistency pattern

The goal of this pattern is to eliminate all GEPs that have more than two indices by splitting it into multiple GEPs.
The advantage of this change is that the resulting GEPs only ever index into one aggregate at the time. This enables handling sub-aggregates in other patterns and also creates IR with easier to deduce pointer element types.

As a minor note, `getResultPtrElementType` for `GEPOp` was rewritten since it did not properly handle dynamic indices. The way GEPOp is specified, the resulting pointer element type can *always* be deduced from its base type and indices.

Differential Revision: https://reviews.llvm.org/D154692

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
    mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
    mlir/test/Dialect/LLVMIR/type-consistency.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
index 417fb04a829dcb..b32ac56d7079c6 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
@@ -79,6 +79,17 @@ class BitcastStores : public OpRewritePattern<StoreOp> {
                                 PatternRewriter &rewriter) const override;
 };
 
+/// Splits GEPs with more than two indices into multiple GEPs with exactly
+/// two indices. The created GEPs are then guaranteed to index into only
+/// one aggregate at a time.
+class SplitGEP : public OpRewritePattern<GEPOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GEPOp gepOp,
+                                PatternRewriter &rewriter) const override;
+};
+
 } // namespace LLVM
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 82cbee3d548d1f..708f916125b34e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -778,34 +778,27 @@ Type LLVM::GEPOp::getSourceElementType() {
 }
 
 Type GEPOp::getResultPtrElementType() {
-  // Ensures all indices are static and fetches them.
-  SmallVector<IntegerAttr> indices;
-  for (auto index : getIndices()) {
-    IntegerAttr indexInt = llvm::dyn_cast_if_present<IntegerAttr>(index);
-    if (!indexInt)
-      return nullptr;
-    indices.push_back(indexInt);
-  }
-
   // Set the initial type currently being used for indexing. This will be
   // updated as the indices get walked over.
   Type selectedType = getSourceElementType();
 
   // Follow the indexed elements in the gep.
-  for (IntegerAttr index : llvm::drop_begin(indices)) {
-    // Ensure the structure of the type being indexed can be reasoned about.
-    // This includes rejecting any potential typed pointer.
-    auto destructurable =
-        llvm::dyn_cast<DestructurableTypeInterface>(selectedType);
-    if (!destructurable)
-      return nullptr;
-
-    // Follow the type at the index the gep is accessing, making it the new type
-    // used for indexing.
-    Type field = destructurable.getTypeAtIndex(index);
-    if (!field)
-      return nullptr;
-    selectedType = field;
+  auto indices = getIndices();
+  for (GEPIndicesAdaptor<ValueRange>::value_type index :
+       llvm::drop_begin(indices)) {
+    // GEPs can only index into aggregates which can be structs or arrays.
+
+    // The resulting type if indexing into an array type is always the element
+    // type, regardless of index.
+    if (auto arrayType = dyn_cast<LLVMArrayType>(selectedType)) {
+      selectedType = arrayType.getElementType();
+      continue;
+    }
+
+    // The GEP verifier ensures that any index into structs are static and
+    // that they refer to a field within the struct.
+    selectedType = cast<DestructurableTypeInterface>(selectedType)
+                       .getTypeAtIndex(cast<IntegerAttr>(index));
   }
 
   // When there are no more indices, the type currently being used for indexing

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index a324712734f53a..2cdc5d9fff84e7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -349,7 +349,9 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
   Type reachedType = getResultPtrElementType();
   if (!reachedType || getIndices().size() < 2)
     return false;
-  auto firstLevelIndex = cast<IntegerAttr>(getIndices()[1]);
+  auto firstLevelIndex = dyn_cast<IntegerAttr>(getIndices()[1]);
+  if (!firstLevelIndex)
+    return false;
   assert(slot.elementPtrs.contains(firstLevelIndex));
   if (!llvm::isa<LLVM::LLVMPointerType>(slot.elementPtrs.at(firstLevelIndex)))
     return false;

diff  --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
index 3f581cf9bd7092..aad6e19f9f9c24 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
@@ -565,6 +565,46 @@ LogicalResult BitcastStores::matchAndRewrite(StoreOp store,
   return success();
 }
 
+LogicalResult SplitGEP::matchAndRewrite(GEPOp gepOp,
+                                        PatternRewriter &rewriter) const {
+  FailureOr<Type> typeHint = getRequiredConsistentGEPType(gepOp);
+  if (succeeded(typeHint) || gepOp.getIndices().size() <= 2) {
+    // GEP is not canonical or a single aggregate deep, nothing to do here.
+    return failure();
+  }
+
+  auto indexToGEPArg =
+      [](GEPIndicesAdaptor<ValueRange>::value_type index) -> GEPArg {
+    if (auto integerAttr = dyn_cast<IntegerAttr>(index))
+      return integerAttr.getValue().getSExtValue();
+    return cast<Value>(index);
+  };
+
+  GEPIndicesAdaptor<ValueRange> indices = gepOp.getIndices();
+
+  auto splitIter = std::next(indices.begin(), 2);
+
+  // Split of the first GEP using the first two indices.
+  auto subGepOp = rewriter.create<GEPOp>(
+      gepOp.getLoc(), gepOp.getType(), gepOp.getSourceElementType(),
+      gepOp.getBase(),
+      llvm::map_to_vector(llvm::make_range(indices.begin(), splitIter),
+                          indexToGEPArg),
+      gepOp.getInbounds());
+
+  // The second GEP indexes on the result pointer element type of the previous
+  // with all the remaining indices and a zero upfront. If this GEP has more
+  // than two indices remaining it'll be further split in subsequent pattern
+  // applications.
+  SmallVector<GEPArg> newIndices = {0};
+  llvm::transform(llvm::make_range(splitIter, indices.end()),
+                  std::back_inserter(newIndices), indexToGEPArg);
+  rewriter.replaceOpWithNewOp<GEPOp>(gepOp, gepOp.getType(),
+                                     subGepOp.getResultPtrElementType(),
+                                     subGepOp, newIndices, gepOp.getInbounds());
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Type consistency pass
 //===----------------------------------------------------------------------===//
@@ -580,6 +620,7 @@ struct LLVMTypeConsistencyPass
     rewritePatterns.add<CanonicalizeAlignedGep>(&getContext());
     rewritePatterns.add<SplitStores>(&getContext(), maxVectorSplitSize);
     rewritePatterns.add<BitcastStores>(&getContext());
+    rewritePatterns.add<SplitGEP>(&getContext());
     FrozenRewritePatternSet frozen(std::move(rewritePatterns));
 
     if (failed(applyPatternsAndFoldGreedily(getOperation(), frozen)))

diff  --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
index 489dfcb8a072b7..7c345d454f264c 100644
--- a/mlir/test/Dialect/LLVMIR/type-consistency.mlir
+++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
@@ -433,3 +433,63 @@ llvm.func @bitcast_insertion(%arg: i32) {
   // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
   llvm.return
 }
+
+// -----
+
+// CHECK-LABEL: llvm.func @gep_split
+// CHECK-SAME: %[[ARG:.*]]: i64
+llvm.func @gep_split(%arg: i64) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.array<2 x struct<"foo", (i64)>>
+  %1 = llvm.alloca %0 x !llvm.array<2 x struct<"foo", (i64)>> : (i32) -> !llvm.ptr
+  %3 = llvm.getelementptr %1[0, 1, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<2 x struct<"foo", (i64)>>
+  // CHECK: %[[TOP_GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<2 x struct<"foo", (i64)>>
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64)>
+  // CHECK: llvm.store %[[ARG]], %[[GEP]]
+  llvm.store %arg, %3 : i64, !llvm.ptr
+  // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @coalesced_store_ints_subaggregate
+// CHECK-SAME: %[[ARG:.*]]: i64
+llvm.func @coalesced_store_ints_subaggregate(%arg: i64) {
+  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i64, struct<(i32, i32)>)>
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i64, struct<(i32, i32)>)> : (i32) -> !llvm.ptr
+  %3 = llvm.getelementptr %1[0, 1, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, struct<(i32, i32)>)>
+
+  // CHECK: %[[TOP_GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, struct<(i32, i32)>)>
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i32, i32)>
+  // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]]
+  // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
+  // CHECK: llvm.store %[[TRUNC]], %[[GEP]]
+  // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64
+  // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 1] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<(i32, i32)>
+  // CHECK: llvm.store %[[TRUNC]], %[[GEP]]
+  llvm.store %arg, %3 : i64, !llvm.ptr
+  // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @gep_result_ptr_type_dynamic
+// CHECK-SAME: %[[ARG:.*]]: i64
+llvm.func @gep_result_ptr_type_dynamic(%arg: i64) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.array<2 x struct<"foo", (i64)>>
+  %1 = llvm.alloca %0 x !llvm.array<2 x struct<"foo", (i64)>> : (i32) -> !llvm.ptr
+  %3 = llvm.getelementptr %1[0, %arg, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.array<2 x struct<"foo", (i64)>>
+  // CHECK: %[[TOP_GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, %[[ARG]]] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.array<2 x struct<"foo", (i64)>>
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64)>
+  // CHECK: llvm.store %[[ARG]], %[[GEP]]
+  llvm.store %arg, %3 : i64, !llvm.ptr
+  // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
+  llvm.return
+}


        


More information about the Mlir-commits mailing list