[Mlir-commits] [mlir] fb047e4 - [mlir][llvm] Type consistency transformations

Théo Degioanni llvmlistbot at llvm.org
Mon Jul 3 02:14:02 PDT 2023


Author: Théo Degioanni
Date: 2023-07-03T11:12:39+02:00
New Revision: fb047e4ae1b2a17ed068d560f94cc8e9fd94d0e6

URL: https://github.com/llvm/llvm-project/commit/fb047e4ae1b2a17ed068d560f94cc8e9fd94d0e6
DIFF: https://github.com/llvm/llvm-project/commit/fb047e4ae1b2a17ed068d560f94cc8e9fd94d0e6.diff

LOG: [mlir][llvm] Type consistency transformations

This revision introduces new rewrites to improve the type consistency of
a program expressed in the LLVM dialect.

Type consistency means that a given opaque pointer is consistently used
assuming the same pointee type, in a best effort basis. The introduced
rewrites modify the program to improve type consistency while preserving
the same semantics. This is useful for two main reasons:

- Transformation passes in the LLVM dialect like SROA or Mem2Reg can
  analyse code better if type information and structure is used in a
  consistent manner. Opaque pointers make this difficult to enforce, but
  type consistency improvements increase the amount of occurences where
  reasonable analysis can pick up on transformable patterns.
- While LLVM IR is not particularly picky about inconsistent type uses,
  it may be of interest to lift LLVM IR into higher level dialects.
  Having more instances of consistent type information would help
  lifting into dialects that do care about consistent types.

In order to detect cases of inconsistent uses, operations returning an
LLVMPointer can implement the GetResultPtrElementType interface, which
allows getting a hint of which type the provided pointer should see its
pointee as, if such hint is available. The provided rewrites will then
use this hint to attempt to modify operations using the pointers so they
use the hinted type when dealing with the pointer.

Two transformations have been implemented in this revision:

- When a load/store uses a struct ptr directly to write to the first
  element of the struct, introduce a GEP to the first element so the
  type structure is preserved.
- When a GEP statically indexes in a pointer with a base type
  inconsistent with the hint, try to find indices using the hint as a
  base type that would result in the same offset, and replace the GEP
  with this indexing.

More transformations are possible and I hope this is only a beginning
for this simplification effort.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D153973

Added: 
    mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
    mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
    mlir/test/Dialect/LLVMIR/type-consistency.mlir

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h
    mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
    mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index dc5d86d8bac862..9f230bf0be87ea 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -167,6 +167,30 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
   ];
 }
 
+def GetResultPtrElementType : OpInterface<"GetResultPtrElementType"> {
+  let description = [{
+    An interface for operations that yield an LLVMPointer. Allows the
+    operation to provide the type of the element an LLVMPointer points to,
+    if known. This is only a hint as to how to interpret a given pointer,
+    translating how the current operation understands it.
+  }];
+
+  let cppNamespace = "::mlir::LLVM";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        [{Returns the the element type hint of the result
+                         LLVMPointer, if known. Returns nullptr if the
+                         requested result is not an LLVMPointer or if the
+                         element type is unknown.}],
+      /*returnType=*/  "Type",
+      /*methodName=*/  "getResultPtrElementType",
+      /*args=*/        (ins)
+      >
+  ];
+}
+
+
 //===----------------------------------------------------------------------===//
 // LLVM dialect type interfaces.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 43f5b633e251b6..993748fe629264 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -178,7 +178,8 @@ def LLVM_FNegOp : LLVM_UnaryFloatArithmeticOp<
 // Memory-related operations.
 def LLVM_AllocaOp : LLVM_Op<"alloca",
     [DeclareOpInterfaceMethods<PromotableAllocationOpInterface>,
-     DeclareOpInterfaceMethods<DestructurableAllocationOpInterface>]>,
+     DeclareOpInterfaceMethods<DestructurableAllocationOpInterface>,
+     DeclareOpInterfaceMethods<GetResultPtrElementType>]>,
   LLVM_MemOpPatterns {
   let arguments = (ins AnyInteger:$arraySize,
                    OptionalAttr<I64Attr>:$alignment,
@@ -239,7 +240,8 @@ def LLVM_AllocaOp : LLVM_Op<"alloca",
 def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure,
     DeclareOpInterfaceMethods<PromotableOpInterface>,
     DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>,
-    DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
+    DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
+    DeclareOpInterfaceMethods<GetResultPtrElementType>]> {
   let arguments = (ins LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$base,
                    Variadic<LLVM_ScalarOrVectorOf<AnyInteger>>:$dynamicIndices,
                    DenseI32ArrayAttr:$rawConstantIndices,

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h
index eefbbcbf66a5ca..7e61bd2419d650 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
 #include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h"
 #include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h"
+#include "mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
index dcf6615f889fd3..e967cea6ae5186 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
@@ -30,6 +30,17 @@ def LLVMRequestCWrappers
   let constructor = "::mlir::LLVM::createRequestCWrappersPass()";
 }
 
+def LLVMTypeConsistency
+    : Pass<"llvm-type-consistency", "::mlir::LLVM::LLVMFuncOp"> {
+  let summary = "Rewrites to improve type consistency";
+  let description = [{
+    Set of rewrites to improve the coherency of types within an LLVM dialect
+    program. This will adjust operations operating on pointers so they interpret
+    their associated pointee type as consistently as possible.
+  }];
+  let constructor = "::mlir::LLVM::createTypeConsistencyPass()";
+}
+
 def NVVMOptimizeForTarget : Pass<"llvm-optimize-for-nvvm-target"> {
   let summary = "Optimize NVVM IR";
   let constructor = "::mlir::NVVM::createOptimizeForTargetPass()";

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
new file mode 100644
index 00000000000000..469feedc477d41
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
@@ -0,0 +1,59 @@
+//===- TypeConsistency.h - Rewrites to improve type consistency -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Set of rewrites to improve the coherency of types within an LLVM dialect
+// program. This will adjust operations around a given pointer so they interpret
+// its pointee type as consistently as possible.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_TYPECONSISTENCY_H
+#define MLIR_DIALECT_LLVMIR_TRANSFORMS_TYPECONSISTENCY_H
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace LLVM {
+
+#define GEN_PASS_DECL_LLVMTYPECONSISTENCY
+#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"
+
+/// Creates a pass that adjusts operations operating on pointers so they
+/// interpret pointee types as consistently as possible.
+std::unique_ptr<Pass> createTypeConsistencyPass();
+
+/// Transforms uses of pointers to a whole struct to uses of pointers to the
+/// first element of a struct. This is achieved by inserting a GEP to the first
+/// element when possible.
+template <class User>
+class AddFieldGetterToStructDirectUse : public OpRewritePattern<User> {
+public:
+  using OpRewritePattern<User>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(User user,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Canonicalizes GEPs of which the base type and the pointer's type hint do not
+/// match. This is done by replacing the original GEP into a GEP with the type
+/// hint as a base type when an element of the hinted type aligns with the
+/// original GEP.
+class CanonicalizeAlignedGep : public OpRewritePattern<GEPOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GEPOp gep,
+                                PatternRewriter &rewriter) const override;
+};
+
+} // namespace LLVM
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_TYPECONSISTENCY_H

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 31a2b93c9521df..f914c044d7e75a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -281,6 +281,12 @@ LogicalResult AllocaOp::verify() {
   return success();
 }
 
+Type AllocaOp::getResultPtrElementType() {
+  // This will become trivial once non-opaque pointers are gone.
+  return getElemType().has_value() ? *getElemType()
+                                   : getResult().getType().getElementType();
+}
+
 //===----------------------------------------------------------------------===//
 // LLVM::BrOp
 //===----------------------------------------------------------------------===//
@@ -771,6 +777,42 @@ Type LLVM::GEPOp::getSourceElementType() {
       .getElementType();
 }
 
+Type GEPOp::getResultPtrElementType() {
+  // Ensures all indices are static and fetches them.
+  SmallVector<IntegerAttr> indices;
+  for (auto index : getIndices()) {
+    IntegerAttr indexInt = llvm::dyn_cast_if_present<IntegerAttr>(index);
+    if (!indexInt)
+      return nullptr;
+    indices.push_back(indexInt);
+  }
+
+  // Set the initial type currently being used for indexing. This will be
+  // updated as the indices get walked over.
+  Type selectedType = getSourceElementType();
+
+  // Follow the indexed elements in the gep.
+  for (IntegerAttr index : llvm::drop_begin(indices)) {
+    // Ensure the structure of the type being indexed can be reasoned about.
+    // This includes rejecting any potential typed pointer.
+    auto destructurable =
+        llvm::dyn_cast<DestructurableTypeInterface>(selectedType);
+    if (!destructurable)
+      return nullptr;
+
+    // Follow the type at the index the gep is accessing, making it the new type
+    // used for indexing.
+    Type field = destructurable.getTypeAtIndex(index);
+    if (!field)
+      return nullptr;
+    selectedType = field;
+  }
+
+  // When there are no more indices, the type currently being used for indexing
+  // is the type of the value pointed at by the returned indexed pointer.
+  return selectedType;
+}
+
 //===----------------------------------------------------------------------===//
 // LoadOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 43d17fe24c3a07..8248a1e08a25ce 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -28,18 +28,11 @@ using namespace mlir;
 // Interfaces for AllocaOp
 //===----------------------------------------------------------------------===//
 
-static Type getAllocaElementType(LLVM::AllocaOp alloca) {
-  // This will become trivial once non-opaque pointers are gone.
-  return alloca.getElemType().has_value()
-             ? *alloca.getElemType()
-             : alloca.getResult().getType().getElementType();
-}
-
 llvm::SmallVector<MemorySlot> LLVM::AllocaOp::getPromotableSlots() {
   if (!getOperation()->getBlock()->isEntryBlock())
     return {};
 
-  return {MemorySlot{getResult(), getAllocaElementType(*this)}};
+  return {MemorySlot{getResult(), getResultPtrElementType()}};
 }
 
 Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot,
@@ -68,7 +61,7 @@ SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
   if (!mlir::matchPattern(getArraySize(), m_One()))
     return {};
 
-  Type elemType = getAllocaElementType(*this);
+  Type elemType = getResultPtrElementType();
   auto destructurable = dyn_cast<DestructurableTypeInterface>(elemType);
   if (!destructurable)
     return {};
@@ -292,52 +285,10 @@ DeletionKind LLVM::GEPOp::removeBlockingUses(
   return DeletionKind::Delete;
 }
 
-/// Returns the type the resulting pointer of the GEP points to. If such a type
-/// is not clear, returns null type.
-static Type computeReachedGEPType(LLVM::GEPOp gep) {
-  if (gep.getIndices().empty())
-    return {};
-
-  // Ensures all indices are static and fetches them.
-  SmallVector<IntegerAttr> indices;
-  for (auto index : gep.getIndices()) {
-    IntegerAttr indexInt = llvm::dyn_cast_if_present<IntegerAttr>(index);
-    if (!indexInt)
-      return {};
-    indices.push_back(indexInt);
-  }
-
-  // Check the pointer indexing only targets the first element.
-  if (indices[0].getInt() != 0)
-    return {};
-
-  // Set the initial type currently being used for indexing. This will be
-  // updated as the indices get walked over.
-  std::optional<Type> maybeSelectedType = gep.getElemType();
-  if (!maybeSelectedType)
-    return {};
-  Type selectedType = *maybeSelectedType;
-
-  // Follow the indexed elements in the gep.
-  for (IntegerAttr index : llvm::drop_begin(indices)) {
-    // Ensure the structure of the type being indexed can be reasoned about.
-    // This includes rejecting any potential typed pointer.
-    auto destructurable =
-        llvm::dyn_cast<DestructurableTypeInterface>(selectedType);
-    if (!destructurable)
-      return {};
-
-    // Follow the type at the index the gep is accessing, making it the new type
-    // used for indexing.
-    Type field = destructurable.getTypeAtIndex(index);
-    if (!field)
-      return {};
-    selectedType = field;
-  }
-
-  // When there are no more indices, the type currently being used for indexing
-  // is the type of the value pointed at by the returned indexed pointer.
-  return selectedType;
+static bool isFirstIndexZero(LLVM::GEPOp gep) {
+  IntegerAttr index =
+      llvm::dyn_cast_if_present<IntegerAttr>(gep.getIndices()[0]);
+  return index && index.getInt() == 0;
 }
 
 LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses(
@@ -346,7 +297,9 @@ LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses(
     return success();
   if (slot.elemType != getElemType())
     return failure();
-  Type reachedType = computeReachedGEPType(*this);
+  if (!isFirstIndexZero(*this))
+    return failure();
+  Type reachedType = getResultPtrElementType();
   if (!reachedType)
     return failure();
   mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType});
@@ -367,7 +320,9 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
 
   if (getBase() != slot.ptr || slot.elemType != getElemType())
     return false;
-  Type reachedType = computeReachedGEPType(*this);
+  if (!isFirstIndexZero(*this))
+    return false;
+  Type reachedType = getResultPtrElementType();
   if (!reachedType || getIndices().size() < 2)
     return false;
   auto firstLevelIndex = cast<IntegerAttr>(getIndices()[1]);

diff  --git a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
index e5a9446dfd2fdb..fac33b29a511c8 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRLLVMIRTransforms
   LegalizeForExport.cpp
   OptimizeForNVVM.cpp
   RequestCWrappers.cpp
+  TypeConsistency.cpp
 
   DEPENDS
   MLIRLLVMPassIncGen

diff  --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
new file mode 100644
index 00000000000000..f696625f486e77
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
@@ -0,0 +1,371 @@
+//===- TypeConsistency.cpp - Rewrites to improve type consistency ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+namespace LLVM {
+#define GEN_PASS_DEF_LLVMTYPECONSISTENCY
+#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"
+} // namespace LLVM
+} // namespace mlir
+
+using namespace mlir;
+using namespace LLVM;
+
+//===----------------------------------------------------------------------===//
+// Utils
+//===----------------------------------------------------------------------===//
+
+/// Checks that a pointer value has a pointee type hint consistent with the
+/// expected type. Returns the type it actually hints to if it 
diff ers, or
+/// nullptr if the type is consistent or impossible to analyze.
+static Type isElementTypeInconsistent(Value addr, Type expectedType) {
+  auto defOp = dyn_cast_or_null<GetResultPtrElementType>(addr.getDefiningOp());
+  if (!defOp)
+    return nullptr;
+
+  Type elemType = defOp.getResultPtrElementType();
+  if (!elemType)
+    return nullptr;
+
+  if (elemType == expectedType)
+    return nullptr;
+
+  return elemType;
+}
+
+/// Checks that two types are the same or can be bitcast into one another.
+static bool areCastCompatible(DataLayout &layout, Type lhs, Type rhs) {
+  return lhs == rhs || (!isa<LLVMStructType, LLVMArrayType>(lhs) &&
+                        !isa<LLVMStructType, LLVMArrayType>(rhs) &&
+                        layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
+}
+
+//===----------------------------------------------------------------------===//
+// AddFieldGetterToStructDirectUse
+//===----------------------------------------------------------------------===//
+
+/// Gets the type of the first subelement of `type` if `type` is destructurable,
+/// nullptr otherwise.
+static Type getFirstSubelementType(Type type) {
+  auto destructurable = dyn_cast<DestructurableTypeInterface>(type);
+  if (!destructurable)
+    return nullptr;
+
+  Type subelementType = destructurable.getTypeAtIndex(
+      IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0));
+  if (subelementType)
+    return subelementType;
+
+  return nullptr;
+}
+
+/// Extracts a pointer to the first field of an `elemType` from the address
+/// pointer of the provided MemOp, and rewires the MemOp so it uses that pointer
+/// instead.
+template <class MemOp>
+static void insertFieldIndirection(MemOp op, PatternRewriter &rewriter,
+                                   Type elemType) {
+  PatternRewriter::InsertionGuard guard(rewriter);
+
+  rewriter.setInsertionPointAfterValue(op.getAddr());
+  SmallVector<GEPArg> firstTypeIndices{0, 0};
+
+  Value properPtr = rewriter.create<GEPOp>(
+      op->getLoc(), LLVM::LLVMPointerType::get(op.getContext()), elemType,
+      op.getAddr(), firstTypeIndices);
+
+  rewriter.updateRootInPlace(op,
+                             [&]() { op.getAddrMutable().assign(properPtr); });
+}
+
+template <>
+LogicalResult AddFieldGetterToStructDirectUse<LoadOp>::matchAndRewrite(
+    LoadOp load, PatternRewriter &rewriter) const {
+  PatternRewriter::InsertionGuard guard(rewriter);
+
+  // Load from typed pointers are not supported.
+  if (!load.getAddr().getType().isOpaque())
+    return failure();
+
+  Type inconsistentElementType =
+      isElementTypeInconsistent(load.getAddr(), load.getType());
+  if (!inconsistentElementType)
+    return failure();
+  Type firstType = getFirstSubelementType(inconsistentElementType);
+  if (!firstType)
+    return failure();
+  DataLayout layout = DataLayout::closest(load);
+  if (!areCastCompatible(layout, firstType, load.getResult().getType()))
+    return failure();
+
+  insertFieldIndirection<LoadOp>(load, rewriter, inconsistentElementType);
+
+  // If the load does not use the first type but a type that can be casted from
+  // it, add a bitcast and change the load type.
+  if (firstType != load.getResult().getType()) {
+    rewriter.setInsertionPointAfterValue(load.getResult());
+    BitcastOp bitcast = rewriter.create<BitcastOp>(
+        load->getLoc(), load.getResult().getType(), load.getResult());
+    rewriter.updateRootInPlace(load,
+                               [&]() { load.getResult().setType(firstType); });
+    rewriter.replaceAllUsesExcept(load.getResult(), bitcast.getResult(),
+                                  bitcast);
+  }
+
+  return success();
+}
+
+template <>
+LogicalResult AddFieldGetterToStructDirectUse<StoreOp>::matchAndRewrite(
+    StoreOp store, PatternRewriter &rewriter) const {
+  PatternRewriter::InsertionGuard guard(rewriter);
+
+  // Store to typed pointers are not supported.
+  if (!store.getAddr().getType().isOpaque())
+    return failure();
+
+  Type inconsistentElementType =
+      isElementTypeInconsistent(store.getAddr(), store.getValue().getType());
+  if (!inconsistentElementType)
+    return failure();
+  Type firstType = getFirstSubelementType(inconsistentElementType);
+  if (!firstType)
+    return failure();
+
+  DataLayout layout = DataLayout::closest(store);
+  // Check that the first field has the right type or can at least be bitcast
+  // to the right type.
+  if (!areCastCompatible(layout, firstType, store.getValue().getType()))
+    return failure();
+
+  insertFieldIndirection<StoreOp>(store, rewriter, inconsistentElementType);
+
+  Value replaceValue = store.getValue();
+  if (firstType != store.getValue().getType()) {
+    rewriter.setInsertionPointAfterValue(store.getValue());
+    replaceValue = rewriter.create<BitcastOp>(store->getLoc(), firstType,
+                                              store.getValue());
+  }
+
+  rewriter.updateRootInPlace(
+      store, [&]() { store.getValueMutable().assign(replaceValue); });
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// CanonicalizeAlignedGep
+//===----------------------------------------------------------------------===//
+
+/// Returns the amount of bytes the provided GEP elements will offset the
+/// pointer by. Returns nullopt if the offset could not be computed.
+static std::optional<uint64_t> gepToByteOffset(DataLayout &layout, Type base,
+                                               ArrayRef<uint32_t> indices) {
+  uint64_t offset = indices[0] * layout.getTypeSize(base);
+
+  Type currentType = base;
+  for (uint32_t index : llvm::drop_begin(indices)) {
+    bool shouldCancel =
+        TypeSwitch<Type, bool>(currentType)
+            .Case([&](LLVMArrayType arrayType) {
+              if (arrayType.getNumElements() <= index)
+                return true;
+              offset += index * layout.getTypeSize(arrayType.getElementType());
+              currentType = arrayType.getElementType();
+              return false;
+            })
+            .Case([&](LLVMStructType structType) {
+              ArrayRef<Type> body = structType.getBody();
+              if (body.size() <= index)
+                return true;
+              for (uint32_t i = 0; i < index; i++) {
+                if (!structType.isPacked())
+                  offset = llvm::alignTo(offset,
+                                         layout.getTypeABIAlignment(body[i]));
+                offset += layout.getTypeSize(body[i]);
+              }
+              currentType = body[index];
+              return false;
+            })
+            .Default([](Type) { return true; });
+
+    if (shouldCancel)
+      return std::nullopt;
+  }
+
+  return offset;
+}
+
+/// Fills in `equivalentIndicesOut` with GEP indices that would be equivalent to
+/// offsetting a pointer by `offset` bytes, assuming the GEP has `base` as base
+/// type.
+static LogicalResult
+findIndicesForOffset(DataLayout &layout, Type base, uint64_t offset,
+                     SmallVectorImpl<GEPArg> &equivalentIndicesOut) {
+
+  uint64_t baseSize = layout.getTypeSize(base);
+  uint64_t rootIndex = offset / baseSize;
+  if (rootIndex > std::numeric_limits<uint32_t>::max())
+    return failure();
+  equivalentIndicesOut.push_back(rootIndex);
+
+  uint64_t distanceToStart = rootIndex * baseSize;
+
+#ifndef NDEBUG
+  auto isWithinCurrentType = [&](Type currentType) {
+    return offset < distanceToStart + layout.getTypeSize(currentType);
+  };
+#endif
+
+  Type currentType = base;
+  while (distanceToStart < offset) {
+    // While an index that does not perfectly align with offset has not been
+    // reached...
+
+    assert(isWithinCurrentType(currentType));
+
+    bool shouldCancel =
+        TypeSwitch<Type, bool>(currentType)
+            .Case([&](LLVMArrayType arrayType) {
+              // Find which element of the array contains the offset.
+              uint64_t elemSize =
+                  layout.getTypeSize(arrayType.getElementType());
+              uint64_t index = (offset - distanceToStart) / elemSize;
+              equivalentIndicesOut.push_back(index);
+              distanceToStart += index * elemSize;
+
+              // Then, try to find where in the element the offset is. If the
+              // offset is exactly the beginning of the element, the loop is
+              // complete.
+              currentType = arrayType.getElementType();
+
+              // Only continue if the element in question can be indexed using
+              // an i32.
+              return index > std::numeric_limits<uint32_t>::max();
+            })
+            .Case([&](LLVMStructType structType) {
+              ArrayRef<Type> body = structType.getBody();
+              uint32_t index = 0;
+
+              // Walk over the elements of the struct to find in which of them
+              // the offset is.
+              for (Type elem : body) {
+                uint64_t elemSize = layout.getTypeSize(elem);
+                if (!structType.isPacked()) {
+                  distanceToStart = llvm::alignTo(
+                      distanceToStart, layout.getTypeABIAlignment(elem));
+                  // If the offset is in padding, cancel the rewrite.
+                  if (offset < distanceToStart)
+                    return true;
+                }
+
+                if (offset < distanceToStart + elemSize) {
+                  // The offset is within this element, stop iterating the
+                  // struct and look within the current element.
+                  equivalentIndicesOut.push_back(index);
+                  currentType = elem;
+                  return false;
+                }
+
+                // The offset is not within this element, continue walking over
+                // the struct.
+                distanceToStart += elemSize;
+                index++;
+              }
+
+              // The offset was supposed to be within this struct but is not.
+              // This can happen if the offset points into final padding.
+              // Anyway, nothing can be done.
+              return true;
+            })
+            .Default([](Type) {
+              // If the offset is within a type that cannot be split, no indices
+              // will yield this offset. This can happen if the offset is not
+              // perfectly aligned with a leaf type.
+              // TODO: support vectors.
+              return true;
+            });
+
+    if (shouldCancel)
+      return failure();
+  }
+
+  return success();
+}
+
+LogicalResult
+CanonicalizeAlignedGep::matchAndRewrite(GEPOp gep,
+                                        PatternRewriter &rewriter) const {
+  // GEP of typed pointers are not supported.
+  if (!gep.getElemType())
+    return failure();
+
+  std::optional<Type> maybeBaseType = gep.getElemType();
+  if (!maybeBaseType)
+    return failure();
+  Type baseType = *maybeBaseType;
+
+  Type typeHint = isElementTypeInconsistent(gep.getBase(), baseType);
+  if (!typeHint)
+    return failure();
+
+  SmallVector<uint32_t> indices;
+  // Ensures all indices are static and fetches them.
+  for (auto index : gep.getIndices()) {
+    IntegerAttr indexInt = llvm::dyn_cast_if_present<IntegerAttr>(index);
+    if (!indexInt)
+      return failure();
+    indices.push_back(indexInt.getInt());
+  }
+
+  DataLayout layout = DataLayout::closest(gep);
+  std::optional<uint64_t> desiredOffset =
+      gepToByteOffset(layout, gep.getSourceElementType(), indices);
+  if (!desiredOffset)
+    return failure();
+
+  SmallVector<GEPArg> newIndices;
+  if (failed(
+          findIndicesForOffset(layout, typeHint, *desiredOffset, newIndices)))
+    return failure();
+
+  rewriter.replaceOpWithNewOp<GEPOp>(
+      gep, LLVM::LLVMPointerType::get(getContext()), typeHint, gep.getBase(),
+      newIndices, gep.getInbounds());
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Type consistency pass
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct LLVMTypeConsistencyPass
+    : public LLVM::impl::LLVMTypeConsistencyBase<LLVMTypeConsistencyPass> {
+  void runOnOperation() override {
+    RewritePatternSet rewritePatterns(&getContext());
+    rewritePatterns.add<AddFieldGetterToStructDirectUse<LoadOp>>(&getContext());
+    rewritePatterns.add<AddFieldGetterToStructDirectUse<StoreOp>>(
+        &getContext());
+    rewritePatterns.add<CanonicalizeAlignedGep>(&getContext());
+    FrozenRewritePatternSet frozen(std::move(rewritePatterns));
+
+    if (failed(applyPatternsAndFoldGreedily(getOperation(), frozen)))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> LLVM::createTypeConsistencyPass() {
+  return std::make_unique<LLVMTypeConsistencyPass>();
+}

diff  --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
new file mode 100644
index 00000000000000..f8cfca90921826
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
@@ -0,0 +1,150 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(llvm-type-consistency))" --split-input-file | FileCheck %s
+
+// CHECK-LABEL: llvm.func @same_address
+llvm.func @same_address(%arg: i32) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)>
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
+  // CHECK: = llvm.getelementptr %[[ALLOCA]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)>
+  %7 = llvm.getelementptr %1[8] : (!llvm.ptr) -> !llvm.ptr, i8
+  llvm.store %arg, %7 : i32, !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @same_address_keep_inbounds
+llvm.func @same_address_keep_inbounds(%arg: i32) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)>
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
+  // CHECK: = llvm.getelementptr inbounds %[[ALLOCA]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)>
+  %7 = llvm.getelementptr inbounds %1[8] : (!llvm.ptr) -> !llvm.ptr, i8
+  llvm.store %arg, %7 : i32, !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @struct_store_instead_of_first_field
+llvm.func @struct_store_instead_of_first_field(%arg: i32) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)>
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)>
+  // CHECK: llvm.store %{{.*}}, %[[GEP]] : i32
+  llvm.store %arg, %1 : i32, !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @struct_store_instead_of_first_field_same_size
+// CHECK-SAME: (%[[ARG:.*]]: f32)
+llvm.func @struct_store_instead_of_first_field_same_size(%arg: f32) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)>
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
+  // CHECK-DAG: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)>
+  // CHECK-DAG: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : f32 to i32
+  // CHECK: llvm.store %[[BITCAST]], %[[GEP]] : i32
+  llvm.store %arg, %1 : f32, !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @struct_load_instead_of_first_field
+llvm.func @struct_load_instead_of_first_field() -> i32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)>
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)>
+  // CHECK: %[[RES:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> i32
+  %2 = llvm.load %1 : !llvm.ptr -> i32
+  // CHECK: llvm.return %[[RES]] : i32
+  llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @struct_load_instead_of_first_field_same_size
+llvm.func @struct_load_instead_of_first_field_same_size() -> f32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)>
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)>
+  // CHECK: %[[LOADED:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> i32
+  // CHECK: %[[RES:.*]] = llvm.bitcast %[[LOADED]] : i32 to f32
+  %2 = llvm.load %1 : !llvm.ptr -> f32
+  // CHECK: llvm.return %[[RES]] : f32
+  llvm.return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @index_in_final_padding
+llvm.func @index_in_final_padding(%arg: i32) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i8)>
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i8)> : (i32) -> !llvm.ptr
+  // CHECK: = llvm.getelementptr %[[ALLOCA]][7] : (!llvm.ptr) -> !llvm.ptr, i8
+  %7 = llvm.getelementptr %1[7] : (!llvm.ptr) -> !llvm.ptr, i8
+  llvm.store %arg, %7 : i32, !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @index_out_of_bounds
+llvm.func @index_out_of_bounds(%arg: i32) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32)>
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32)> : (i32) -> !llvm.ptr
+  // CHECK: = llvm.getelementptr %[[ALLOCA]][9] : (!llvm.ptr) -> !llvm.ptr, i8
+  %7 = llvm.getelementptr %1[9] : (!llvm.ptr) -> !llvm.ptr, i8
+  llvm.store %arg, %7 : i32, !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @index_in_padding
+llvm.func @index_in_padding(%arg: i16) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i16, i32)>
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i16, i32)> : (i32) -> !llvm.ptr
+  // CHECK: = llvm.getelementptr %[[ALLOCA]][2] : (!llvm.ptr) -> !llvm.ptr, i8
+  %7 = llvm.getelementptr %1[2] : (!llvm.ptr) -> !llvm.ptr, i8
+  llvm.store %arg, %7 : i16, !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @index_not_in_padding_because_packed
+llvm.func @index_not_in_padding_because_packed(%arg: i16) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", packed (i16, i32)>
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", packed (i16, i32)> : (i32) -> !llvm.ptr
+  // CHECK: = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", packed (i16, i32)>
+  %7 = llvm.getelementptr %1[2] : (!llvm.ptr) -> !llvm.ptr, i8
+  llvm.store %arg, %7 : i16, !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @index_to_struct
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+llvm.func @index_to_struct(%arg: i32) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, struct<"bar", (i32, i32)>)>
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, struct<"bar", (i32, i32)>)> : (i32) -> !llvm.ptr
+  // CHECK: %[[GEP0:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, struct<"bar", (i32, i32)>)>
+  // CHECK: %[[GEP1:.*]] = llvm.getelementptr %[[GEP0]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"bar", (i32, i32)>
+  %7 = llvm.getelementptr %1[4] : (!llvm.ptr) -> !llvm.ptr, i8
+  // CHECK: llvm.store %[[ARG]], %[[GEP1]]
+  llvm.store %arg, %7 : i32, !llvm.ptr
+  llvm.return
+}


        


More information about the Mlir-commits mailing list