[Mlir-commits] [mlir] c9b3638 - [mlir][scf][bufferize] Fix bufferizesToMemoryRead with 0 loop iterations
Matthias Springer
llvmlistbot at llvm.org
Mon Oct 24 05:34:50 PDT 2022
Author: Matthias Springer
Date: 2022-10-24T14:34:41+02:00
New Revision: c9b3638126e520917ad42d3ec38ad31fd389e4b5
URL: https://github.com/llvm/llvm-project/commit/c9b3638126e520917ad42d3ec38ad31fd389e4b5
DIFF: https://github.com/llvm/llvm-project/commit/c9b3638126e520917ad42d3ec38ad31fd389e4b5.diff
LOG: [mlir][scf][bufferize] Fix bufferizesToMemoryRead with 0 loop iterations
There was a bug in scf.for loop bufferization that could lead to a missing buffer copy (alloc was there, but not the copy).
Differential Revision: https://reviews.llvm.org/D135053
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir
mlir/test/Dialect/SCF/one-shot-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 6c54ffea6e3ea..2771857766612 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -454,6 +454,15 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
yieldedRanked.getMemorySpaceAsInt());
}
+/// Return `true` if the given loop may have 0 iterations.
+bool mayHaveZeroIterations(scf::ForOp forOp) {
+ Optional<int64_t> lb = getConstantIntValue(forOp.getLowerBound());
+ Optional<int64_t> ub = getConstantIntValue(forOp.getUpperBound());
+ if (!lb.has_value() || !ub.has_value())
+ return true;
+ return *ub <= *lb;
+}
+
/// Bufferization of scf.for. Replace with a new scf.for that operates on
/// memrefs.
struct ForOpInterface
@@ -461,9 +470,15 @@ struct ForOpInterface
scf::ForOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
+ auto forOp = cast<scf::ForOp>(op);
+
+ // If the loop has zero iterations, the results of the op are their
+ // corresponding init_args, meaning that the init_args bufferize to a read.
+ if (mayHaveZeroIterations(forOp))
+ return true;
+
// scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
// its matching bbArg may.
- auto forOp = cast<scf::ForOp>(op);
return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
}
@@ -1039,6 +1054,19 @@ struct YieldOpInterface
}
};
+/// Return `true` if the given loop may have 0 iterations.
+bool mayHaveZeroIterations(scf::ForeachThreadOp foreachThreadOp) {
+ int64_t p = 1;
+ for (Value v : foreachThreadOp.getNumThreads()) {
+ if (Optional<int64_t> c = getConstantIntValue(v)) {
+ p *= *c;
+ } else {
+ return true;
+ }
+ }
+ return p == 0;
+}
+
/// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the
/// region. There are op interfaces for the terminators (PerformConcurrentlyOp
/// and ParallelInsertSliceOp), but these are only used during analysis. Not
@@ -1048,9 +1076,16 @@ struct ForeachThreadOpInterface
ForeachThreadOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
+ auto foreachThreadOp = cast<ForeachThreadOp>(op);
+
+ // If the loop has zero iterations, the results of the op are their
+ // corresponding shared_outs, meaning that the shared_outs bufferize to a
+ // read.
+ if (mayHaveZeroIterations(foreachThreadOp))
+ return true;
+
// scf::ForeachThreadOp alone doesn't bufferize to a memory read, one of the
// uses of its matching bbArg may.
- auto foreachThreadOp = cast<ForeachThreadOp>(op);
return state.isValueRead(foreachThreadOp.getTiedBlockArgument(&opOperand));
}
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir
index fa651a6df408f..69c4ef4f3166f 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir
@@ -177,6 +177,7 @@ func.func @non_reading_scf_for(%t1: tensor<?xf32> {bufferization.writable = true
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
%cst = arith.constant 0.0 : f32
// Write to %t1.
@@ -186,7 +187,7 @@ func.func @non_reading_scf_for(%t1: tensor<?xf32> {bufferization.writable = true
// This loop does not read from %t1. It only writes to it.
// CHECK: scf.for
- %r, %v3 = scf.for %i = %c0 to %s step %c1 iter_args(%t2 = %t1, %v0 = %v) -> (tensor<?xf32>, vector<5xf32>) {
+ %r, %v3 = scf.for %i = %c0 to %c10 step %c1 iter_args(%t2 = %t1, %v0 = %v) -> (tensor<?xf32>, vector<5xf32>) {
// Write to %t1 via %t2. (Overwrite %t3.)
// CHECK: linalg.generic
// CHECK-SAME: __inplace_operands_attr__ = ["true"]
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 6d3dbeee1c8e3..3640f82a25a96 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -38,6 +38,32 @@ func.func @scf_for_yield_only(
// -----
+// CHECK-LABEL: func @scf_for_is_reading(
+// CHECK-SAME: %[[A:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[B:.*]]: memref<?xf32, strided<[?], offset: ?>>
+func.func @scf_for_is_reading(%A : tensor<?xf32>, %B : tensor<?xf32>,
+ %lb : index, %ub : index)
+ -> (f32, f32)
+{
+ %c1 = arith.constant 1 : index
+ %cst = arith.constant 0.0 : f32
+
+ // This is a regression test to make sure that an alloc + copy is emitted.
+
+ // CHECK: %[[alloc:.*]] = memref.alloc
+ // CHECK: memref.copy %[[A]], %[[alloc]]
+ // CHECK: %[[clone:.*]] = bufferization.clone %[[alloc]]
+ // CHECK: scf.for {{.*}} iter_args(%{{.*}} = %[[clone]])
+ %0 = scf.for %iv = %lb to %ub step %c1 iter_args(%1 = %A) -> tensor<?xf32> {
+ %r = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf32>) -> tensor<?xf32>
+ scf.yield %B : tensor<?xf32>
+ }
+ %1 = tensor.extract %0[%c1] : tensor<?xf32>
+ %2 = tensor.extract %A[%c1] : tensor<?xf32>
+ return %1, %2 : f32, f32
+}
+
+// -----
+
// Ensure that the function bufferizes without error. This tests pre-order
// traversal of scf.for loops during bufferization. No need to check the IR,
// just want to make sure that it does not crash.
More information about the Mlir-commits
mailing list