[Mlir-commits] [mlir] [mlir] Do not bufferize parallel_insert_slice dest to read for full slices (PR #112761)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 18 07:36:10 PDT 2024
https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/112761
>From 4826e05651ff7a88cb57e54e39c47a8676e125ea Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Wed, 16 Oct 2024 10:42:56 -0500
Subject: [PATCH 1/2] [mlir] Do not bufferize parallel_insert_slice dest to
read for full slices
Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
.../BufferizableOpInterfaceImpl.cpp | 60 ++++++++++---------
.../Dialect/Tensor/one-shot-bufferize.mlir | 15 +++++
2 files changed, 48 insertions(+), 27 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 87464ccb71720d..def4ee93854a1a 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
@@ -636,6 +637,34 @@ struct InsertOpInterface
}
};
+template <typename InsertOpTy>
+static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
+ OpOperand &opOperand) {
+ RankedTensorType destType = insertSliceOp.getDestType();
+
+ // The source is always read.
+ if (opOperand == insertSliceOp.getSourceMutable())
+ return true;
+
+ // For the destination, it depends...
+ assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
+
+ // Dest is not read if it is entirely overwritten. E.g.:
+ // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
+ bool allOffsetsZero =
+ llvm::all_of(insertSliceOp.getMixedOffsets(),
+ [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
+ bool sizesMatchDestSizes = llvm::all_of(
+ llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) {
+ return getConstantIntValue(it.value()) ==
+ destType.getDimSize(it.index());
+ });
+ bool allStridesOne =
+ llvm::all_of(insertSliceOp.getMixedStrides(),
+ [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
+ return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
+}
+
/// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
/// certain circumstances, this op can also be a no-op.
///
@@ -646,32 +675,8 @@ struct InsertSliceOpInterface
tensor::InsertSliceOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
- auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
- RankedTensorType destType = insertSliceOp.getDestType();
-
- // The source is always read.
- if (opOperand == insertSliceOp.getSourceMutable())
- return true;
-
- // For the destination, it depends...
- assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
-
- // Dest is not read if it is entirely overwritten. E.g.:
- // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
- bool allOffsetsZero =
- llvm::all_of(insertSliceOp.getMixedOffsets(), [](OpFoldResult ofr) {
- return isConstantIntValue(ofr, 0);
- });
- bool sizesMatchDestSizes = llvm::all_of(
- llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) {
- return getConstantIntValue(it.value()) ==
- destType.getDimSize(it.index());
- });
- bool allStridesOne =
- llvm::all_of(insertSliceOp.getMixedStrides(), [](OpFoldResult ofr) {
- return isConstantIntValue(ofr, 1);
- });
- return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
+ return insertSliceOpRequiresRead(cast<tensor::InsertSliceOp>(op),
+ opOperand);
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -931,7 +936,8 @@ struct ParallelInsertSliceOpInterface
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
- return true;
+ return insertSliceOpRequiresRead(cast<tensor::ParallelInsertSliceOp>(op),
+ opOperand);
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index e2169fe1404c82..dc4306b8316ab7 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -213,6 +213,21 @@ func.func @rank_reducing_parallel_insert_slice(%in: tensor<100xf32>, %out: tenso
// -----
+// CHECK-LABEL: func.func @parallel_insert_full_slice_in_place
+// CHECK-NOT: memref.alloc()
+func.func @parallel_insert_full_slice_in_place(%2: tensor<2xf32>) -> tensor<2xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %3 = scf.forall (%arg0) in (1) shared_outs(%arg2 = %2) -> (tensor<2xf32>) {
+ %fill = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<2xf32>) -> tensor<2xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %fill into %arg2[0] [2] [1] : tensor<2xf32> into tensor<2xf32>
+ }
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+ return %3 : tensor<2xf32>
+}
+
+// -----
+
// This test case could bufferize in-place with a better analysis. However, it
// is simpler to let the canonicalizer fold away the tensor.insert_slice.
>From 5a0fa0ee9d06785a50f0876b0fdfbab41a1ad41b Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 18 Oct 2024 10:35:44 -0400
Subject: [PATCH 2/2] Add additional static value utils
Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
.../mlir/Dialect/Utils/StaticValueUtils.h | 6 ++++++
.../Transforms/BufferizableOpInterfaceImpl.cpp | 16 +++++-----------
.../Tensor/Transforms/PackAndUnpackPatterns.cpp | 5 -----
mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 17 +++++++++++++++++
4 files changed, 28 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index ba4f084d3efd1a..4d7aa1ae17fdb1 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -92,6 +92,12 @@ getConstantIntValues(ArrayRef<OpFoldResult> ofrs);
/// Return true if `ofr` is constant integer equal to `value`.
bool isConstantIntValue(OpFoldResult ofr, int64_t value);
+/// Return true if all of `ofrs` are constant integers equal to `value`.
+bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value);
+/// Return true if all of `ofrs` are constant integers equal to the
+/// corresponding value in `values`.
+bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs,
+ ArrayRef<int64_t> values);
/// Return true if ofr1 and ofr2 are the same integer constant attribute
/// values or the same SSA value. Ignore integer bitwitdh and type mismatch
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index def4ee93854a1a..c2b8614148bf25 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -640,8 +640,6 @@ struct InsertOpInterface
template <typename InsertOpTy>
static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
OpOperand &opOperand) {
- RankedTensorType destType = insertSliceOp.getDestType();
-
// The source is always read.
if (opOperand == insertSliceOp.getSourceMutable())
return true;
@@ -652,16 +650,12 @@ static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
// Dest is not read if it is entirely overwritten. E.g.:
// tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
bool allOffsetsZero =
- llvm::all_of(insertSliceOp.getMixedOffsets(),
- [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
- bool sizesMatchDestSizes = llvm::all_of(
- llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) {
- return getConstantIntValue(it.value()) ==
- destType.getDimSize(it.index());
- });
+ llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroIndex);
+ RankedTensorType destType = insertSliceOp.getDestType();
+ bool sizesMatchDestSizes =
+ areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape());
bool allStridesOne =
- llvm::all_of(insertSliceOp.getMixedStrides(),
- [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
+ areAllConstantIntValue(insertSliceOp.getMixedStrides(), 1);
return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index 995486c87771a3..3566714c6529e3 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -16,11 +16,6 @@ namespace mlir {
namespace tensor {
namespace {
-static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
- return llvm::all_of(
- ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
-}
-
/// Returns the number of shape sizes that is either dynamic or greater than 1.
static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
return llvm::count_if(
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 547d120404aba3..3e8e63938a42bb 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -10,6 +10,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/APSInt.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/MathExtras.h"
namespace mlir {
@@ -137,6 +138,22 @@ bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
return val && *val == value;
}
+/// Return true if all of `ofrs` are constant integers equal to `value`.
+bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
+ return llvm::all_of(
+ ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
+}
+
+/// Return true if all of `ofrs` are constant integers equal to the
+/// corresponding value in `values`.
+bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs,
+ ArrayRef<int64_t> values) {
+ if (ofrs.size() != values.size())
+ return false;
+ std::optional<SmallVector<int64_t>> constOfrs = getConstantIntValues(ofrs);
+ return constOfrs && llvm::equal(constOfrs.value(), values);
+}
+
/// Return true if ofr1 and ofr2 are the same integer constant attribute values
/// or the same SSA value.
/// Ignore integer bitwidth and type mismatch that come from the fact there is
More information about the Mlir-commits
mailing list