[Mlir-commits] [mlir] [mlir][IntRangeInference] Infer values for {memref, tensor}.dim (PR #122945)

Krzysztof Drewniak llvmlistbot at llvm.org
Fri Jan 17 20:09:22 PST 2025


https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/122945

>From db5d7865008555e8f4104099fde70f24147066c3 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Tue, 14 Jan 2025 18:46:46 +0000
Subject: [PATCH 1/2] [mlir][IntRangeInference] Infer values for
 {memref,tensor}.dim

Implement the integer range inference niterface for memref.dim and
tetnor.dim using shared code. The inference will infer the `dim` of
dynamic dimensions to [0, index_max] and take the union of all the
dimensions that the `dim` argument could be validly referring to.
---
 mlir/include/mlir/Dialect/MemRef/IR/MemRef.h  |  1 +
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |  6 +-
 mlir/include/mlir/Dialect/Tensor/IR/Tensor.h  |  1 +
 .../mlir/Dialect/Tensor/IR/TensorOps.td       |  4 +-
 .../Interfaces/Utils/InferIntRangeCommon.h    |  8 +++
 mlir/lib/Dialect/MemRef/IR/CMakeLists.txt     |  2 +
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      |  7 +++
 mlir/lib/Dialect/Tensor/IR/CMakeLists.txt     |  2 +
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      |  8 +++
 mlir/lib/Interfaces/Utils/CMakeLists.txt      |  1 +
 .../Interfaces/Utils/InferIntRangeCommon.cpp  | 44 +++++++++++++
 .../Dialect/MemRef/int-range-inference.mlir   | 61 +++++++++++++++++++
 .../Dialect/Tensor/int-range-inference.mlir   | 61 +++++++++++++++++++
 13 files changed, 203 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Dialect/MemRef/int-range-inference.mlir
 create mode 100644 mlir/test/Dialect/Tensor/int-range-inference.mlir

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index 72463dca715ca32..ac383ab46e7a504 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -17,6 +17,7 @@
 #include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/CopyOpInterface.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index a0d8d34f38237ac..c3ee3968abc16dc 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -14,6 +14,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td"
 include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/CopyOpInterface.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/MemorySlotInterfaces.td"
 include "mlir/Interfaces/ShapedOpInterfaces.td"
@@ -577,7 +578,8 @@ def MemRef_DimOp : MemRef_Op<"dim", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     MemRefsNormalizable,
     ConditionallySpeculatable, NoMemoryEffect,
-    ShapedDimOpInterface]> {
+    ShapedDimOpInterface,
+    DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>]> {
   let summary = "dimension index operation";
   let description = [{
     The `dim` operation takes a memref and a dimension operand of type `index`.
@@ -1675,7 +1677,7 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     }]>,
 
     // Builder that infers the result layout map. The result shape must be
-    // specified. Otherwise, the op may be ambiguous. The output shape for 
+    // specified. Otherwise, the op may be ambiguous. The output shape for
     // the op will be inferred using the inferOutputShape() method.
     OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
                "ArrayRef<ReassociationIndices>":$reassociation)>,
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 0a21c9922b223bf..bd96337a55407a8 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -18,6 +18,7 @@
 #include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/ParallelCombiningOpInterface.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 812ac2098450204..38874513a4cc003 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -13,6 +13,7 @@ include "mlir/Dialect/Tensor/IR/TensorBase.td"
 include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/ParallelCombiningOpInterface.td"
 include "mlir/Interfaces/ShapedOpInterfaces.td"
@@ -195,7 +196,8 @@ def Tensor_ConcatOp : Tensor_Op<"concat",
 def Tensor_DimOp : Tensor_Op<"dim", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     ConditionallySpeculatable, NoMemoryEffect,
-    ShapedDimOpInterface]> {
+    ShapedDimOpInterface,
+    DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>]> {
   let summary = "dimension index operation";
   let description = [{
     The `tensor.dim` operation takes a tensor and a dimension operand of type
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index 3988a8826498a94..e46358ccfc46f72 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -20,6 +20,8 @@
 #include <optional>
 
 namespace mlir {
+class ShapedDimOpInterface;
+
 namespace intrange {
 /// Function that performs inference on an array of `ConstantIntRanges`,
 /// abstracted away here to permit writing the function that handles both
@@ -143,6 +145,12 @@ std::optional<bool> evaluatePred(CmpPredicate pred,
                                  const ConstantIntRanges &lhs,
                                  const ConstantIntRanges &rhs);
 
+/// Returns the integer range for the result of a `ShapedDimOpInterface` given
+/// the optional inferred ranges for the `dimension` index `maybeDim`. When a
+/// dynamic dimension is encountered, returns [0, signed_max(type(result))].
+ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op,
+                                            const IntegerValueRange &maybeDim);
+
 } // namespace intrange
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index 845914ebd107a26..734294bd014c6e8 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -16,6 +16,8 @@ add_mlir_dialect_library(MLIRMemRefDialect
   MLIRControlFlowInterfaces
   MLIRDialect
   MLIRDialectUtils
+  MLIRInferIntRangeCommon
+  MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRIR
   MLIRMemorySlotInterfaces
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9aae46a5c288dc9..f0aee7a68e0bfff 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
@@ -915,6 +916,12 @@ Speculation::Speculatability DimOp::getSpeculatability() {
   return Speculation::Speculatable;
 }
 
+void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
+                                          SetIntLatticeFn setResultRange) {
+  setResultRange(getResult(),
+                 intrange::inferShapedDimOpInterface(*this, argRanges[1]));
+}
+
 /// Return a map with key being elements in `vals` and data being number of
 /// occurences of it. Use std::map, since the `vals` here are strides and the
 /// dynamic stride value is the same as the tombstone value for
diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
index cfdd3847761a49a..5425615dac39324 100644
--- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
@@ -26,6 +26,8 @@ add_mlir_dialect_library(MLIRTensorDialect
   MLIRDestinationStyleOpInterface
   MLIRDialectUtils
   MLIRIR
+  MLIRInferIntRangeCommon
+  MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRParallelCombiningOpInterface
   MLIRShapedOpInterfaces
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 24a1d5531531981..e0853cab60fb948 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -23,7 +23,9 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
@@ -782,6 +784,12 @@ Speculation::Speculatability DimOp::getSpeculatability() {
   return Speculation::Speculatable;
 }
 
+void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
+                                          SetIntLatticeFn setResultRange) {
+  setResultRange(getResult(),
+                 intrange::inferShapedDimOpInterface(*this, argRanges[1]));
+}
+
 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
   // All forms of folding require a known index.
   auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
diff --git a/mlir/lib/Interfaces/Utils/CMakeLists.txt b/mlir/lib/Interfaces/Utils/CMakeLists.txt
index ece6c8e46ffea9c..8c45f6699742719 100644
--- a/mlir/lib/Interfaces/Utils/CMakeLists.txt
+++ b/mlir/lib/Interfaces/Utils/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_library(MLIRInferIntRangeCommon
     MLIRInferIntRangeInterfaceIncGen
 
     LINK_LIBS PUBLIC
+    MLIRShapedOpInterfaces
     MLIRInferIntRangeInterface
     MLIRIR
 )
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 1eab4139488bdd3..2f47939df5a0222 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 
 #include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Interfaces/ShapedOpInterfaces.h"
 
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
@@ -725,3 +726,46 @@ std::optional<bool> mlir::intrange::evaluatePred(CmpPredicate pred,
     return false;
   return std::nullopt;
 }
+
+//===----------------------------------------------------------------------===//
+// Shaped type dimension accessors / ShapedDimOpInterface
+//===----------------------------------------------------------------------===//
+
+ConstantIntRanges
+mlir::intrange::inferShapedDimOpInterface(ShapedDimOpInterface op,
+                                          const IntegerValueRange &maybeDim) {
+  unsigned width =
+      ConstantIntRanges::getStorageBitwidth(op->getResult(0).getType());
+  APInt zero = APInt::getZero(width);
+  APInt typeMax = APInt::getSignedMaxValue(width);
+
+  auto shapedTy = cast<ShapedType>(op.getShapedValue().getType());
+  if (!shapedTy.hasRank())
+    return ConstantIntRanges::fromSigned(zero, typeMax);
+
+  int64_t rank = shapedTy.getRank();
+  int64_t minDim = 0;
+  int64_t maxDim = rank - 1;
+  if (!maybeDim.isUninitialized()) {
+    const ConstantIntRanges &dim = maybeDim.getValue();
+    minDim = std::max(minDim, dim.smin().getSExtValue());
+    maxDim = std::min(maxDim, dim.smax().getSExtValue());
+  }
+
+  std::optional<ConstantIntRanges> result;
+  auto joinResult = [&](const ConstantIntRanges &thisResult) {
+    if (!result.has_value())
+      result = thisResult;
+    else
+      result = result->rangeUnion(thisResult);
+  };
+  for (int64_t i = minDim; i <= maxDim; ++i) {
+    int64_t length = shapedTy.getDimSize(i);
+
+    if (ShapedType::isDynamic(length))
+      joinResult(ConstantIntRanges::fromSigned(zero, typeMax));
+    else
+      joinResult(ConstantIntRanges::constant(APInt(width, length)));
+  }
+  return result.value_or(ConstantIntRanges::fromSigned(zero, typeMax));
+}
diff --git a/mlir/test/Dialect/MemRef/int-range-inference.mlir b/mlir/test/Dialect/MemRef/int-range-inference.mlir
new file mode 100644
index 000000000000000..e2aa487eaaa25b4
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/int-range-inference.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt -int-range-optimizations --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @dim_const
+// CHECK: %[[ret:.+]] = arith.constant 3 : index
+// CHECK: return %[[ret]]
+func.func @dim_const(%m: memref<3x5xi32>) -> index {
+  %c0 = arith.constant 0 : index
+  %0 = memref.dim %m, %c0 : memref<3x5xi32>
+  return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_any_static
+// CHECK: %[[op:.+]] = memref.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_any_static(%m: memref<3x5xi32>, %x: index) -> index {
+  %0 = memref.dim %m, %x : memref<3x5xi32>
+  %1 = test.reflect_bounds %0 : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_dynamic
+// CHECK: %[[op:.+]] = memref.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_dynamic(%m: memref<?x5xi32>) -> index {
+  %c0 = arith.constant 0 : index
+  %0 = memref.dim %m, %c0 : memref<?x5xi32>
+  %1 = test.reflect_bounds %0 : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_any_dynamic
+// CHECK: %[[op:.+]] = memref.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_any_dynamic(%m: memref<?x5xi32>, %x: index) -> index {
+  %0 = memref.dim %m, %x : memref<?x5xi32>
+  %1 = test.reflect_bounds %0 : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_some_omitting_dynamic
+// CHECK: %[[op:.+]] = memref.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_some_omitting_dynamic(%m: memref<?x3x5xi32>, %x: index) -> index {
+  %c1 = arith.constant 1 : index
+  %0 = arith.maxsi %x, %c1 : index
+  %1 = memref.dim %m, %0 : memref<?x3x5xi32>
+  %2 = test.reflect_bounds %1 : index
+  return %2 : index
+}
diff --git a/mlir/test/Dialect/Tensor/int-range-inference.mlir b/mlir/test/Dialect/Tensor/int-range-inference.mlir
new file mode 100644
index 000000000000000..384ae781e0e3300
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/int-range-inference.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt -int-range-optimizations --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @dim_const
+// CHECK: %[[ret:.+]] = arith.constant 3 : index
+// CHECK: return %[[ret]]
+func.func @dim_const(%m: tensor<3x5xi32>) -> index {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.dim %m, %c0 : tensor<3x5xi32>
+  return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_any_static
+// CHECK: %[[op:.+]] = tensor.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_any_static(%m: tensor<3x5xi32>, %x: index) -> index {
+  %0 = tensor.dim %m, %x : tensor<3x5xi32>
+  %1 = test.reflect_bounds %0 : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_dynamic
+// CHECK: %[[op:.+]] = tensor.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_dynamic(%m: tensor<?x5xi32>) -> index {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.dim %m, %c0 : tensor<?x5xi32>
+  %1 = test.reflect_bounds %0 : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_any_dynamic
+// CHECK: %[[op:.+]] = tensor.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_any_dynamic(%m: tensor<?x5xi32>, %x: index) -> index {
+  %0 = tensor.dim %m, %x : tensor<?x5xi32>
+  %1 = test.reflect_bounds %0 : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_some_omitting_dynamic
+// CHECK: %[[op:.+]] = tensor.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_some_omitting_dynamic(%m: tensor<?x3x5xi32>, %x: index) -> index {
+  %c1 = arith.constant 1 : index
+  %0 = arith.maxsi %x, %c1 : index
+  %1 = tensor.dim %m, %0 : tensor<?x3x5xi32>
+  %2 = test.reflect_bounds %1 : index
+  return %2 : index
+}

>From fdb82c188c5e41f091dfde0787b83aa8284fa060 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Sat, 18 Jan 2025 03:54:16 +0000
Subject: [PATCH 2/2] Add test for the unranked case

---
 .../Dialect/MemRef/int-range-inference.mlir   | 13 ++++++++
 .../Dialect/Tensor/int-range-inference.mlir   | 33 +++++++++++++------
 2 files changed, 36 insertions(+), 10 deletions(-)

diff --git a/mlir/test/Dialect/MemRef/int-range-inference.mlir b/mlir/test/Dialect/MemRef/int-range-inference.mlir
index e2aa487eaaa25b4..34568d1d1d5200f 100644
--- a/mlir/test/Dialect/MemRef/int-range-inference.mlir
+++ b/mlir/test/Dialect/MemRef/int-range-inference.mlir
@@ -59,3 +59,16 @@ func.func @dim_some_omitting_dynamic(%m: memref<?x3x5xi32>, %x: index) -> index
   %2 = test.reflect_bounds %1 : index
   return %2 : index
 }
+
+// -----
+
+// CHECK-LABEL: @dim_unranked
+// CHECK: %[[op:.+]] = memref.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_unranked(%m: memref<*xi32>) -> index {
+  %c0 = arith.constant 0 : index
+  %0 = memref.dim %m, %c0 : memref<*xi32>
+  %1 = test.reflect_bounds %0 : index
+  return %1 : index
+}
diff --git a/mlir/test/Dialect/Tensor/int-range-inference.mlir b/mlir/test/Dialect/Tensor/int-range-inference.mlir
index 384ae781e0e3300..e90ebf5fccb8ea8 100644
--- a/mlir/test/Dialect/Tensor/int-range-inference.mlir
+++ b/mlir/test/Dialect/Tensor/int-range-inference.mlir
@@ -3,9 +3,9 @@
 // CHECK-LABEL: @dim_const
 // CHECK: %[[ret:.+]] = arith.constant 3 : index
 // CHECK: return %[[ret]]
-func.func @dim_const(%m: tensor<3x5xi32>) -> index {
+func.func @dim_const(%t: tensor<3x5xi32>) -> index {
   %c0 = arith.constant 0 : index
-  %0 = tensor.dim %m, %c0 : tensor<3x5xi32>
+  %0 = tensor.dim %t, %c0 : tensor<3x5xi32>
   return %0 : index
 }
 
@@ -15,8 +15,8 @@ func.func @dim_const(%m: tensor<3x5xi32>) -> index {
 // CHECK: %[[op:.+]] = tensor.dim
 // CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
 // CHECK: return %[[ret]]
-func.func @dim_any_static(%m: tensor<3x5xi32>, %x: index) -> index {
-  %0 = tensor.dim %m, %x : tensor<3x5xi32>
+func.func @dim_any_static(%t: tensor<3x5xi32>, %x: index) -> index {
+  %0 = tensor.dim %t, %x : tensor<3x5xi32>
   %1 = test.reflect_bounds %0 : index
   return %1 : index
 }
@@ -27,9 +27,9 @@ func.func @dim_any_static(%m: tensor<3x5xi32>, %x: index) -> index {
 // CHECK: %[[op:.+]] = tensor.dim
 // CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
 // CHECK: return %[[ret]]
-func.func @dim_dynamic(%m: tensor<?x5xi32>) -> index {
+func.func @dim_dynamic(%t: tensor<?x5xi32>) -> index {
   %c0 = arith.constant 0 : index
-  %0 = tensor.dim %m, %c0 : tensor<?x5xi32>
+  %0 = tensor.dim %t, %c0 : tensor<?x5xi32>
   %1 = test.reflect_bounds %0 : index
   return %1 : index
 }
@@ -40,8 +40,8 @@ func.func @dim_dynamic(%m: tensor<?x5xi32>) -> index {
 // CHECK: %[[op:.+]] = tensor.dim
 // CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
 // CHECK: return %[[ret]]
-func.func @dim_any_dynamic(%m: tensor<?x5xi32>, %x: index) -> index {
-  %0 = tensor.dim %m, %x : tensor<?x5xi32>
+func.func @dim_any_dynamic(%t: tensor<?x5xi32>, %x: index) -> index {
+  %0 = tensor.dim %t, %x : tensor<?x5xi32>
   %1 = test.reflect_bounds %0 : index
   return %1 : index
 }
@@ -52,10 +52,23 @@ func.func @dim_any_dynamic(%m: tensor<?x5xi32>, %x: index) -> index {
 // CHECK: %[[op:.+]] = tensor.dim
 // CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
 // CHECK: return %[[ret]]
-func.func @dim_some_omitting_dynamic(%m: tensor<?x3x5xi32>, %x: index) -> index {
+func.func @dim_some_omitting_dynamic(%t: tensor<?x3x5xi32>, %x: index) -> index {
   %c1 = arith.constant 1 : index
   %0 = arith.maxsi %x, %c1 : index
-  %1 = tensor.dim %m, %0 : tensor<?x3x5xi32>
+  %1 = tensor.dim %t, %0 : tensor<?x3x5xi32>
   %2 = test.reflect_bounds %1 : index
   return %2 : index
 }
+
+// -----
+
+// CHECK-LABEL: @dim_unranked
+// CHECK: %[[op:.+]] = tensor.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_unranked(%t: tensor<*xi32>) -> index {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.dim %t, %c0 : tensor<*xi32>
+  %1 = test.reflect_bounds %0 : index
+  return %1 : index
+}



More information about the Mlir-commits mailing list