[Mlir-commits] [mlir] [MLIR] Incorrect result for RuntimeVerifiableOpInterface on MemRef::R… (PR #96580)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 24 18:05:46 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
Author: None (hmalgewatta)
<details>
<summary>Changes</summary>
…einterpretCastOpInterface
Fixes issue where the upper bound of a resulting reinterpret cast operation is wrongly calculated
Fix calcuates the inclusive upper bound by subtracting one from each dimensions size and then using that for calculation
Adds a test case
Fixes: #<!-- -->94864
---
Full diff: https://github.com/llvm/llvm-project/pull/96580.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Utils/IndexingUtils.h (+4)
- (modified) mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp (+36-3)
- (modified) mlir/lib/Dialect/Utils/IndexingUtils.cpp (+34)
- (modified) mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir (+10)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index b774359552aa5..2034872359ebe 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -260,6 +260,10 @@ computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<OpFoldResult> strides,
std::pair<AffineExpr, SmallVector<OpFoldResult>>
computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides,
ArrayRef<Value> indices);
+std::pair<AffineExpr, SmallVector<OpFoldResult>>
+computeInclusiveLinearIndex(OpFoldResult sourceOffset,
+ ArrayRef<OpFoldResult> strides,
+ ArrayRef<OpFoldResult> indices);
//===----------------------------------------------------------------------===//
// Utilities for decomposing larger shapes
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 450bfa0cec0c7..e5fe9152f4ce4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -179,6 +179,16 @@ Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset,
return getValueOrCreateConstantIndexOp(builder, loc, index);
}
+Value computeInclusiveLinearIndex(OpBuilder &builder, Location loc,
+ OpFoldResult offset,
+ ArrayRef<OpFoldResult> strides,
+ ArrayRef<OpFoldResult> indices) {
+ auto [expr, values] = computeInclusiveLinearIndex(offset, strides, indices);
+ auto index =
+ affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
+ return getValueOrCreateConstantIndexOp(builder, loc, index);
+}
+
/// Returns two Values representing the bounds of the provided strided layout
/// metadata. The bounds are returned as a half open interval -- [low, high).
std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
@@ -192,6 +202,17 @@ std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
return {lowerBound, upperBound};
}
+std::pair<Value, Value> computeLinearInclusiveBounds(
+ OpBuilder &builder, Location loc, OpFoldResult offset,
+ ArrayRef<OpFoldResult> strides, ArrayRef<OpFoldResult> sizes) {
+ auto zeros = SmallVector<int64_t>(sizes.size(), 0);
+ auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros);
+ auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices);
+ auto upperBound =
+ computeInclusiveLinearIndex(builder, loc, offset, strides, sizes);
+ return {lowerBound, upperBound};
+}
+
/// Returns two Values representing the bounds of the memref. The bounds are
/// returned as a half open interval -- [low, high).
std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
@@ -203,6 +224,18 @@ std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
return computeLinearBounds(builder, loc, offset, strides, sizes);
}
+/// Returns two Values representing the bounds of the memref. The bounds are
+/// returned as a half open interval -- [low, high].
+std::pair<Value, Value>
+computeLinearInclusiveBounds(OpBuilder &builder, Location loc,
+ TypedValue<BaseMemRefType> memref) {
+ auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref);
+ auto offset = runtimeMetadata.getConstifiedMixedOffset();
+ auto strides = runtimeMetadata.getConstifiedMixedStrides();
+ auto sizes = runtimeMetadata.getConstifiedMixedSizes();
+ return computeLinearInclusiveBounds(builder, loc, offset, strides, sizes);
+}
+
/// Verifies that the linear bounds of a reinterpret_cast op are within the
/// linear bounds of the base memref: low >= baseLow && high <= baseHigh
struct ReinterpretCastOpInterface
@@ -221,15 +254,15 @@ struct ReinterpretCastOpInterface
auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
// Compute the linear bounds of the resulting memref
- auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
+ auto [low, high] = computeLinearInclusiveBounds(builder, loc, resultMemref);
// Check low >= baseLow
auto geLow = builder.createOrFold<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, low, baseLow);
- // Check high <= baseHigh
+ // Check high < baseHigh
auto leHigh = builder.createOrFold<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sle, high, baseHigh);
+ loc, arith::CmpIPredicate::slt, high, baseHigh);
auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index aba225be720c3..c8e56e318bfc9 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -305,6 +305,40 @@ mlir::computeLinearIndex(OpFoldResult sourceOffset,
return {expr, values};
}
+std::pair<AffineExpr, SmallVector<OpFoldResult>>
+mlir::computeInclusiveLinearIndex(OpFoldResult sourceOffset,
+ ArrayRef<OpFoldResult> strides,
+ ArrayRef<OpFoldResult> indices) {
+ assert(strides.size() == indices.size());
+ auto sourceRank = static_cast<unsigned>(strides.size());
+
+ // Hold the affine symbols and values for the computation of the offset.
+ SmallVector<OpFoldResult> values(2 * sourceRank + 1);
+ SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
+
+ bindSymbolsList(getContext(sourceOffset), MutableArrayRef{symbols});
+ AffineExpr expr = symbols.front();
+ AffineExpr constOneExpr = getAffineConstantExpr(1, getContext(sourceOffset));
+ values[0] = sourceOffset;
+
+ for (unsigned i = 0; i < sourceRank; ++i) {
+ // Compute the stride.
+ OpFoldResult origStride = strides[i];
+
+ // Build up the computation of the offset.
+ unsigned baseIdxForDim = 1 + 2 * i;
+ unsigned subOffsetForDim = baseIdxForDim;
+ unsigned origStrideForDim = baseIdxForDim + 1;
+ // Subtract 1 from the index to get the inclusive bound
+ expr = expr + (symbols[subOffsetForDim] - constOneExpr) *
+ symbols[origStrideForDim];
+ values[subOffsetForDim] = indices[i];
+ values[origStrideForDim] = origStride;
+ }
+
+ return {expr, values};
+}
+
std::pair<AffineExpr, SmallVector<OpFoldResult>>
mlir::computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides,
ArrayRef<Value> indices) {
diff --git a/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir
index 2239ba50b6626..5d1a945cc5d44 100644
--- a/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir
@@ -26,6 +26,11 @@ func.func @reinterpret_cast_fully_dynamic(%memref: memref<?xf32>, %offset: index
return
}
+func.func @reinterpret_cast_upper_bound(%arg0: memref<768xf32>) -> (memref<12x64xf32>) {
+ %reinterpret_result = memref.reinterpret_cast %arg0 to offset: [0], sizes: [12, 64], strides: [64, 1] : memref<768xf32> to memref<12x64xf32>
+ return %reinterpret_result : memref<12x64xf32>
+}
+
func.func @main() {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : index
@@ -34,6 +39,7 @@ func.func @main() {
%5 = arith.constant 5 : index
%alloca_1 = memref.alloca() : memref<1xf32>
+ %alloca_5 = memref.alloca() : memref<768xf32>
%alloca_4 = memref.alloca() : memref<4xf32>
%alloca_4_dyn = memref.cast %alloca_4 : memref<4xf32> to memref<?xf32>
@@ -71,5 +77,9 @@ func.func @main() {
// CHECK-NOT: ERROR: Runtime op verification failed
func.call @reinterpret_cast_fully_dynamic(%alloca_4_dyn, %0, %4, %1) : (memref<?xf32>, index, index, index) -> ()
+ // upper bound valid
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ //func.call @reinterpret_cast_upper_bound(%alloca_5) : (memref<768xf32>) -> (memref<12x64xf32>)
+
return
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/96580
More information about the Mlir-commits
mailing list