[Mlir-commits] [mlir] [MLIR] Incorrect result for RuntimeVerifiableOpInterface on MemRef::R… (PR #96580)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 24 18:05:46 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-memref

Author: None (hmalgewatta)

<details>
<summary>Changes</summary>

…einterpretCastOpInterface

Fixes issue where the upper bound of a resulting reinterpret cast operation is wrongly calculated

Fix calcuates the inclusive upper bound by subtracting one from each dimensions size and then using that for calculation

Adds a test case

Fixes:  #<!-- -->94864

---
Full diff: https://github.com/llvm/llvm-project/pull/96580.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Utils/IndexingUtils.h (+4) 
- (modified) mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp (+36-3) 
- (modified) mlir/lib/Dialect/Utils/IndexingUtils.cpp (+34) 
- (modified) mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir (+10) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index b774359552aa5..2034872359ebe 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -260,6 +260,10 @@ computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<OpFoldResult> strides,
 std::pair<AffineExpr, SmallVector<OpFoldResult>>
 computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides,
                    ArrayRef<Value> indices);
+std::pair<AffineExpr, SmallVector<OpFoldResult>>
+computeInclusiveLinearIndex(OpFoldResult sourceOffset,
+                            ArrayRef<OpFoldResult> strides,
+                            ArrayRef<OpFoldResult> indices);
 
 //===----------------------------------------------------------------------===//
 // Utilities for decomposing larger shapes
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 450bfa0cec0c7..e5fe9152f4ce4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -179,6 +179,16 @@ Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset,
   return getValueOrCreateConstantIndexOp(builder, loc, index);
 }
 
+Value computeInclusiveLinearIndex(OpBuilder &builder, Location loc,
+                                  OpFoldResult offset,
+                                  ArrayRef<OpFoldResult> strides,
+                                  ArrayRef<OpFoldResult> indices) {
+  auto [expr, values] = computeInclusiveLinearIndex(offset, strides, indices);
+  auto index =
+      affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
+  return getValueOrCreateConstantIndexOp(builder, loc, index);
+}
+
 /// Returns two Values representing the bounds of the provided strided layout
 /// metadata. The bounds are returned as a half open interval -- [low, high).
 std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
@@ -192,6 +202,17 @@ std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
   return {lowerBound, upperBound};
 }
 
+std::pair<Value, Value> computeLinearInclusiveBounds(
+    OpBuilder &builder, Location loc, OpFoldResult offset,
+    ArrayRef<OpFoldResult> strides, ArrayRef<OpFoldResult> sizes) {
+  auto zeros = SmallVector<int64_t>(sizes.size(), 0);
+  auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros);
+  auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices);
+  auto upperBound =
+      computeInclusiveLinearIndex(builder, loc, offset, strides, sizes);
+  return {lowerBound, upperBound};
+}
+
 /// Returns two Values representing the bounds of the memref. The bounds are
 /// returned as a half open interval -- [low, high).
 std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
@@ -203,6 +224,18 @@ std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
   return computeLinearBounds(builder, loc, offset, strides, sizes);
 }
 
+/// Returns two Values representing the bounds of the memref. The bounds are
+/// returned as a half open interval -- [low, high].
+std::pair<Value, Value>
+computeLinearInclusiveBounds(OpBuilder &builder, Location loc,
+                             TypedValue<BaseMemRefType> memref) {
+  auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref);
+  auto offset = runtimeMetadata.getConstifiedMixedOffset();
+  auto strides = runtimeMetadata.getConstifiedMixedStrides();
+  auto sizes = runtimeMetadata.getConstifiedMixedSizes();
+  return computeLinearInclusiveBounds(builder, loc, offset, strides, sizes);
+}
+
 /// Verifies that the linear bounds of a reinterpret_cast op are within the
 /// linear bounds of the base memref: low >= baseLow && high <= baseHigh
 struct ReinterpretCastOpInterface
@@ -221,15 +254,15 @@ struct ReinterpretCastOpInterface
     auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
 
     // Compute the linear bounds of the resulting memref
-    auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
+    auto [low, high] = computeLinearInclusiveBounds(builder, loc, resultMemref);
 
     // Check low >= baseLow
     auto geLow = builder.createOrFold<arith::CmpIOp>(
         loc, arith::CmpIPredicate::sge, low, baseLow);
 
-    // Check high <= baseHigh
+    // Check high < baseHigh
     auto leHigh = builder.createOrFold<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::sle, high, baseHigh);
+        loc, arith::CmpIPredicate::slt, high, baseHigh);
 
     auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
 
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index aba225be720c3..c8e56e318bfc9 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -305,6 +305,40 @@ mlir::computeLinearIndex(OpFoldResult sourceOffset,
   return {expr, values};
 }
 
+std::pair<AffineExpr, SmallVector<OpFoldResult>>
+mlir::computeInclusiveLinearIndex(OpFoldResult sourceOffset,
+                                  ArrayRef<OpFoldResult> strides,
+                                  ArrayRef<OpFoldResult> indices) {
+  assert(strides.size() == indices.size());
+  auto sourceRank = static_cast<unsigned>(strides.size());
+
+  // Hold the affine symbols and values for the computation of the offset.
+  SmallVector<OpFoldResult> values(2 * sourceRank + 1);
+  SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
+
+  bindSymbolsList(getContext(sourceOffset), MutableArrayRef{symbols});
+  AffineExpr expr = symbols.front();
+  AffineExpr constOneExpr = getAffineConstantExpr(1, getContext(sourceOffset));
+  values[0] = sourceOffset;
+
+  for (unsigned i = 0; i < sourceRank; ++i) {
+    // Compute the stride.
+    OpFoldResult origStride = strides[i];
+
+    // Build up the computation of the offset.
+    unsigned baseIdxForDim = 1 + 2 * i;
+    unsigned subOffsetForDim = baseIdxForDim;
+    unsigned origStrideForDim = baseIdxForDim + 1;
+    // Subtract 1 from the index to get the inclusive bound
+    expr = expr + (symbols[subOffsetForDim] - constOneExpr) *
+                      symbols[origStrideForDim];
+    values[subOffsetForDim] = indices[i];
+    values[origStrideForDim] = origStride;
+  }
+
+  return {expr, values};
+}
+
 std::pair<AffineExpr, SmallVector<OpFoldResult>>
 mlir::computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides,
                          ArrayRef<Value> indices) {
diff --git a/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir
index 2239ba50b6626..5d1a945cc5d44 100644
--- a/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir
@@ -26,6 +26,11 @@ func.func @reinterpret_cast_fully_dynamic(%memref: memref<?xf32>, %offset: index
     return
 }
 
+func.func @reinterpret_cast_upper_bound(%arg0: memref<768xf32>) -> (memref<12x64xf32>) {
+  %reinterpret_result = memref.reinterpret_cast %arg0 to offset: [0], sizes: [12, 64], strides: [64, 1] : memref<768xf32> to memref<12x64xf32>
+  return %reinterpret_result : memref<12x64xf32>
+}
+
 func.func @main() {
   %0 = arith.constant 0 : index
   %1 = arith.constant 1 : index
@@ -34,6 +39,7 @@ func.func @main() {
   %5 = arith.constant 5 : index
 
   %alloca_1 = memref.alloca() : memref<1xf32>
+  %alloca_5 = memref.alloca() : memref<768xf32>
   %alloca_4 = memref.alloca() : memref<4xf32>
   %alloca_4_dyn = memref.cast %alloca_4 : memref<4xf32> to memref<?xf32>
 
@@ -71,5 +77,9 @@ func.func @main() {
   //  CHECK-NOT: ERROR: Runtime op verification failed
   func.call @reinterpret_cast_fully_dynamic(%alloca_4_dyn, %0, %4, %1) : (memref<?xf32>, index, index, index) -> ()
 
+  // upper bound valid
+  //    CHECK-NOT: ERROR: Runtime op verification failed
+  //func.call @reinterpret_cast_upper_bound(%alloca_5) : (memref<768xf32>) -> (memref<12x64xf32>)
+
   return
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/96580


More information about the Mlir-commits mailing list