[Mlir-commits] [mlir] [MLIR] Incorrect result for RuntimeVerifiableOpInterface on MemRef::R… (PR #96580)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 25 10:23:39 PDT 2024
https://github.com/hmalgewatta updated https://github.com/llvm/llvm-project/pull/96580
>From b9b97a82572a2a4151b0b09de4fa29d07001ce44 Mon Sep 17 00:00:00 2001
From: Hasitha Algewatta <hasithaalgewatta at Hasithas-MacBook-Pro.local>
Date: Mon, 24 Jun 2024 20:44:23 -0400
Subject: [PATCH 1/2] [MLIR] Incorrect result for RuntimeVerifiableOpInterface
on MemRef::ReinterpretCastOpInterface
Fixes issue where the upper bound of a resulting reinterpret cast operation is wrongly calculated
Fix calcuates the inclusive upper bound by subtracting one from each dimensions size and then using
that for calculation
Adds a test case
Fixes: #94864
---
.../mlir/Dialect/Utils/IndexingUtils.h | 4 ++
.../Transforms/RuntimeOpVerification.cpp | 39 +++++++++++++++++--
mlir/lib/Dialect/Utils/IndexingUtils.cpp | 34 ++++++++++++++++
...reinterpret-cast-runtime-verification.mlir | 10 +++++
4 files changed, 84 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index b774359552aa5..2034872359ebe 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -260,6 +260,10 @@ computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<OpFoldResult> strides,
std::pair<AffineExpr, SmallVector<OpFoldResult>>
computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides,
ArrayRef<Value> indices);
+std::pair<AffineExpr, SmallVector<OpFoldResult>>
+computeInclusiveLinearIndex(OpFoldResult sourceOffset,
+ ArrayRef<OpFoldResult> strides,
+ ArrayRef<OpFoldResult> indices);
//===----------------------------------------------------------------------===//
// Utilities for decomposing larger shapes
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 450bfa0cec0c7..e5fe9152f4ce4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -179,6 +179,16 @@ Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset,
return getValueOrCreateConstantIndexOp(builder, loc, index);
}
+Value computeInclusiveLinearIndex(OpBuilder &builder, Location loc,
+ OpFoldResult offset,
+ ArrayRef<OpFoldResult> strides,
+ ArrayRef<OpFoldResult> indices) {
+ auto [expr, values] = computeInclusiveLinearIndex(offset, strides, indices);
+ auto index =
+ affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
+ return getValueOrCreateConstantIndexOp(builder, loc, index);
+}
+
/// Returns two Values representing the bounds of the provided strided layout
/// metadata. The bounds are returned as a half open interval -- [low, high).
std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
@@ -192,6 +202,17 @@ std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
return {lowerBound, upperBound};
}
+std::pair<Value, Value> computeLinearInclusiveBounds(
+ OpBuilder &builder, Location loc, OpFoldResult offset,
+ ArrayRef<OpFoldResult> strides, ArrayRef<OpFoldResult> sizes) {
+ auto zeros = SmallVector<int64_t>(sizes.size(), 0);
+ auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros);
+ auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices);
+ auto upperBound =
+ computeInclusiveLinearIndex(builder, loc, offset, strides, sizes);
+ return {lowerBound, upperBound};
+}
+
/// Returns two Values representing the bounds of the memref. The bounds are
/// returned as a half open interval -- [low, high).
std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
@@ -203,6 +224,18 @@ std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
return computeLinearBounds(builder, loc, offset, strides, sizes);
}
+/// Returns two Values representing the bounds of the memref. The bounds are
+/// returned as a half open interval -- [low, high].
+std::pair<Value, Value>
+computeLinearInclusiveBounds(OpBuilder &builder, Location loc,
+ TypedValue<BaseMemRefType> memref) {
+ auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref);
+ auto offset = runtimeMetadata.getConstifiedMixedOffset();
+ auto strides = runtimeMetadata.getConstifiedMixedStrides();
+ auto sizes = runtimeMetadata.getConstifiedMixedSizes();
+ return computeLinearInclusiveBounds(builder, loc, offset, strides, sizes);
+}
+
/// Verifies that the linear bounds of a reinterpret_cast op are within the
/// linear bounds of the base memref: low >= baseLow && high <= baseHigh
struct ReinterpretCastOpInterface
@@ -221,15 +254,15 @@ struct ReinterpretCastOpInterface
auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
// Compute the linear bounds of the resulting memref
- auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
+ auto [low, high] = computeLinearInclusiveBounds(builder, loc, resultMemref);
// Check low >= baseLow
auto geLow = builder.createOrFold<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, low, baseLow);
- // Check high <= baseHigh
+ // Check high < baseHigh
auto leHigh = builder.createOrFold<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sle, high, baseHigh);
+ loc, arith::CmpIPredicate::slt, high, baseHigh);
auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index aba225be720c3..c8e56e318bfc9 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -305,6 +305,40 @@ mlir::computeLinearIndex(OpFoldResult sourceOffset,
return {expr, values};
}
+std::pair<AffineExpr, SmallVector<OpFoldResult>>
+mlir::computeInclusiveLinearIndex(OpFoldResult sourceOffset,
+ ArrayRef<OpFoldResult> strides,
+ ArrayRef<OpFoldResult> indices) {
+ assert(strides.size() == indices.size());
+ auto sourceRank = static_cast<unsigned>(strides.size());
+
+ // Hold the affine symbols and values for the computation of the offset.
+ SmallVector<OpFoldResult> values(2 * sourceRank + 1);
+ SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
+
+ bindSymbolsList(getContext(sourceOffset), MutableArrayRef{symbols});
+ AffineExpr expr = symbols.front();
+ AffineExpr constOneExpr = getAffineConstantExpr(1, getContext(sourceOffset));
+ values[0] = sourceOffset;
+
+ for (unsigned i = 0; i < sourceRank; ++i) {
+ // Compute the stride.
+ OpFoldResult origStride = strides[i];
+
+ // Build up the computation of the offset.
+ unsigned baseIdxForDim = 1 + 2 * i;
+ unsigned subOffsetForDim = baseIdxForDim;
+ unsigned origStrideForDim = baseIdxForDim + 1;
+ // Subtract 1 from the index to get the inclusive bound
+ expr = expr + (symbols[subOffsetForDim] - constOneExpr) *
+ symbols[origStrideForDim];
+ values[subOffsetForDim] = indices[i];
+ values[origStrideForDim] = origStride;
+ }
+
+ return {expr, values};
+}
+
std::pair<AffineExpr, SmallVector<OpFoldResult>>
mlir::computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides,
ArrayRef<Value> indices) {
diff --git a/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir
index 2239ba50b6626..5d1a945cc5d44 100644
--- a/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir
@@ -26,6 +26,11 @@ func.func @reinterpret_cast_fully_dynamic(%memref: memref<?xf32>, %offset: index
return
}
+func.func @reinterpret_cast_upper_bound(%arg0: memref<768xf32>) -> (memref<12x64xf32>) {
+ %reinterpret_result = memref.reinterpret_cast %arg0 to offset: [0], sizes: [12, 64], strides: [64, 1] : memref<768xf32> to memref<12x64xf32>
+ return %reinterpret_result : memref<12x64xf32>
+}
+
func.func @main() {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : index
@@ -34,6 +39,7 @@ func.func @main() {
%5 = arith.constant 5 : index
%alloca_1 = memref.alloca() : memref<1xf32>
+ %alloca_5 = memref.alloca() : memref<768xf32>
%alloca_4 = memref.alloca() : memref<4xf32>
%alloca_4_dyn = memref.cast %alloca_4 : memref<4xf32> to memref<?xf32>
@@ -71,5 +77,9 @@ func.func @main() {
// CHECK-NOT: ERROR: Runtime op verification failed
func.call @reinterpret_cast_fully_dynamic(%alloca_4_dyn, %0, %4, %1) : (memref<?xf32>, index, index, index) -> ()
+ // upper bound valid
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ //func.call @reinterpret_cast_upper_bound(%alloca_5) : (memref<768xf32>) -> (memref<12x64xf32>)
+
return
}
>From bca2bdb33dcf742a387230cbfcabf0d9369002e7 Mon Sep 17 00:00:00 2001
From: Hasitha Algewatta <hasithaalgewatta at Hasithas-MacBook-Pro.local>
Date: Tue, 25 Jun 2024 13:23:10 -0400
Subject: [PATCH 2/2] Adds class comment, refactors variable names and uses
computeInclusiveLinearIndex for basememref, and directly uses int64_t instead
of separate AffineExpr
---
mlir/include/mlir/Dialect/Utils/IndexingUtils.h | 3 +++
.../Dialect/MemRef/Transforms/RuntimeOpVerification.cpp | 9 +++++----
mlir/lib/Dialect/Utils/IndexingUtils.cpp | 8 ++++----
3 files changed, 12 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 2034872359ebe..8eec67e3bc3b4 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -260,6 +260,9 @@ computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<OpFoldResult> strides,
std::pair<AffineExpr, SmallVector<OpFoldResult>>
computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides,
ArrayRef<Value> indices);
+/// Compute linear index from provided strides and indices, assuming strided
+/// Unlike the above, this version computes the inclusive linear index by
+/// subtracting 1 from each dimension size
std::pair<AffineExpr, SmallVector<OpFoldResult>>
computeInclusiveLinearIndex(OpFoldResult sourceOffset,
ArrayRef<OpFoldResult> strides,
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index e5fe9152f4ce4..63344b35e855f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -237,7 +237,7 @@ computeLinearInclusiveBounds(OpBuilder &builder, Location loc,
}
/// Verifies that the linear bounds of a reinterpret_cast op are within the
-/// linear bounds of the base memref: low >= baseLow && high <= baseHigh
+/// linear bounds of the base memref: low >= baseLow && high <= baseHigh.
struct ReinterpretCastOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
ReinterpretCastOpInterface, ReinterpretCastOp> {
@@ -251,7 +251,8 @@ struct ReinterpretCastOpInterface
builder.setInsertionPointAfter(op);
// Compute the linear bounds of the base memref
- auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
+ auto [baseLow, baseHigh] =
+ computeLinearInclusiveBounds(builder, loc, baseMemref);
// Compute the linear bounds of the resulting memref
auto [low, high] = computeLinearInclusiveBounds(builder, loc, resultMemref);
@@ -260,9 +261,9 @@ struct ReinterpretCastOpInterface
auto geLow = builder.createOrFold<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, low, baseLow);
- // Check high < baseHigh
+ // Check high <= baseHigh
auto leHigh = builder.createOrFold<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, high, baseHigh);
+ loc, arith::CmpIPredicate::sle, high, baseHigh);
auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index c8e56e318bfc9..3d0e7866116fd 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -318,7 +318,6 @@ mlir::computeInclusiveLinearIndex(OpFoldResult sourceOffset,
bindSymbolsList(getContext(sourceOffset), MutableArrayRef{symbols});
AffineExpr expr = symbols.front();
- AffineExpr constOneExpr = getAffineConstantExpr(1, getContext(sourceOffset));
values[0] = sourceOffset;
for (unsigned i = 0; i < sourceRank; ++i) {
@@ -329,9 +328,10 @@ mlir::computeInclusiveLinearIndex(OpFoldResult sourceOffset,
unsigned baseIdxForDim = 1 + 2 * i;
unsigned subOffsetForDim = baseIdxForDim;
unsigned origStrideForDim = baseIdxForDim + 1;
- // Subtract 1 from the index to get the inclusive bound
- expr = expr + (symbols[subOffsetForDim] - constOneExpr) *
- symbols[origStrideForDim];
+ AffineExpr dimSize = symbols[subOffsetForDim];
+ AffineExpr stride = symbols[origStrideForDim];
+ // Subtract 1 from the dimension size to get the inclusive bound
+ expr = expr + (dimSize - 1) * stride;
values[subOffsetForDim] = indices[i];
values[origStrideForDim] = origStride;
}
More information about the Mlir-commits
mailing list