[Mlir-commits] [mlir] [MLIR] Incorrect result for RuntimeVerifiableOpInterface on MemRef::R… (PR #96580)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 2 14:48:36 PDT 2024


https://github.com/hmalgewatta updated https://github.com/llvm/llvm-project/pull/96580

>From b9b97a82572a2a4151b0b09de4fa29d07001ce44 Mon Sep 17 00:00:00 2001
From: Hasitha Algewatta <hasithaalgewatta at Hasithas-MacBook-Pro.local>
Date: Mon, 24 Jun 2024 20:44:23 -0400
Subject: [PATCH 1/2] [MLIR] Incorrect result for RuntimeVerifiableOpInterface
 on MemRef::ReinterpretCastOpInterface

Fixes issue where the upper bound of a resulting reinterpret cast operation is wrongly calculated

Fix calcuates the inclusive upper bound by subtracting one from each dimensions size and then using
that for calculation

Adds a test case

Fixes:  #94864
---
 .../mlir/Dialect/Utils/IndexingUtils.h        |  4 ++
 .../Transforms/RuntimeOpVerification.cpp      | 39 +++++++++++++++++--
 mlir/lib/Dialect/Utils/IndexingUtils.cpp      | 34 ++++++++++++++++
 ...reinterpret-cast-runtime-verification.mlir | 10 +++++
 4 files changed, 84 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index b774359552aa5..2034872359ebe 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -260,6 +260,10 @@ computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<OpFoldResult> strides,
 std::pair<AffineExpr, SmallVector<OpFoldResult>>
 computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides,
                    ArrayRef<Value> indices);
+std::pair<AffineExpr, SmallVector<OpFoldResult>>
+computeInclusiveLinearIndex(OpFoldResult sourceOffset,
+                            ArrayRef<OpFoldResult> strides,
+                            ArrayRef<OpFoldResult> indices);
 
 //===----------------------------------------------------------------------===//
 // Utilities for decomposing larger shapes
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 450bfa0cec0c7..e5fe9152f4ce4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -179,6 +179,16 @@ Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset,
   return getValueOrCreateConstantIndexOp(builder, loc, index);
 }
 
+Value computeInclusiveLinearIndex(OpBuilder &builder, Location loc,
+                                  OpFoldResult offset,
+                                  ArrayRef<OpFoldResult> strides,
+                                  ArrayRef<OpFoldResult> indices) {
+  auto [expr, values] = computeInclusiveLinearIndex(offset, strides, indices);
+  auto index =
+      affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
+  return getValueOrCreateConstantIndexOp(builder, loc, index);
+}
+
 /// Returns two Values representing the bounds of the provided strided layout
 /// metadata. The bounds are returned as a half open interval -- [low, high).
 std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
@@ -192,6 +202,17 @@ std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
   return {lowerBound, upperBound};
 }
 
+std::pair<Value, Value> computeLinearInclusiveBounds(
+    OpBuilder &builder, Location loc, OpFoldResult offset,
+    ArrayRef<OpFoldResult> strides, ArrayRef<OpFoldResult> sizes) {
+  auto zeros = SmallVector<int64_t>(sizes.size(), 0);
+  auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros);
+  auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices);
+  auto upperBound =
+      computeInclusiveLinearIndex(builder, loc, offset, strides, sizes);
+  return {lowerBound, upperBound};
+}
+
 /// Returns two Values representing the bounds of the memref. The bounds are
 /// returned as a half open interval -- [low, high).
 std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
@@ -203,6 +224,18 @@ std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
   return computeLinearBounds(builder, loc, offset, strides, sizes);
 }
 
+/// Returns two Values representing the bounds of the memref. The bounds are
+/// returned as a half open interval -- [low, high].
+std::pair<Value, Value>
+computeLinearInclusiveBounds(OpBuilder &builder, Location loc,
+                             TypedValue<BaseMemRefType> memref) {
+  auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref);
+  auto offset = runtimeMetadata.getConstifiedMixedOffset();
+  auto strides = runtimeMetadata.getConstifiedMixedStrides();
+  auto sizes = runtimeMetadata.getConstifiedMixedSizes();
+  return computeLinearInclusiveBounds(builder, loc, offset, strides, sizes);
+}
+
 /// Verifies that the linear bounds of a reinterpret_cast op are within the
 /// linear bounds of the base memref: low >= baseLow && high <= baseHigh
 struct ReinterpretCastOpInterface
@@ -221,15 +254,15 @@ struct ReinterpretCastOpInterface
     auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
 
     // Compute the linear bounds of the resulting memref
-    auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
+    auto [low, high] = computeLinearInclusiveBounds(builder, loc, resultMemref);
 
     // Check low >= baseLow
     auto geLow = builder.createOrFold<arith::CmpIOp>(
         loc, arith::CmpIPredicate::sge, low, baseLow);
 
-    // Check high <= baseHigh
+    // Check high < baseHigh
     auto leHigh = builder.createOrFold<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::sle, high, baseHigh);
+        loc, arith::CmpIPredicate::slt, high, baseHigh);
 
     auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
 
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index aba225be720c3..c8e56e318bfc9 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -305,6 +305,40 @@ mlir::computeLinearIndex(OpFoldResult sourceOffset,
   return {expr, values};
 }
 
+std::pair<AffineExpr, SmallVector<OpFoldResult>>
+mlir::computeInclusiveLinearIndex(OpFoldResult sourceOffset,
+                                  ArrayRef<OpFoldResult> strides,
+                                  ArrayRef<OpFoldResult> indices) {
+  assert(strides.size() == indices.size());
+  auto sourceRank = static_cast<unsigned>(strides.size());
+
+  // Hold the affine symbols and values for the computation of the offset.
+  SmallVector<OpFoldResult> values(2 * sourceRank + 1);
+  SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
+
+  bindSymbolsList(getContext(sourceOffset), MutableArrayRef{symbols});
+  AffineExpr expr = symbols.front();
+  AffineExpr constOneExpr = getAffineConstantExpr(1, getContext(sourceOffset));
+  values[0] = sourceOffset;
+
+  for (unsigned i = 0; i < sourceRank; ++i) {
+    // Compute the stride.
+    OpFoldResult origStride = strides[i];
+
+    // Build up the computation of the offset.
+    unsigned baseIdxForDim = 1 + 2 * i;
+    unsigned subOffsetForDim = baseIdxForDim;
+    unsigned origStrideForDim = baseIdxForDim + 1;
+    // Subtract 1 from the index to get the inclusive bound
+    expr = expr + (symbols[subOffsetForDim] - constOneExpr) *
+                      symbols[origStrideForDim];
+    values[subOffsetForDim] = indices[i];
+    values[origStrideForDim] = origStride;
+  }
+
+  return {expr, values};
+}
+
 std::pair<AffineExpr, SmallVector<OpFoldResult>>
 mlir::computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides,
                          ArrayRef<Value> indices) {
diff --git a/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir
index 2239ba50b6626..5d1a945cc5d44 100644
--- a/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir
@@ -26,6 +26,11 @@ func.func @reinterpret_cast_fully_dynamic(%memref: memref<?xf32>, %offset: index
     return
 }
 
+func.func @reinterpret_cast_upper_bound(%arg0: memref<768xf32>) -> (memref<12x64xf32>) {
+  %reinterpret_result = memref.reinterpret_cast %arg0 to offset: [0], sizes: [12, 64], strides: [64, 1] : memref<768xf32> to memref<12x64xf32>
+  return %reinterpret_result : memref<12x64xf32>
+}
+
 func.func @main() {
   %0 = arith.constant 0 : index
   %1 = arith.constant 1 : index
@@ -34,6 +39,7 @@ func.func @main() {
   %5 = arith.constant 5 : index
 
   %alloca_1 = memref.alloca() : memref<1xf32>
+  %alloca_5 = memref.alloca() : memref<768xf32>
   %alloca_4 = memref.alloca() : memref<4xf32>
   %alloca_4_dyn = memref.cast %alloca_4 : memref<4xf32> to memref<?xf32>
 
@@ -71,5 +77,9 @@ func.func @main() {
   //  CHECK-NOT: ERROR: Runtime op verification failed
   func.call @reinterpret_cast_fully_dynamic(%alloca_4_dyn, %0, %4, %1) : (memref<?xf32>, index, index, index) -> ()
 
+  // upper bound valid
+  //    CHECK-NOT: ERROR: Runtime op verification failed
+  //func.call @reinterpret_cast_upper_bound(%alloca_5) : (memref<768xf32>) -> (memref<12x64xf32>)
+
   return
 }

>From bca2bdb33dcf742a387230cbfcabf0d9369002e7 Mon Sep 17 00:00:00 2001
From: Hasitha Algewatta <hasithaalgewatta at Hasithas-MacBook-Pro.local>
Date: Tue, 25 Jun 2024 13:23:10 -0400
Subject: [PATCH 2/2] Adds class comment, refactors variable names and uses
 computeInclusiveLinearIndex for basememref, and directly uses int64_t instead
 of separate AffineExpr

---
 mlir/include/mlir/Dialect/Utils/IndexingUtils.h          | 3 +++
 .../Dialect/MemRef/Transforms/RuntimeOpVerification.cpp  | 9 +++++----
 mlir/lib/Dialect/Utils/IndexingUtils.cpp                 | 8 ++++----
 3 files changed, 12 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 2034872359ebe..8eec67e3bc3b4 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -260,6 +260,9 @@ computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<OpFoldResult> strides,
 std::pair<AffineExpr, SmallVector<OpFoldResult>>
 computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides,
                    ArrayRef<Value> indices);
+/// Compute linear index from provided strides and indices, assuming strided
+/// Unlike the above, this version computes the inclusive linear index by
+/// subtracting 1 from each dimension size
 std::pair<AffineExpr, SmallVector<OpFoldResult>>
 computeInclusiveLinearIndex(OpFoldResult sourceOffset,
                             ArrayRef<OpFoldResult> strides,
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index e5fe9152f4ce4..63344b35e855f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -237,7 +237,7 @@ computeLinearInclusiveBounds(OpBuilder &builder, Location loc,
 }
 
 /// Verifies that the linear bounds of a reinterpret_cast op are within the
-/// linear bounds of the base memref: low >= baseLow && high <= baseHigh
+/// linear bounds of the base memref: low >= baseLow && high <= baseHigh.
 struct ReinterpretCastOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<
           ReinterpretCastOpInterface, ReinterpretCastOp> {
@@ -251,7 +251,8 @@ struct ReinterpretCastOpInterface
     builder.setInsertionPointAfter(op);
 
     // Compute the linear bounds of the base memref
-    auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
+    auto [baseLow, baseHigh] =
+        computeLinearInclusiveBounds(builder, loc, baseMemref);
 
     // Compute the linear bounds of the resulting memref
     auto [low, high] = computeLinearInclusiveBounds(builder, loc, resultMemref);
@@ -260,9 +261,9 @@ struct ReinterpretCastOpInterface
     auto geLow = builder.createOrFold<arith::CmpIOp>(
         loc, arith::CmpIPredicate::sge, low, baseLow);
 
-    // Check high < baseHigh
+    // Check high <= baseHigh
     auto leHigh = builder.createOrFold<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::slt, high, baseHigh);
+        loc, arith::CmpIPredicate::sle, high, baseHigh);
 
     auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
 
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index c8e56e318bfc9..3d0e7866116fd 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -318,7 +318,6 @@ mlir::computeInclusiveLinearIndex(OpFoldResult sourceOffset,
 
   bindSymbolsList(getContext(sourceOffset), MutableArrayRef{symbols});
   AffineExpr expr = symbols.front();
-  AffineExpr constOneExpr = getAffineConstantExpr(1, getContext(sourceOffset));
   values[0] = sourceOffset;
 
   for (unsigned i = 0; i < sourceRank; ++i) {
@@ -329,9 +328,10 @@ mlir::computeInclusiveLinearIndex(OpFoldResult sourceOffset,
     unsigned baseIdxForDim = 1 + 2 * i;
     unsigned subOffsetForDim = baseIdxForDim;
     unsigned origStrideForDim = baseIdxForDim + 1;
-    // Subtract 1 from the index to get the inclusive bound
-    expr = expr + (symbols[subOffsetForDim] - constOneExpr) *
-                      symbols[origStrideForDim];
+    AffineExpr dimSize = symbols[subOffsetForDim];
+    AffineExpr stride = symbols[origStrideForDim];
+    // Subtract 1 from the dimension size to get the inclusive bound
+    expr = expr + (dimSize - 1) * stride;
     values[subOffsetForDim] = indices[i];
     values[origStrideForDim] = origStride;
   }



More information about the Mlir-commits mailing list