[Mlir-commits] [mlir] [mlir] Add strided metadata range dataflow analysis (PR #161280)

Fabian Mora llvmlistbot at llvm.org
Thu Oct 9 04:29:52 PDT 2025


https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/161280

>From 670c4827b8685ab6b3ed9bc50946937430a99631 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Mon, 29 Sep 2025 21:15:54 +0000
Subject: [PATCH 1/2] add strided metadata range analysis

Signed-off-by: Fabian Mora <fabian.mora-cordero at amd.com>
---
 .../DataFlow/StridedMetadataRangeAnalysis.h   |  54 +++++++
 .../IR/InferStridedMetadataInterfaceImpl.h    |  25 +++
 mlir/include/mlir/Interfaces/CMakeLists.txt   |   1 +
 .../mlir/Interfaces/InferIntRangeInterface.h  |   3 +-
 .../InferStridedMetadataInterface.h           | 148 ++++++++++++++++++
 .../InferStridedMetadataInterface.td          |  43 +++++
 mlir/lib/Analysis/CMakeLists.txt              |   1 +
 .../DataFlow/StridedMetadataRangeAnalysis.cpp | 127 +++++++++++++++
 mlir/lib/Dialect/MemRef/IR/CMakeLists.txt     |   4 +-
 .../IR/InferStridedMetadataInterfaceImpl.cpp  | 118 ++++++++++++++
 mlir/lib/Interfaces/CMakeLists.txt            |   2 +
 .../InferStridedMetadataInterface.cpp         |  36 +++++
 mlir/lib/RegisterAllDialects.cpp              |   2 +
 .../test-strided-metadata-range-analysis.mlir |  67 ++++++++
 mlir/test/lib/Analysis/CMakeLists.txt         |   1 +
 .../TestStridedMetadataRangeAnalysis.cpp      |  86 ++++++++++
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 +
 17 files changed, 718 insertions(+), 2 deletions(-)
 create mode 100644 mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h
 create mode 100644 mlir/include/mlir/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.h
 create mode 100644 mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h
 create mode 100644 mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td
 create mode 100644 mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp
 create mode 100644 mlir/lib/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.cpp
 create mode 100644 mlir/lib/Interfaces/InferStridedMetadataInterface.cpp
 create mode 100644 mlir/test/Analysis/DataFlow/test-strided-metadata-range-analysis.mlir
 create mode 100644 mlir/test/lib/Analysis/DataFlow/TestStridedMetadataRangeAnalysis.cpp

diff --git a/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h
new file mode 100644
index 0000000000000..72ac2477435db
--- /dev/null
+++ b/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h
@@ -0,0 +1,54 @@
+//===- StridedMetadataRange.h - Strided metadata range analysis -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H
+#define MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H
+
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Interfaces/InferStridedMetadataInterface.h"
+
+namespace mlir {
+namespace dataflow {
+
+/// This lattice element represents the strided metadata of an SSA value.
+class StridedMetadataRangeLattice : public Lattice<StridedMetadataRange> {
+public:
+  using Lattice::Lattice;
+};
+
+/// Strided metadata range analysis determines the strided metadata ranges of
+/// SSA values using operations that define `InferStridedMetadataInterface`.
+///
+/// This analysis depends on DeadCodeAnalysis, SparseConstantPropagation, and
+/// IntegerRangeAnalysis, and will be a silent no-op if the analyses are not
+/// loaded in the same solver context.
+class StridedMetadataRangeAnalysis
+    : public SparseForwardDataFlowAnalysis<StridedMetadataRangeLattice> {
+public:
+  StridedMetadataRangeAnalysis(DataFlowSolver &solver,
+                               int32_t indexBitwidth = 64);
+
+  /// At an entry point, we cannot reason about strided metadata ranges unless
+  /// the type also encodes the data. For example, a memref with static layout.
+  void setToEntryState(StridedMetadataRangeLattice *lattice) override;
+
+  /// Visit an operation. Invoke the transfer function on each operation that
+  /// implements `InferStridedMetadataInterface`.
+  LogicalResult
+  visitOperation(Operation *op,
+                 ArrayRef<const StridedMetadataRangeLattice *> operands,
+                 ArrayRef<StridedMetadataRangeLattice *> results) override;
+
+private:
+  /// Index bitwidth to use when operating with the int-ranges.
+  int32_t indexBitwidth = 64;
+};
+} // namespace dataflow
+} // end namespace mlir
+
+#endif // MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.h b/mlir/include/mlir/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.h
new file mode 100644
index 0000000000000..ca3bc78648ab2
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.h
@@ -0,0 +1,25 @@
+//===- InferStridedMetadataOpInterfaceImpl.h - Impl. of infer strided md --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_IR_INFERSTRIDEDMETADATAOPINTERFACEIMPL_H
+#define MLIR_DIALECT_MEMREF_IR_INFERSTRIDEDMETADATAOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace memref {
+/// Register the external models for the infer strided metadata op interface,
+/// for the `memref` dialect. This implementation assumes that the strided
+/// metadata of a ranked memref consists of one offset, and zero or more sizes
+/// and strides.
+void registerInferStridedMetadataOpInterfaceExternalModels(
+    DialectRegistry &registry);
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_IR_INFERSTRIDEDMETADATAOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index a5feb592045c0..72ed046a1ba5d 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_interface(DestinationStyleOpInterface)
 add_mlir_interface(FunctionInterfaces)
 add_mlir_interface(IndexingMapOpInterface)
 add_mlir_interface(InferIntRangeInterface)
+add_mlir_interface(InferStridedMetadataInterface)
 add_mlir_interface(InferTypeOpInterface)
 add_mlir_interface(LoopLikeInterface)
 add_mlir_interface(MemOpInterfaces)
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index 0e107e88f5232..a9e3e82acdc4f 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -117,7 +117,8 @@ class IntegerValueRange {
   IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}
 
   /// Create an integer value range lattice value.
-  IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt)
+  explicit IntegerValueRange(
+      std::optional<ConstantIntRanges> value = std::nullopt)
       : value(std::move(value)) {}
 
   /// Whether the range is uninitialized. This happens when the state hasn't
diff --git a/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h
new file mode 100644
index 0000000000000..ee37db3b4380e
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h
@@ -0,0 +1,148 @@
+//===- InferStridedMetadataInterface.h - Strided Metadata Inference -C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions of the strided metadata inference interface
+// defined in `InferStridedMetadataInterface.td`
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
+#define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
+
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+
+namespace mlir {
+/// A class that represents the strided metadata range information, including
+/// offsets, sizes, and strides as integer ranges.
+class StridedMetadataRange {
+public:
+  /// Default constructor creates uninitialized ranges.
+  StridedMetadataRange() = default;
+
+  /// Returns a ranked strided metadata range.
+  static StridedMetadataRange
+  getRanked(SmallVectorImpl<ConstantIntRanges> &&offsets,
+            SmallVectorImpl<ConstantIntRanges> &&sizes,
+            SmallVectorImpl<ConstantIntRanges> &&strides) {
+    return StridedMetadataRange(std::move(offsets), std::move(sizes),
+                                std::move(strides));
+  }
+
+  /// Returns a strided metadata range with maximum ranges.
+  static StridedMetadataRange getMaxRanges(int32_t indexBitwidth,
+                                           int32_t offsetsRank,
+                                           int32_t sizeRank,
+                                           int32_t stridedRank) {
+    return StridedMetadataRange(
+        SmallVector<ConstantIntRanges>(
+            offsetsRank, ConstantIntRanges::maxRange(indexBitwidth)),
+        SmallVector<ConstantIntRanges>(
+            sizeRank, ConstantIntRanges::maxRange(indexBitwidth)),
+        SmallVector<ConstantIntRanges>(
+            stridedRank, ConstantIntRanges::maxRange(indexBitwidth)));
+  }
+
+  static StridedMetadataRange getMaxRanges(int32_t indexBitwidth,
+                                           int32_t rank) {
+    return getMaxRanges(indexBitwidth, 1, rank, rank);
+  }
+
+  /// Returns whether the metadata is uninitialized.
+  bool isUninitialized() const { return !offsets.has_value(); }
+
+  /// Get the offsets range.
+  ArrayRef<ConstantIntRanges> getOffsets() const {
+    return offsets ? *offsets : ArrayRef<ConstantIntRanges>();
+  }
+  MutableArrayRef<ConstantIntRanges> getOffsets() {
+    return offsets ? *offsets : MutableArrayRef<ConstantIntRanges>();
+  }
+
+  /// Get the sizes ranges.
+  ArrayRef<ConstantIntRanges> getSizes() const { return sizes; }
+  MutableArrayRef<ConstantIntRanges> getSizes() { return sizes; }
+
+  /// Get the strides ranges.
+  ArrayRef<ConstantIntRanges> getStrides() const { return strides; }
+  MutableArrayRef<ConstantIntRanges> getStrides() { return strides; }
+
+  /// Compare two strided metadata ranges.
+  bool operator==(const StridedMetadataRange &other) const {
+    return offsets == other.offsets && sizes == other.sizes &&
+           strides == other.strides;
+  }
+
+  /// Print the strided metadata range.
+  void print(raw_ostream &os) const;
+
+  /// Join two strided metadata ranges, by taking the element-wise union of the
+  /// metadata.
+  static StridedMetadataRange join(const StridedMetadataRange &lhs,
+                                   const StridedMetadataRange &rhs) {
+    if (lhs.isUninitialized())
+      return rhs;
+    if (rhs.isUninitialized())
+      return lhs;
+
+    // Helper fuction to compute the range union of constant ranges.
+    auto rangeUnion =
+        +[](const std::tuple<ConstantIntRanges, ConstantIntRanges> &lhsRhs)
+        -> ConstantIntRanges {
+      return std::get<0>(lhsRhs).rangeUnion(std::get<1>(lhsRhs));
+    };
+
+    // Get the elementwise range union. Note, that `zip_equal` will assert if
+    // sizes are not equal.
+    SmallVector<ConstantIntRanges> offsets = llvm::map_to_vector(
+        llvm::zip_equal(*lhs.offsets, *rhs.offsets), rangeUnion);
+    SmallVector<ConstantIntRanges> sizes =
+        llvm::map_to_vector(llvm::zip_equal(lhs.sizes, rhs.sizes), rangeUnion);
+    SmallVector<ConstantIntRanges> strides = llvm::map_to_vector(
+        llvm::zip_equal(lhs.strides, rhs.strides), rangeUnion);
+
+    // Return the joined metadata.
+    return StridedMetadataRange(std::move(offsets), std::move(sizes),
+                                std::move(strides));
+  }
+
+private:
+  /// Create a strided metadata range with the given offset, sizes, and strides.
+  StridedMetadataRange(SmallVectorImpl<ConstantIntRanges> &&offsets,
+                       SmallVectorImpl<ConstantIntRanges> &&sizes,
+                       SmallVectorImpl<ConstantIntRanges> &&strides)
+      : offsets(std::move(offsets)), sizes(std::move(sizes)),
+        strides(std::move(strides)) {}
+
+  /// The offsets range.
+  std::optional<SmallVector<ConstantIntRanges>> offsets;
+
+  /// The sizes ranges.
+  SmallVector<ConstantIntRanges> sizes;
+
+  /// The strides ranges.
+  SmallVector<ConstantIntRanges> strides;
+};
+
+/// Print the strided metadata to `os`.
+inline raw_ostream &operator<<(raw_ostream &os,
+                               const StridedMetadataRange &range) {
+  range.print(os);
+  return os;
+}
+
+/// Callback function type to get the integer range of a value.
+using GetIntRangeFn = function_ref<IntegerValueRange(Value)>;
+
+/// Callback function type for setting the strided metadata of a value.
+using SetStridedMetadataRangeFn =
+    function_ref<void(Value, const StridedMetadataRange &)>;
+} // end namespace mlir
+
+#include "mlir/Interfaces/InferStridedMetadataInterface.h.inc"
+
+#endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
diff --git a/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td
new file mode 100644
index 0000000000000..87f92dcbf2615
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td
@@ -0,0 +1,43 @@
+//===- InferStridedMetadataInterface.td - Strided MD Inference ----------*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the interface for strided metadata range analysis
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE
+#define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def InferStridedMetadataOpInterface :
+    OpInterface<"InferStridedMetadataOpInterface"> {
+  let description = [{
+    Allows operations to participate in strided metadata analysis by providing
+    methods that allow them to specify bounds on offsets, sizes, and strides
+    of their result(s) given bounds on their input(s) if known.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<[{
+      Infer the strided metadata bounds on the results of this op given
+      the bounds on its operands.
+      For each result value or block argument, the method should call
+      `setMetadata` with that `Value` as an argument.
+      The `getIntRange` callback is provided for obtaining the int-range
+      analysis result for a given value.
+    }],
+    "void", "inferStridedMetadataRanges",
+    (ins "::llvm::ArrayRef<::mlir::StridedMetadataRange>":$operands,
+         "::mlir::GetIntRangeFn":$getIntRange,
+         "::mlir::SetStridedMetadataRangeFn":$setMetadata,
+         "int32_t":$indexBitwidth)>
+  ];
+}
+#endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 609cb34309829..bef189600d8e7 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -40,6 +40,7 @@ add_mlir_library(MLIRAnalysis
   DataFlow/IntegerRangeAnalysis.cpp
   DataFlow/LivenessAnalysis.cpp
   DataFlow/SparseAnalysis.cpp
+  DataFlow/StridedMetadataRangeAnalysis.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Analysis
diff --git a/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp
new file mode 100644
index 0000000000000..01c9dafaddf10
--- /dev/null
+++ b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp
@@ -0,0 +1,127 @@
+//===- StridedMetadataRangeAnalysis.cpp - Integer range analysis --------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the dataflow analysis class for integer range inference
+// which is used in transformations over the `arith` dialect such as
+// branch elimination or signed->unsigned rewriting
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h"
+#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/DebugStringHelper.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
+
+#define DEBUG_TYPE "strided-metadata-range-analysis"
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+/// Get the entry state for a value. For any value that is not a ranked memref,
+/// this function sets the metadata to a top state with no offsets, sizes, or
+/// strides. For `memref` types, this function will use the metadata in the type
+/// to try to deduce as much informaiton as possible.
+static StridedMetadataRange getEntryStateImpl(Value v, int32_t indexBitwidth) {
+  // TODO: generalize this method with a type interface.
+  auto mTy = dyn_cast<BaseMemRefType>(v.getType());
+
+  // If not a memref or it's un-ranked, don't infer any metadata.
+  if (!mTy || !mTy.hasRank())
+    return StridedMetadataRange::getMaxRanges(indexBitwidth, 0, 0, 0);
+
+  // Get the top state.
+  auto metadata =
+      StridedMetadataRange::getMaxRanges(indexBitwidth, mTy.getRank());
+
+  // Compute the offset and strides.
+  int64_t offset;
+  SmallVector<int64_t> strides;
+  if (failed(cast<MemRefType>(mTy).getStridesAndOffset(strides, offset)))
+    return metadata;
+
+  // Refine the metadata if we know it from the type.
+  if (!ShapedType::isDynamic(offset)) {
+    metadata.getOffsets()[0] =
+        ConstantIntRanges::constant(APInt(indexBitwidth, offset));
+  }
+  for (auto &&[size, range] :
+       llvm::zip_equal(mTy.getShape(), metadata.getSizes())) {
+    if (ShapedType::isDynamic(size))
+      continue;
+    range = ConstantIntRanges::constant(APInt(indexBitwidth, size));
+  }
+  for (auto &&[stride, range] :
+       llvm::zip_equal(strides, metadata.getStrides())) {
+    if (ShapedType::isDynamic(stride))
+      continue;
+    range = ConstantIntRanges::constant(APInt(indexBitwidth, stride));
+  }
+
+  return metadata;
+}
+
+StridedMetadataRangeAnalysis::StridedMetadataRangeAnalysis(
+    DataFlowSolver &solver, int32_t indexBitwidth)
+    : SparseForwardDataFlowAnalysis(solver), indexBitwidth(indexBitwidth) {
+  assert(indexBitwidth > 0 && "invalid bitwidth");
+}
+
+void StridedMetadataRangeAnalysis::setToEntryState(
+    StridedMetadataRangeLattice *lattice) {
+  propagateIfChanged(lattice, lattice->join(getEntryStateImpl(
+                                  lattice->getAnchor(), indexBitwidth)));
+}
+
+LogicalResult StridedMetadataRangeAnalysis::visitOperation(
+    Operation *op, ArrayRef<const StridedMetadataRangeLattice *> operands,
+    ArrayRef<StridedMetadataRangeLattice *> results) {
+  auto inferrable = dyn_cast<InferStridedMetadataOpInterface>(op);
+
+  // Bail if we cannot reason about the op.
+  if (!inferrable) {
+    setAllToEntryStates(results);
+    return success();
+  }
+
+  LDBG() << "Inferring metadata for: "
+         << OpWithFlags(op, OpPrintingFlags().skipRegions());
+
+  // Helper function to retrieve int range values.
+  auto getIntRange = [&](Value value) -> IntegerValueRange {
+    auto lattice = getOrCreateFor<IntegerValueRangeLattice>(
+        getProgramPointAfter(op), value);
+    return lattice ? lattice->getValue() : IntegerValueRange();
+  };
+
+  // Convert the arguments lattices to a vector.
+  SmallVector<StridedMetadataRange> argRanges = llvm::map_to_vector(
+      operands, [](const StridedMetadataRangeLattice *lattice) {
+        return lattice->getValue();
+      });
+
+  // Callback to set metadata on a result.
+  auto joinCallback = [&](Value v, const StridedMetadataRange &md) {
+    auto result = cast<OpResult>(v);
+    assert(llvm::is_contained(op->getResults(), result));
+    LDBG() << "- Inferred metadata: " << md;
+    StridedMetadataRangeLattice *lattice = results[result.getResultNumber()];
+    ChangeResult changed = lattice->join(md);
+    LDBG() << "- Joined metadata: " << lattice->getValue();
+    propagateIfChanged(lattice, changed);
+  };
+
+  // Infer the metadata.
+  inferrable.inferStridedMetadataRanges(argRanges, getIntRange, joinCallback,
+                                        indexBitwidth);
+  return success();
+}
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index e25a0121a3359..9707dc0cc64e9 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -2,10 +2,11 @@ add_mlir_dialect_library(MLIRMemRefDialect
   MemRefDialect.cpp
   MemRefMemorySlot.cpp
   MemRefOps.cpp
+  InferStridedMetadataInterfaceImpl.cpp
   ValueBoundsOpInterfaceImpl.cpp
 
   ADDITIONAL_HEADER_DIRS
-  ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect
+  ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRef/IR
 
   DEPENDS
   MLIRMemRefOpsIncGen
@@ -18,6 +19,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
   MLIRDialectUtils
   MLIRInferIntRangeCommon
   MLIRInferIntRangeInterface
+  MLIRInferStridedMetadataInterface
   MLIRInferTypeOpInterface
   MLIRIR
   MLIRMemOpInterfaces
diff --git a/mlir/lib/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.cpp
new file mode 100644
index 0000000000000..4bc4edc0357e8
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.cpp
@@ -0,0 +1,118 @@
+//===- InferStridedMetadataInterfaceImpl.cpp - Impl. of infer strided md --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.h"
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Interfaces/InferStridedMetadataInterface.h"
+#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
+
+using namespace mlir;
+using namespace mlir::memref;
+
+/// Collect the integer range values on a set of op fold results. This function
+/// returns failure if any of the int ranges couldn't be collected.
+static FailureOr<SmallVector<ConstantIntRanges>>
+getIntValueRanges(SmallVector<OpFoldResult> values, GetIntRangeFn getIntRange,
+                  int32_t indexBitwidth) {
+  SmallVector<ConstantIntRanges> ranges;
+  ranges.reserve(values.size());
+  for (OpFoldResult ofr : values) {
+    if (auto value = dyn_cast<Value>(ofr)) {
+      IntegerValueRange range = getIntRange(value);
+      // Bail if the range is not available.
+      if (range.isUninitialized())
+        return failure();
+      ranges.push_back(range.getValue());
+      continue;
+    }
+
+    // Create a constant range.
+    auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
+    ranges.emplace_back(ConstantIntRanges::constant(
+        attr.getValue().sextOrTrunc(indexBitwidth)));
+  }
+  return ranges;
+}
+
+namespace {
+/// Implementation of `InferStridedMetadataOpInterface` for the `memref.subview`
+/// operation.
+struct SubViewOpInterface
+    : public InferStridedMetadataOpInterface::ExternalModel<SubViewOpInterface,
+                                                            SubViewOp> {
+  void inferStridedMetadataRanges(Operation *op,
+                                  ArrayRef<StridedMetadataRange> ranges,
+                                  GetIntRangeFn getIntRange,
+                                  SetStridedMetadataRangeFn setMetadata,
+                                  int32_t indexBitwidth) const {
+    auto subViewOp = cast<SubViewOp>(op);
+
+    // Bail early if any of the operands metadata is not ready:
+    FailureOr<SmallVector<ConstantIntRanges>> offsetOperands =
+        getIntValueRanges(subViewOp.getMixedOffsets(), getIntRange,
+                          indexBitwidth);
+    if (failed(offsetOperands))
+      return;
+
+    FailureOr<SmallVector<ConstantIntRanges>> sizeOperands = getIntValueRanges(
+        subViewOp.getMixedSizes(), getIntRange, indexBitwidth);
+    if (failed(sizeOperands))
+      return;
+
+    FailureOr<SmallVector<ConstantIntRanges>> stridesOperands =
+        getIntValueRanges(subViewOp.getMixedStrides(), getIntRange,
+                          indexBitwidth);
+    if (failed(stridesOperands))
+      return;
+
+    StridedMetadataRange sourceRange =
+        ranges[subViewOp.getSourceMutable().getOperandNumber()];
+    if (sourceRange.isUninitialized())
+      return;
+
+    ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides();
+
+    // Get the dropped dims.
+    llvm::SmallBitVector droppedDims = subViewOp.getDroppedDims();
+
+    // Compute the new offset, strides and sizes.
+    ConstantIntRanges offset = sourceRange.getOffsets()[0];
+    SmallVector<ConstantIntRanges> strides, sizes;
+
+    for (size_t i = 0, e = droppedDims.size(); i < e; ++i) {
+      bool dropped = droppedDims.test(i);
+      // Compute the new offset.
+      ConstantIntRanges off =
+          intrange::inferMul({(*offsetOperands)[i], srcStrides[i]});
+      offset = intrange::inferAdd({offset, off});
+
+      // Skip dropped dimensions.
+      if (dropped)
+        continue;
+      // Multiply the strides.
+      strides.push_back(
+          intrange::inferMul({(*stridesOperands)[i], srcStrides[i]}));
+      // Get the sizes.
+      sizes.push_back((*sizeOperands)[i]);
+    }
+
+    setMetadata(subViewOp.getResult(),
+                StridedMetadataRange::getRanked(
+                    SmallVector<ConstantIntRanges>({std::move(offset)}),
+                    std::move(sizes), std::move(strides)));
+  }
+};
+} // namespace
+
+void mlir::memref::registerInferStridedMetadataOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
+    memref::SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
+  });
+}
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index 388de1c3e5abf..ad020eb431ee0 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -9,6 +9,7 @@ set(LLVM_OPTIONAL_SOURCES
   FunctionInterfaces.cpp
   IndexingMapOpInterface.cpp
   InferIntRangeInterface.cpp
+  InferStridedMetadataInterface.cpp
   InferTypeOpInterface.cpp
   LoopLikeInterface.cpp
   MemOpInterfaces.cpp
@@ -64,6 +65,7 @@ add_mlir_library(MLIRFunctionInterfaces
 
 add_mlir_interface_library(IndexingMapOpInterface)
 add_mlir_interface_library(InferIntRangeInterface)
+add_mlir_interface_library(InferStridedMetadataInterface)
 add_mlir_interface_library(InferTypeOpInterface)
 
 add_mlir_library(MLIRLoopLikeInterface
diff --git a/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp b/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp
new file mode 100644
index 0000000000000..483e9f192cdcd
--- /dev/null
+++ b/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp
@@ -0,0 +1,36 @@
+//===- InferStridedMetadataInterface.cpp - Strided md inference interface -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/InferStridedMetadataInterface.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include <optional>
+
+using namespace mlir;
+
+#include "mlir/Interfaces/InferStridedMetadataInterface.cpp.inc"
+
+void StridedMetadataRange::print(raw_ostream &os) const {
+  if (isUninitialized()) {
+    os << "strided_metadata<None>";
+    return;
+  }
+  os << "strided_metadata<offset = [";
+  llvm::interleaveComma(*offsets, os, [&](const ConstantIntRanges &range) {
+    os << "{" << range << "}";
+  });
+  os << "], sizes = [";
+  llvm::interleaveComma(sizes, os, [&](const ConstantIntRanges &range) {
+    os << "{" << range << "}";
+  });
+  os << "], strides = [";
+  llvm::interleaveComma(strides, os, [&](const ConstantIntRanges &range) {
+    os << "{" << range << "}";
+  });
+  os << "]>";
+}
diff --git a/mlir/lib/RegisterAllDialects.cpp b/mlir/lib/RegisterAllDialects.cpp
index 258fed135a3e5..36a16af026b88 100644
--- a/mlir/lib/RegisterAllDialects.cpp
+++ b/mlir/lib/RegisterAllDialects.cpp
@@ -53,6 +53,7 @@
 #include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
 #include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
 #include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
@@ -178,6 +179,7 @@ void mlir::registerAllDialects(DialectRegistry &registry) {
   memref::registerAllocationOpInterfaceExternalModels(registry);
   memref::registerBufferViewFlowOpInterfaceExternalModels(registry);
   memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
+  memref::registerInferStridedMetadataOpInterfaceExternalModels(registry);
   memref::registerValueBoundsOpInterfaceExternalModels(registry);
   memref::registerMemorySlotExternalModels(registry);
   ml_program::registerBufferizableOpInterfaceExternalModels(registry);
diff --git a/mlir/test/Analysis/DataFlow/test-strided-metadata-range-analysis.mlir b/mlir/test/Analysis/DataFlow/test-strided-metadata-range-analysis.mlir
new file mode 100644
index 0000000000000..808c1c2bfd2a8
--- /dev/null
+++ b/mlir/test/Analysis/DataFlow/test-strided-metadata-range-analysis.mlir
@@ -0,0 +1,67 @@
+// RUN: mlir-opt -test-strided-metadata-range-analysis %s 2>&1 | FileCheck %s
+
+func.func @memref_subview(%arg0: memref<8x16x4xf32, strided<[64, 4, 1]>>, %arg1: memref<1x128x1x32x1xf32, strided<[4096, 32, 32, 1, 1]>>, %arg2: memref<8x16x4xf32, strided<[1, 64, 8], offset: 16>>, %arg3: index, %arg4: index, %arg5: index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %0 = test.with_bounds {smax = 13 : index, smin = 11 : index, umax = 13 : index, umin = 11 : index} : index
+  %1 = test.with_bounds {smax = 7 : index, smin = 5 : index, umax = 7 : index, umin = 5 : index} : index
+
+  // Test subview with unknown sizes, and constant offsets and strides.
+  // CHECK: Op:  %[[SV0:.*]] = memref.subview
+  // CHECK-NEXT: result[0]: strided_metadata<
+  // CHECK-SAME: offset = [{unsigned : [1, 1] signed : [1, 1]}]
+  // CHECK-SAME: sizes = [{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}, {unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}, {unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}]
+  // CHECK-SAME: strides = [{unsigned : [64, 64] signed : [64, 64]}, {unsigned : [4, 4] signed : [4, 4]}, {unsigned : [1, 1] signed : [1, 1]}]
+  %subview = memref.subview %arg0[%c0, %c0, %c1] [%arg3, %arg4, %arg5] [%c1, %c1, %c1] : memref<8x16x4xf32, strided<[64, 4, 1]>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+
+  // Test a subview of a subview, with bounded dynamic offsets.
+  // CHECK: Op:  %[[SV1:.*]] = memref.subview
+  // CHECK-NEXT: result[0]: strided_metadata<
+  // CHECK-SAME: offset = [{unsigned : [346, 484] signed : [346, 484]}]
+  // CHECK-SAME: sizes = [{unsigned : [2, 2] signed : [2, 2]}, {unsigned : [2, 2] signed : [2, 2]}, {unsigned : [2, 2] signed : [2, 2]}]
+  // CHECK-SAME: strides = [{unsigned : [704, 832] signed : [704, 832]}, {unsigned : [44, 52] signed : [44, 52]}, {unsigned : [11, 13] signed : [11, 13]}]
+  %subview_0 = memref.subview %subview[%1, %1, %1] [%c2, %c2, %c2] [%0, %0, %0] : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+
+  // Test a subview of a subview, with constant operands.
+  // CHECK: Op:  %[[SV2:.*]] = memref.subview
+  // CHECK-NEXT: result[0]: strided_metadata<
+  // CHECK-SAME: offset = [{unsigned : [368, 510] signed : [368, 510]}]
+  // CHECK-SAME: sizes = [{unsigned : [2, 2] signed : [2, 2]}, {unsigned : [2, 2] signed : [2, 2]}, {unsigned : [2, 2] signed : [2, 2]}]
+  // CHECK-SAME: strides = [{unsigned : [704, 832] signed : [704, 832]}, {unsigned : [44, 52] signed : [44, 52]}, {unsigned : [11, 13] signed : [11, 13]}]
+  %subview_1 = memref.subview %subview_0[%c0, %c0, %c2] [%c2, %c2, %c2] [%c1, %c1, %c1] : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+
+  // Test a rank-reducing subview.
+  // CHECK: Op:  %[[SV3:.*]] = memref.subview
+  // CHECK-NEXT: result[0]: strided_metadata<
+  // CHECK-SAME: offset = [{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}]
+  // CHECK-SAME: sizes = [{unsigned : [64, 64] signed : [64, 64]}, {unsigned : [16, 16] signed : [16, 16]}]
+  // CHECK-SAME: strides = [{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}, {unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}]
+  %subview_2 = memref.subview %arg1[%arg4, %arg4, %arg4, %arg4, %arg4] [1, 64, 1, 16, 1] [%arg5, %arg5, %arg5, %arg5, %arg5] : memref<1x128x1x32x1xf32, strided<[4096, 32, 32, 1, 1]>> to memref<64x16xf32, strided<[?, ?], offset: ?>>
+
+  // Test a subview of a rank-reducing subview
+  // CHECK: Op:  %[[SV4:.*]] = memref.subview
+  // CHECK-NEXT: result[0]: strided_metadata<
+  // CHECK-SAME: offset = [{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}]
+  // CHECK-SAME: sizes = [{unsigned : [5, 7] signed : [5, 7]}]
+  // CHECK-SAME: strides = [{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}]
+  %subview_3 = memref.subview %subview_2[%c0, %0] [1, %1] [%c1, %c2] : memref<64x16xf32, strided<[?, ?], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
+
+  // Test a subview with mixed bounded and unbound dynamic sizes.
+  // CHECK: Op:  %[[SV5:.*]] = memref.subview
+  // CHECK-NEXT: result[0]: strided_metadata<
+  // CHECK-SAME: offset = [{unsigned : [32, 32] signed : [32, 32]}]
+  // CHECK-SAME: sizes = [{unsigned : [11, 13] signed : [11, 13]}, {unsigned : [5, 7] signed : [5, 7]}, {unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}]
+  // CHECK-SAME: strides = [{unsigned : [1, 1] signed : [1, 1]}, {unsigned : [64, 64] signed : [64, 64]}, {unsigned : [8, 8] signed : [8, 8]}]
+  %subview_4 = memref.subview %arg2[%c0, %c0, %c2] [%0, %1, %arg5] [%c1, %c1, %c1] : memref<8x16x4xf32, strided<[1, 64, 8], offset: 16>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+  return
+}
+
+// CHECK:       func.func @memref_subview
+// CHECK:       %[[A0:.*]]: memref<8x16x4xf32, strided<[64, 4, 1]>>
+// CHECK:       %[[SV0]] = memref.subview %[[A0]]
+// CHECK-NEXT:  %[[SV1]] = memref.subview
+// CHECK-NEXT:  %[[SV2]] = memref.subview
+// CHECK-NEXT:  %[[SV3]] = memref.subview
+// CHECK-NEXT:  %[[SV4]] = memref.subview
+// CHECK-NEXT:  %[[SV5]] = memref.subview
diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt
index 91879981bffd2..c37671ade37b3 100644
--- a/mlir/test/lib/Analysis/CMakeLists.txt
+++ b/mlir/test/lib/Analysis/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_library(MLIRTestAnalysis
   DataFlow/TestDenseForwardDataFlowAnalysis.cpp
   DataFlow/TestLivenessAnalysis.cpp
   DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
+  DataFlow/TestStridedMetadataRangeAnalysis.cpp
 
   EXCLUDE_FROM_LIBMLIR
 
diff --git a/mlir/test/lib/Analysis/DataFlow/TestStridedMetadataRangeAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestStridedMetadataRangeAnalysis.cpp
new file mode 100644
index 0000000000000..6ac09fdeed136
--- /dev/null
+++ b/mlir/test/lib/Analysis/DataFlow/TestStridedMetadataRangeAnalysis.cpp
@@ -0,0 +1,86 @@
+//===- TestStridedMetadataRangeAnalysis.cpp - Test strided md analysis ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
+#include "mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h"
+#include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+static void printAnalysisResults(DataFlowSolver &solver, Operation *op,
+                                 raw_ostream &os) {
+  // Collect the strided metadata of the op results.
+  SmallVector<std::pair<unsigned, const StridedMetadataRangeLattice *>> results;
+  for (OpResult result : op->getResults()) {
+    const auto *state = solver.lookupState<StridedMetadataRangeLattice>(result);
+    // Skip the result if it's uninitialized.
+    if (!state || state->getValue().isUninitialized())
+      continue;
+
+    // Skip the result if the range is empty.
+    const mlir::StridedMetadataRange &md = state->getValue();
+    if (md.getOffsets().empty() && md.getSizes().empty() &&
+        md.getStrides().empty())
+      continue;
+    results.push_back({result.getResultNumber(), state});
+  }
+
+  // Early exit if there's no metadata to print.
+  if (results.empty())
+    return;
+
+  // Print the metadata.
+  os << "Op: " << OpWithFlags(op, OpPrintingFlags().skipRegions()) << "\n";
+  for (auto [idx, state] : results)
+    os << "  result[" << idx << "]: " << state->getValue() << "\n";
+  os << "\n";
+}
+
+namespace {
+struct TestStridedMetadataRangeAnalysisPass
+    : public PassWrapper<TestStridedMetadataRangeAnalysisPass,
+                         OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestStridedMetadataRangeAnalysisPass)
+
+  StringRef getArgument() const override {
+    return "test-strided-metadata-range-analysis";
+  }
+  void runOnOperation() override {
+    Operation *op = getOperation();
+
+    DataFlowSolver solver;
+    solver.load<DeadCodeAnalysis>();
+    solver.load<SparseConstantPropagation>();
+    solver.load<IntegerRangeAnalysis>();
+    solver.load<StridedMetadataRangeAnalysis>();
+    if (failed(solver.initializeAndRun(op)))
+      return signalPassFailure();
+
+    op->walk(
+        [&](Operation *op) { printAnalysisResults(solver, op, llvm::errs()); });
+  }
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestStridedMetadataRangeAnalysisPass() {
+  PassRegistration<TestStridedMetadataRangeAnalysisPass>();
+}
+} // end namespace test
+} // end namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 6432fae615f88..88421800fed1e 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -151,6 +151,7 @@ void registerTestSliceAnalysisPass();
 void registerTestSPIRVCPURunnerPipeline();
 void registerTestSPIRVFuncSignatureConversion();
 void registerTestSPIRVVectorUnrolling();
+void registerTestStridedMetadataRangeAnalysisPass();
 void registerTestTensorCopyInsertionPass();
 void registerTestTensorLikeAndBufferLikePass();
 void registerTestTensorTransforms();
@@ -299,6 +300,7 @@ void registerTestPasses() {
   mlir::test::registerTestSPIRVCPURunnerPipeline();
   mlir::test::registerTestSPIRVFuncSignatureConversion();
   mlir::test::registerTestSPIRVVectorUnrolling();
+  mlir::test::registerTestStridedMetadataRangeAnalysisPass();
   mlir::test::registerTestTensorCopyInsertionPass();
   mlir::test::registerTestTensorLikeAndBufferLikePass();
   mlir::test::registerTestTensorTransforms();

>From 6957fc5a5e694071bbb1911b2574852b8a2f62a2 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Thu, 9 Oct 2025 11:27:41 +0000
Subject: [PATCH 2/2] address reviewer comments

---
 .../IR/InferStridedMetadataInterfaceImpl.h    |  25 ----
 mlir/include/mlir/Dialect/MemRef/IR/MemRef.h  |   1 +
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |   2 +
 .../mlir/Interfaces/InferIntRangeInterface.h  |   9 ++
 .../InferStridedMetadataInterface.h           |   3 -
 .../InferStridedMetadataInterface.td          |   6 +-
 mlir/lib/Dialect/MemRef/IR/CMakeLists.txt     |   1 -
 .../IR/InferStridedMetadataInterfaceImpl.cpp  | 118 ------------------
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      |  59 +++++++++
 .../lib/Interfaces/InferIntRangeInterface.cpp |  19 +++
 mlir/lib/RegisterAllDialects.cpp              |   2 -
 11 files changed, 94 insertions(+), 151 deletions(-)
 delete mode 100644 mlir/include/mlir/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.h
 delete mode 100644 mlir/lib/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.cpp

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.h b/mlir/include/mlir/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.h
deleted file mode 100644
index ca3bc78648ab2..0000000000000
--- a/mlir/include/mlir/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.h
+++ /dev/null
@@ -1,25 +0,0 @@
-//===- InferStridedMetadataOpInterfaceImpl.h - Impl. of infer strided md --===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_MEMREF_IR_INFERSTRIDEDMETADATAOPINTERFACEIMPL_H
-#define MLIR_DIALECT_MEMREF_IR_INFERSTRIDEDMETADATAOPINTERFACEIMPL_H
-
-namespace mlir {
-class DialectRegistry;
-
-namespace memref {
-/// Register the external models for the infer strided metadata op interface,
-/// for the `memref` dialect. This implementation assumes that the strided
-/// metadata of a ranked memref consists of one offset, and zero or more sizes
-/// and strides.
-void registerInferStridedMetadataOpInterfaceExternalModels(
-    DialectRegistry &registry);
-} // namespace memref
-} // namespace mlir
-
-#endif // MLIR_DIALECT_MEMREF_IR_INFERSTRIDEDMETADATAOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index 30f33ed2fd1d6..69447f74ec403 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -17,6 +17,7 @@
 #include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Interfaces/InferStridedMetadataInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/MemOpInterfaces.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 40b7d7e33d5c2..47013e66b7d30 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -14,6 +14,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td"
 include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
+include "mlir/Interfaces/InferStridedMetadataInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/MemOpInterfaces.td"
 include "mlir/Interfaces/MemorySlotInterfaces.td"
@@ -2084,6 +2085,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
 
 def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<InferStridedMetadataOpInterface>,
     DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
     DeclareOpInterfaceMethods<ViewLikeOpInterface>,
     AttrSizedOperandSegments,
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index a9e3e82acdc4f..a6de3d1885eec 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -168,6 +168,15 @@ using SetIntRangeFn =
 using SetIntLatticeFn =
     llvm::function_ref<void(Value, const IntegerValueRange &)>;
 
+/// Helper callback type to get the integer range of a value.
+using GetIntRangeFn = function_ref<IntegerValueRange(Value)>;
+
+/// Helper function to collect the integer range values of an array of op fold
+/// results.
+SmallVector<IntegerValueRange> getIntValueRanges(ArrayRef<OpFoldResult> values,
+                                                 GetIntRangeFn getIntRange,
+                                                 int32_t indexBitwidth);
+
 class InferIntRangeInterface;
 
 namespace intrange::detail {
diff --git a/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h
index ee37db3b4380e..0c572e0196a03 100644
--- a/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h
+++ b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h
@@ -135,9 +135,6 @@ inline raw_ostream &operator<<(raw_ostream &os,
   return os;
 }
 
-/// Callback function type to get the integer range of a value.
-using GetIntRangeFn = function_ref<IntegerValueRange(Value)>;
-
 /// Callback function type for setting the strided metadata of a value.
 using SetStridedMetadataRangeFn =
     function_ref<void(Value, const StridedMetadataRange &)>;
diff --git a/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td
index 87f92dcbf2615..ee5b0942f683e 100644
--- a/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td
+++ b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td
@@ -28,8 +28,10 @@ def InferStridedMetadataOpInterface :
     InterfaceMethod<[{
       Infer the strided metadata bounds on the results of this op given
       the bounds on its operands.
-      For each result value or block argument, the method should call
-      `setMetadata` with that `Value` as an argument.
+      For each result value or block argument of interest, the method should
+      call `setMetadata` with that `Value` as an argument.
+      The `operands` parameter contains the strided metadata ranges for all the
+      operands of the operation in order.
       The `getIntRange` callback is provided for obtaining the int-range
       analysis result for a given value.
     }],
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index 9707dc0cc64e9..1382c7aceea79 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -2,7 +2,6 @@ add_mlir_dialect_library(MLIRMemRefDialect
   MemRefDialect.cpp
   MemRefMemorySlot.cpp
   MemRefOps.cpp
-  InferStridedMetadataInterfaceImpl.cpp
   ValueBoundsOpInterfaceImpl.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.cpp
deleted file mode 100644
index 4bc4edc0357e8..0000000000000
--- a/mlir/lib/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.cpp
+++ /dev/null
@@ -1,118 +0,0 @@
-//===- InferStridedMetadataInterfaceImpl.cpp - Impl. of infer strided md --===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.h"
-
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Interfaces/InferStridedMetadataInterface.h"
-#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
-
-using namespace mlir;
-using namespace mlir::memref;
-
-/// Collect the integer range values on a set of op fold results. This function
-/// returns failure if any of the int ranges couldn't be collected.
-static FailureOr<SmallVector<ConstantIntRanges>>
-getIntValueRanges(SmallVector<OpFoldResult> values, GetIntRangeFn getIntRange,
-                  int32_t indexBitwidth) {
-  SmallVector<ConstantIntRanges> ranges;
-  ranges.reserve(values.size());
-  for (OpFoldResult ofr : values) {
-    if (auto value = dyn_cast<Value>(ofr)) {
-      IntegerValueRange range = getIntRange(value);
-      // Bail if the range is not available.
-      if (range.isUninitialized())
-        return failure();
-      ranges.push_back(range.getValue());
-      continue;
-    }
-
-    // Create a constant range.
-    auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
-    ranges.emplace_back(ConstantIntRanges::constant(
-        attr.getValue().sextOrTrunc(indexBitwidth)));
-  }
-  return ranges;
-}
-
-namespace {
-/// Implementation of `InferStridedMetadataOpInterface` for the `memref.subview`
-/// operation.
-struct SubViewOpInterface
-    : public InferStridedMetadataOpInterface::ExternalModel<SubViewOpInterface,
-                                                            SubViewOp> {
-  void inferStridedMetadataRanges(Operation *op,
-                                  ArrayRef<StridedMetadataRange> ranges,
-                                  GetIntRangeFn getIntRange,
-                                  SetStridedMetadataRangeFn setMetadata,
-                                  int32_t indexBitwidth) const {
-    auto subViewOp = cast<SubViewOp>(op);
-
-    // Bail early if any of the operands metadata is not ready:
-    FailureOr<SmallVector<ConstantIntRanges>> offsetOperands =
-        getIntValueRanges(subViewOp.getMixedOffsets(), getIntRange,
-                          indexBitwidth);
-    if (failed(offsetOperands))
-      return;
-
-    FailureOr<SmallVector<ConstantIntRanges>> sizeOperands = getIntValueRanges(
-        subViewOp.getMixedSizes(), getIntRange, indexBitwidth);
-    if (failed(sizeOperands))
-      return;
-
-    FailureOr<SmallVector<ConstantIntRanges>> stridesOperands =
-        getIntValueRanges(subViewOp.getMixedStrides(), getIntRange,
-                          indexBitwidth);
-    if (failed(stridesOperands))
-      return;
-
-    StridedMetadataRange sourceRange =
-        ranges[subViewOp.getSourceMutable().getOperandNumber()];
-    if (sourceRange.isUninitialized())
-      return;
-
-    ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides();
-
-    // Get the dropped dims.
-    llvm::SmallBitVector droppedDims = subViewOp.getDroppedDims();
-
-    // Compute the new offset, strides and sizes.
-    ConstantIntRanges offset = sourceRange.getOffsets()[0];
-    SmallVector<ConstantIntRanges> strides, sizes;
-
-    for (size_t i = 0, e = droppedDims.size(); i < e; ++i) {
-      bool dropped = droppedDims.test(i);
-      // Compute the new offset.
-      ConstantIntRanges off =
-          intrange::inferMul({(*offsetOperands)[i], srcStrides[i]});
-      offset = intrange::inferAdd({offset, off});
-
-      // Skip dropped dimensions.
-      if (dropped)
-        continue;
-      // Multiply the strides.
-      strides.push_back(
-          intrange::inferMul({(*stridesOperands)[i], srcStrides[i]}));
-      // Get the sizes.
-      sizes.push_back((*sizeOperands)[i]);
-    }
-
-    setMetadata(subViewOp.getResult(),
-                StridedMetadataRange::getRanked(
-                    SmallVector<ConstantIntRanges>({std::move(offset)}),
-                    std::move(sizes), std::move(strides)));
-  }
-};
-} // namespace
-
-void mlir::memref::registerInferStridedMetadataOpInterfaceExternalModels(
-    DialectRegistry &registry) {
-  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
-    memref::SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
-  });
-}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e9bdcda296da5..507597b4707c4 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3437,6 +3437,65 @@ SubViewOp::bubbleDownCasts(OpBuilder &builder) {
   return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
 }
 
+void SubViewOp::inferStridedMetadataRanges(
+    ArrayRef<StridedMetadataRange> ranges, GetIntRangeFn getIntRange,
+    SetStridedMetadataRangeFn setMetadata, int32_t indexBitwidth) {
+  auto isUninitialized =
+      +[](IntegerValueRange range) { return range.isUninitialized(); };
+
+  // Bail early if any of the operands metadata is not ready:
+  SmallVector<IntegerValueRange> offsetOperands =
+      getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth);
+  if (llvm::any_of(offsetOperands, isUninitialized))
+    return;
+
+  SmallVector<IntegerValueRange> sizeOperands =
+      getIntValueRanges(getMixedSizes(), getIntRange, indexBitwidth);
+  if (llvm::any_of(sizeOperands, isUninitialized))
+    return;
+
+  SmallVector<IntegerValueRange> stridesOperands =
+      getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth);
+  if (llvm::any_of(stridesOperands, isUninitialized))
+    return;
+
+  StridedMetadataRange sourceRange =
+      ranges[getSourceMutable().getOperandNumber()];
+  if (sourceRange.isUninitialized())
+    return;
+
+  ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides();
+
+  // Get the dropped dims.
+  llvm::SmallBitVector droppedDims = getDroppedDims();
+
+  // Compute the new offset, strides and sizes.
+  ConstantIntRanges offset = sourceRange.getOffsets()[0];
+  SmallVector<ConstantIntRanges> strides, sizes;
+
+  for (size_t i = 0, e = droppedDims.size(); i < e; ++i) {
+    bool dropped = droppedDims.test(i);
+    // Compute the new offset.
+    ConstantIntRanges off =
+        intrange::inferMul({offsetOperands[i].getValue(), srcStrides[i]});
+    offset = intrange::inferAdd({offset, off});
+
+    // Skip dropped dimensions.
+    if (dropped)
+      continue;
+    // Multiply the strides.
+    strides.push_back(
+        intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]}));
+    // Get the sizes.
+    sizes.push_back(sizeOperands[i].getValue());
+  }
+
+  setMetadata(getResult(),
+              StridedMetadataRange::getRanked(
+                  SmallVector<ConstantIntRanges>({std::move(offset)}),
+                  std::move(sizes), std::move(strides)));
+}
+
 //===----------------------------------------------------------------------===//
 // TransposeOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index 9f3e97d051c85..84fc9b8b61a11 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -146,6 +146,25 @@ raw_ostream &mlir::operator<<(raw_ostream &os, const IntegerValueRange &range) {
   return os;
 }
 
+SmallVector<IntegerValueRange>
+mlir::getIntValueRanges(ArrayRef<OpFoldResult> values,
+                        GetIntRangeFn getIntRange, int32_t indexBitwidth) {
+  SmallVector<IntegerValueRange> ranges;
+  ranges.reserve(values.size());
+  for (OpFoldResult ofr : values) {
+    if (auto value = dyn_cast<Value>(ofr)) {
+      ranges.push_back(getIntRange(value));
+      continue;
+    }
+
+    // Create a constant range.
+    auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
+    ranges.emplace_back(ConstantIntRanges::constant(
+        attr.getValue().sextOrTrunc(indexBitwidth)));
+  }
+  return ranges;
+}
+
 void mlir::intrange::detail::defaultInferResultRanges(
     InferIntRangeInterface interface, ArrayRef<IntegerValueRange> argRanges,
     SetIntLatticeFn setResultRanges) {
diff --git a/mlir/lib/RegisterAllDialects.cpp b/mlir/lib/RegisterAllDialects.cpp
index 36a16af026b88..258fed135a3e5 100644
--- a/mlir/lib/RegisterAllDialects.cpp
+++ b/mlir/lib/RegisterAllDialects.cpp
@@ -53,7 +53,6 @@
 #include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
 #include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
 #include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
@@ -179,7 +178,6 @@ void mlir::registerAllDialects(DialectRegistry &registry) {
   memref::registerAllocationOpInterfaceExternalModels(registry);
   memref::registerBufferViewFlowOpInterfaceExternalModels(registry);
   memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
-  memref::registerInferStridedMetadataOpInterfaceExternalModels(registry);
   memref::registerValueBoundsOpInterfaceExternalModels(registry);
   memref::registerMemorySlotExternalModels(registry);
   ml_program::registerBufferizableOpInterfaceExternalModels(registry);



More information about the Mlir-commits mailing list