[Mlir-commits] [mlir] Reland "[mlir] Add strided metadata range dataflow analysis" (#163403)" (PR #163408)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 14 07:42:48 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Fabian Mora (fabianmcg)

<details>
<summary>Changes</summary>

This relands commit aa8499863ad23350da0912d99d189f306d0ea139 after fixing shared lib builds.

---

Patch is 37.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163408.diff


18 Files Affected:

- (added) mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h (+54) 
- (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRef.h (+1) 
- (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+2) 
- (modified) mlir/include/mlir/Interfaces/CMakeLists.txt (+1) 
- (modified) mlir/include/mlir/Interfaces/InferIntRangeInterface.h (+11-1) 
- (added) mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h (+145) 
- (added) mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td (+45) 
- (modified) mlir/lib/Analysis/CMakeLists.txt (+2) 
- (added) mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp (+127) 
- (modified) mlir/lib/Dialect/MemRef/IR/CMakeLists.txt (+2-1) 
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+59) 
- (modified) mlir/lib/Interfaces/CMakeLists.txt (+16) 
- (modified) mlir/lib/Interfaces/InferIntRangeInterface.cpp (+19) 
- (added) mlir/lib/Interfaces/InferStridedMetadataInterface.cpp (+36) 
- (added) mlir/test/Analysis/DataFlow/test-strided-metadata-range-analysis.mlir (+67) 
- (modified) mlir/test/lib/Analysis/CMakeLists.txt (+1) 
- (added) mlir/test/lib/Analysis/DataFlow/TestStridedMetadataRangeAnalysis.cpp (+86) 
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2) 


``````````diff
diff --git a/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h
new file mode 100644
index 0000000000000..72ac2477435db
--- /dev/null
+++ b/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h
@@ -0,0 +1,54 @@
+//===- StridedMetadataRange.h - Strided metadata range analysis -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H
+#define MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H
+
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Interfaces/InferStridedMetadataInterface.h"
+
+namespace mlir {
+namespace dataflow {
+
+/// This lattice element represents the strided metadata of an SSA value.
+class StridedMetadataRangeLattice : public Lattice<StridedMetadataRange> {
+public:
+  using Lattice::Lattice;
+};
+
+/// Strided metadata range analysis determines the strided metadata ranges of
+/// SSA values using operations that define `InferStridedMetadataInterface`.
+///
+/// This analysis depends on DeadCodeAnalysis, SparseConstantPropagation, and
+/// IntegerRangeAnalysis, and will be a silent no-op if the analyses are not
+/// loaded in the same solver context.
+class StridedMetadataRangeAnalysis
+    : public SparseForwardDataFlowAnalysis<StridedMetadataRangeLattice> {
+public:
+  StridedMetadataRangeAnalysis(DataFlowSolver &solver,
+                               int32_t indexBitwidth = 64);
+
+  /// At an entry point, we cannot reason about strided metadata ranges unless
+  /// the type also encodes the data. For example, a memref with static layout.
+  void setToEntryState(StridedMetadataRangeLattice *lattice) override;
+
+  /// Visit an operation. Invoke the transfer function on each operation that
+  /// implements `InferStridedMetadataInterface`.
+  LogicalResult
+  visitOperation(Operation *op,
+                 ArrayRef<const StridedMetadataRangeLattice *> operands,
+                 ArrayRef<StridedMetadataRangeLattice *> results) override;
+
+private:
+  /// Index bitwidth to use when operating with the int-ranges.
+  int32_t indexBitwidth = 64;
+};
+} // namespace dataflow
+} // end namespace mlir
+
+#endif // MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index 30f33ed2fd1d6..69447f74ec403 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -17,6 +17,7 @@
 #include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Interfaces/InferStridedMetadataInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/MemOpInterfaces.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 89bd0f103d9f3..b39207fc30dd7 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -14,6 +14,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td"
 include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
+include "mlir/Interfaces/InferStridedMetadataInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/MemOpInterfaces.td"
 include "mlir/Interfaces/MemorySlotInterfaces.td"
@@ -2085,6 +2086,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
 
 def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<InferStridedMetadataOpInterface>,
     DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
     DeclareOpInterfaceMethods<ViewLikeOpInterface>,
     AttrSizedOperandSegments,
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index a5feb592045c0..72ed046a1ba5d 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_interface(DestinationStyleOpInterface)
 add_mlir_interface(FunctionInterfaces)
 add_mlir_interface(IndexingMapOpInterface)
 add_mlir_interface(InferIntRangeInterface)
+add_mlir_interface(InferStridedMetadataInterface)
 add_mlir_interface(InferTypeOpInterface)
 add_mlir_interface(LoopLikeInterface)
 add_mlir_interface(MemOpInterfaces)
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index 0e107e88f5232..a6de3d1885eec 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -117,7 +117,8 @@ class IntegerValueRange {
   IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}
 
   /// Create an integer value range lattice value.
-  IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt)
+  explicit IntegerValueRange(
+      std::optional<ConstantIntRanges> value = std::nullopt)
       : value(std::move(value)) {}
 
   /// Whether the range is uninitialized. This happens when the state hasn't
@@ -167,6 +168,15 @@ using SetIntRangeFn =
 using SetIntLatticeFn =
     llvm::function_ref<void(Value, const IntegerValueRange &)>;
 
+/// Helper callback type to get the integer range of a value.
+using GetIntRangeFn = function_ref<IntegerValueRange(Value)>;
+
+/// Helper function to collect the integer range values of an array of op fold
+/// results.
+SmallVector<IntegerValueRange> getIntValueRanges(ArrayRef<OpFoldResult> values,
+                                                 GetIntRangeFn getIntRange,
+                                                 int32_t indexBitwidth);
+
 class InferIntRangeInterface;
 
 namespace intrange::detail {
diff --git a/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h
new file mode 100644
index 0000000000000..0c572e0196a03
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h
@@ -0,0 +1,145 @@
+//===- InferStridedMetadataInterface.h - Strided Metadata Inference -C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions of the strided metadata inference interface
+// defined in `InferStridedMetadataInterface.td`
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
+#define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
+
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+
+namespace mlir {
+/// A class that represents the strided metadata range information, including
+/// offsets, sizes, and strides as integer ranges.
+class StridedMetadataRange {
+public:
+  /// Default constructor creates uninitialized ranges.
+  StridedMetadataRange() = default;
+
+  /// Returns a ranked strided metadata range.
+  static StridedMetadataRange
+  getRanked(SmallVectorImpl<ConstantIntRanges> &&offsets,
+            SmallVectorImpl<ConstantIntRanges> &&sizes,
+            SmallVectorImpl<ConstantIntRanges> &&strides) {
+    return StridedMetadataRange(std::move(offsets), std::move(sizes),
+                                std::move(strides));
+  }
+
+  /// Returns a strided metadata range with maximum ranges.
+  static StridedMetadataRange getMaxRanges(int32_t indexBitwidth,
+                                           int32_t offsetsRank,
+                                           int32_t sizeRank,
+                                           int32_t stridedRank) {
+    return StridedMetadataRange(
+        SmallVector<ConstantIntRanges>(
+            offsetsRank, ConstantIntRanges::maxRange(indexBitwidth)),
+        SmallVector<ConstantIntRanges>(
+            sizeRank, ConstantIntRanges::maxRange(indexBitwidth)),
+        SmallVector<ConstantIntRanges>(
+            stridedRank, ConstantIntRanges::maxRange(indexBitwidth)));
+  }
+
+  static StridedMetadataRange getMaxRanges(int32_t indexBitwidth,
+                                           int32_t rank) {
+    return getMaxRanges(indexBitwidth, 1, rank, rank);
+  }
+
+  /// Returns whether the metadata is uninitialized.
+  bool isUninitialized() const { return !offsets.has_value(); }
+
+  /// Get the offsets range.
+  ArrayRef<ConstantIntRanges> getOffsets() const {
+    return offsets ? *offsets : ArrayRef<ConstantIntRanges>();
+  }
+  MutableArrayRef<ConstantIntRanges> getOffsets() {
+    return offsets ? *offsets : MutableArrayRef<ConstantIntRanges>();
+  }
+
+  /// Get the sizes ranges.
+  ArrayRef<ConstantIntRanges> getSizes() const { return sizes; }
+  MutableArrayRef<ConstantIntRanges> getSizes() { return sizes; }
+
+  /// Get the strides ranges.
+  ArrayRef<ConstantIntRanges> getStrides() const { return strides; }
+  MutableArrayRef<ConstantIntRanges> getStrides() { return strides; }
+
+  /// Compare two strided metadata ranges.
+  bool operator==(const StridedMetadataRange &other) const {
+    return offsets == other.offsets && sizes == other.sizes &&
+           strides == other.strides;
+  }
+
+  /// Print the strided metadata range.
+  void print(raw_ostream &os) const;
+
+  /// Join two strided metadata ranges, by taking the element-wise union of the
+  /// metadata.
+  static StridedMetadataRange join(const StridedMetadataRange &lhs,
+                                   const StridedMetadataRange &rhs) {
+    if (lhs.isUninitialized())
+      return rhs;
+    if (rhs.isUninitialized())
+      return lhs;
+
+    // Helper fuction to compute the range union of constant ranges.
+    auto rangeUnion =
+        +[](const std::tuple<ConstantIntRanges, ConstantIntRanges> &lhsRhs)
+        -> ConstantIntRanges {
+      return std::get<0>(lhsRhs).rangeUnion(std::get<1>(lhsRhs));
+    };
+
+    // Get the elementwise range union. Note, that `zip_equal` will assert if
+    // sizes are not equal.
+    SmallVector<ConstantIntRanges> offsets = llvm::map_to_vector(
+        llvm::zip_equal(*lhs.offsets, *rhs.offsets), rangeUnion);
+    SmallVector<ConstantIntRanges> sizes =
+        llvm::map_to_vector(llvm::zip_equal(lhs.sizes, rhs.sizes), rangeUnion);
+    SmallVector<ConstantIntRanges> strides = llvm::map_to_vector(
+        llvm::zip_equal(lhs.strides, rhs.strides), rangeUnion);
+
+    // Return the joined metadata.
+    return StridedMetadataRange(std::move(offsets), std::move(sizes),
+                                std::move(strides));
+  }
+
+private:
+  /// Create a strided metadata range with the given offset, sizes, and strides.
+  StridedMetadataRange(SmallVectorImpl<ConstantIntRanges> &&offsets,
+                       SmallVectorImpl<ConstantIntRanges> &&sizes,
+                       SmallVectorImpl<ConstantIntRanges> &&strides)
+      : offsets(std::move(offsets)), sizes(std::move(sizes)),
+        strides(std::move(strides)) {}
+
+  /// The offsets range.
+  std::optional<SmallVector<ConstantIntRanges>> offsets;
+
+  /// The sizes ranges.
+  SmallVector<ConstantIntRanges> sizes;
+
+  /// The strides ranges.
+  SmallVector<ConstantIntRanges> strides;
+};
+
+/// Print the strided metadata to `os`.
+inline raw_ostream &operator<<(raw_ostream &os,
+                               const StridedMetadataRange &range) {
+  range.print(os);
+  return os;
+}
+
+/// Callback function type for setting the strided metadata of a value.
+using SetStridedMetadataRangeFn =
+    function_ref<void(Value, const StridedMetadataRange &)>;
+} // end namespace mlir
+
+#include "mlir/Interfaces/InferStridedMetadataInterface.h.inc"
+
+#endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
diff --git a/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td
new file mode 100644
index 0000000000000..ee5b0942f683e
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td
@@ -0,0 +1,45 @@
+//===- InferStridedMetadataInterface.td - Strided MD Inference ----------*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the interface for strided metadata range analysis
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE
+#define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def InferStridedMetadataOpInterface :
+    OpInterface<"InferStridedMetadataOpInterface"> {
+  let description = [{
+    Allows operations to participate in strided metadata analysis by providing
+    methods that allow them to specify bounds on offsets, sizes, and strides
+    of their result(s) given bounds on their input(s) if known.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<[{
+      Infer the strided metadata bounds on the results of this op given
+      the bounds on its operands.
+      For each result value or block argument of interest, the method should
+      call `setMetadata` with that `Value` as an argument.
+      The `operands` parameter contains the strided metadata ranges for all the
+      operands of the operation in order.
+      The `getIntRange` callback is provided for obtaining the int-range
+      analysis result for a given value.
+    }],
+    "void", "inferStridedMetadataRanges",
+    (ins "::llvm::ArrayRef<::mlir::StridedMetadataRange>":$operands,
+         "::mlir::GetIntRangeFn":$getIntRange,
+         "::mlir::SetStridedMetadataRangeFn":$setMetadata,
+         "int32_t":$indexBitwidth)>
+  ];
+}
+#endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 609cb34309829..db10ebcf2c311 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -40,6 +40,7 @@ add_mlir_library(MLIRAnalysis
   DataFlow/IntegerRangeAnalysis.cpp
   DataFlow/LivenessAnalysis.cpp
   DataFlow/SparseAnalysis.cpp
+  DataFlow/StridedMetadataRangeAnalysis.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Analysis
@@ -53,6 +54,7 @@ add_mlir_library(MLIRAnalysis
   MLIRDataLayoutInterfaces
   MLIRFunctionInterfaces
   MLIRInferIntRangeInterface
+  MLIRInferStridedMetadataInterface
   MLIRInferTypeOpInterface
   MLIRLoopLikeInterface
   MLIRPresburger
diff --git a/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp
new file mode 100644
index 0000000000000..01c9dafaddf10
--- /dev/null
+++ b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp
@@ -0,0 +1,127 @@
+//===- StridedMetadataRangeAnalysis.cpp - Integer range analysis --------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the dataflow analysis class for integer range inference
+// which is used in transformations over the `arith` dialect such as
+// branch elimination or signed->unsigned rewriting
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h"
+#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/DebugStringHelper.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
+
+#define DEBUG_TYPE "strided-metadata-range-analysis"
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+/// Get the entry state for a value. For any value that is not a ranked memref,
+/// this function sets the metadata to a top state with no offsets, sizes, or
+/// strides. For `memref` types, this function will use the metadata in the type
+/// to try to deduce as much informaiton as possible.
+static StridedMetadataRange getEntryStateImpl(Value v, int32_t indexBitwidth) {
+  // TODO: generalize this method with a type interface.
+  auto mTy = dyn_cast<BaseMemRefType>(v.getType());
+
+  // If not a memref or it's un-ranked, don't infer any metadata.
+  if (!mTy || !mTy.hasRank())
+    return StridedMetadataRange::getMaxRanges(indexBitwidth, 0, 0, 0);
+
+  // Get the top state.
+  auto metadata =
+      StridedMetadataRange::getMaxRanges(indexBitwidth, mTy.getRank());
+
+  // Compute the offset and strides.
+  int64_t offset;
+  SmallVector<int64_t> strides;
+  if (failed(cast<MemRefType>(mTy).getStridesAndOffset(strides, offset)))
+    return metadata;
+
+  // Refine the metadata if we know it from the type.
+  if (!ShapedType::isDynamic(offset)) {
+    metadata.getOffsets()[0] =
+        ConstantIntRanges::constant(APInt(indexBitwidth, offset));
+  }
+  for (auto &&[size, range] :
+       llvm::zip_equal(mTy.getShape(), metadata.getSizes())) {
+    if (ShapedType::isDynamic(size))
+      continue;
+    range = ConstantIntRanges::constant(APInt(indexBitwidth, size));
+  }
+  for (auto &&[stride, range] :
+       llvm::zip_equal(strides, metadata.getStrides())) {
+    if (ShapedType::isDynamic(stride))
+      continue;
+    range = ConstantIntRanges::constant(APInt(indexBitwidth, stride));
+  }
+
+  return metadata;
+}
+
+StridedMetadataRangeAnalysis::StridedMetadataRangeAnalysis(
+    DataFlowSolver &solver, int32_t indexBitwidth)
+    : SparseForwardDataFlowAnalysis(solver), indexBitwidth(indexBitwidth) {
+  assert(indexBitwidth > 0 && "invalid bitwidth");
+}
+
+void StridedMetadataRangeAnalysis::setToEntryState(
+    StridedMetadataRangeLattice *lattice) {
+  propagateIfChanged(lattice, lattice->join(getEntryStateImpl(
+                                  lattice->getAnchor(), indexBitwidth)));
+}
+
+LogicalResult StridedMetadataRangeAnalysis::visitOperation(
+    Operation *op, ArrayRef<const StridedMetadataRangeLattice *> operands,
+    ArrayRef<StridedMetadataRangeLattice *> results) {
+  auto inferrable = dyn_cast<InferStridedMetadataOpInterface>(op);
+
+  // Bail if we cannot reason about the op.
+  if (!inferrable) {
+    setAllToEntryStates(results);
+    return success();
+  }
+
+  LDBG() << "Inferring metadata for: "
+         << OpWithFlags(op, OpPrintingFlags().skipRegions());
+
+  // Helper function to retrieve int range values.
+  auto getIntRange = [&](Value value) -> IntegerValueRange {
+    auto lattice = getOrCreateFor<IntegerValueRangeLattice>(
+        getProgramPointAfter(op), value);
+    return lattice ? lattice->getValue() : IntegerValueRange();
+  };
+
+  // Convert the arguments lattices to a vector.
+  SmallVector<StridedMetadataRange> argRanges = llvm::map_to_vector(
+      operands, [](const StridedMetadataRangeLattice *lattice) {
+        return lattice->getValue();
+      });
+
+  // Callback to set metadata on a result.
+  auto joinCallback = [&](Value v, const StridedMetadataRange &md) {
+    auto result = cast<OpResult>(v);
+    assert(llvm::is_contained(op->getResults(), result));
+    LDBG() << "- Inferred metadata: " << md;
+    StridedMetadataRangeLattice *lattice = results[result.getResultNumber()];
+    ChangeResult changed = lattice->join(md);
+    LDBG() << "- Joined metadata: " << lattice->getValue();
+    propagateIfChanged(lattice, changed);
+  };
+
+  // Infer the metadata.
+  inferrable.inferStridedMetadataRanges(argRanges, getIntRange, joinCallbac...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/163408


More information about the Mlir-commits mailing list