[Mlir-commits] [mlir] [mlir] Implement inferResultRanges for vector.step op (PR #151536)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 31 08:21:43 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: None (Max191)

<details>
<summary>Changes</summary>

Implements the `inferResultRanges` method from the `InferIntRangeInterface` interface for `vector.step`.

---
Full diff: https://github.com/llvm/llvm-project/pull/151536.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+4-1) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+25) 
- (modified) mlir/test/Dialect/Vector/int-range-interface.mlir (+8) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3885439e11f89..1d9a9d3f699ac 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2876,7 +2876,10 @@ def Vector_ScanOp :
 // VectorStepOp
 //===----------------------------------------------------------------------===//
 
-def Vector_StepOp : Vector_Op<"step", [Pure]> {
+def Vector_StepOp : Vector_Op<"step", [
+    Pure,
+    DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
+  ]> {
   let summary = "A linear sequence of values from 0 to N";
   let description = [{
     A `step` operation produces an index vector, i.e. a 1-D vector of values of
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8789f55707267..144091bd0f70e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -7197,6 +7197,31 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
   return selectPassthru(b, mask, result, acc);
 }
 
+//===----------------------------------------------------------------------===//
+// StepOp
+//===----------------------------------------------------------------------===//
+
+void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                               SetIntRangeFn setResultRanges) {
+  auto resultType = cast<VectorType>(getType());
+  if (resultType.isScalable()) {
+    return;
+  }
+  std::optional<ConstantIntRanges> result;
+  Type elementType = resultType.getElementType();
+  unsigned bitwidth = elementType.isIndex()
+                          ? IndexType::kInternalStorageBitWidth
+                          : elementType.getIntOrFloatBitWidth();
+  int64_t size = resultType.getShape()[0];
+  for (int64_t val : llvm::seq<int64_t>(size)) {
+    auto range = ConstantIntRanges::constant(APInt(bitwidth, val));
+    result = (result ? result->rangeUnion(range) : range);
+  }
+
+  assert(result && "Zero-sized vectors are not allowed");
+  setResultRanges(getResult(), *result);
+}
+
 //===----------------------------------------------------------------------===//
 // Vector Masking Utilities
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir
index 2563b48cdd506..c60c21fadb668 100644
--- a/mlir/test/Dialect/Vector/int-range-interface.mlir
+++ b/mlir/test/Dialect/Vector/int-range-interface.mlir
@@ -99,3 +99,11 @@ func.func @test_vector_extsi() -> vector<2xi32> {
   %2 = test.reflect_bounds %1 : vector<2xi32>
   func.return %2 : vector<2xi32>
 }
+
+// CHECK-LABEL: func @vector_step
+// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
+func.func @vector_step() -> vector<8xindex> {
+  %0 = vector.step : vector<8xindex>
+  %1 = test.reflect_bounds %0 : vector<8xindex>
+  func.return %1 : vector<8xindex>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/151536


More information about the Mlir-commits mailing list