[Mlir-commits] [mlir] 98e838a - [mlir] Do not bufferize parallel_insert_slice dest to read for full slices (#112761)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 18 13:02:08 PDT 2024


Author: Max191
Date: 2024-10-18T16:02:03-04:00
New Revision: 98e838a890191b9250ad33741a1c121a9591caa3

URL: https://github.com/llvm/llvm-project/commit/98e838a890191b9250ad33741a1c121a9591caa3
DIFF: https://github.com/llvm/llvm-project/commit/98e838a890191b9250ad33741a1c121a9591caa3.diff

LOG: [mlir] Do not bufferize parallel_insert_slice dest to read for full slices (#112761)

In the insert_slice bufferization interface implementation, the
destination tensor is not considered read if the full tensor is
overwritten by the slice. This PR adds the same check for
tensor.parallel_insert_slice.

Adds two new StaticValueUtils:
- `isAllConstantIntValue` checks if an array of `OpFoldResult` are all
equal to a passed `int64_t` value.
- `areConstantIntValues` checks if an array of `OpFoldResult` are all
equal to a passed array of `int64_t` values.

fixes https://github.com/llvm/llvm-project/issues/112435

---------

Signed-off-by: Max Dawkins <max.dawkins at gmail.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
    mlir/lib/Dialect/Utils/StaticValueUtils.cpp
    mlir/test/Dialect/Tensor/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index ba4f084d3efd1a..4d7aa1ae17fdb1 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -92,6 +92,12 @@ getConstantIntValues(ArrayRef<OpFoldResult> ofrs);
 
 /// Return true if `ofr` is constant integer equal to `value`.
 bool isConstantIntValue(OpFoldResult ofr, int64_t value);
+/// Return true if all of `ofrs` are constant integers equal to `value`.
+bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value);
+/// Return true if all of `ofrs` are constant integers equal to the
+/// corresponding value in `values`.
+bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs,
+                          ArrayRef<int64_t> values);
 
 /// Return true if ofr1 and ofr2 are the same integer constant attribute
 /// values or the same SSA value. Ignore integer bitwitdh and type mismatch

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 87464ccb71720d..c2b8614148bf25 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Operation.h"
 
@@ -636,6 +637,28 @@ struct InsertOpInterface
   }
 };
 
+template <typename InsertOpTy>
+static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
+                                      OpOperand &opOperand) {
+  // The source is always read.
+  if (opOperand == insertSliceOp.getSourceMutable())
+    return true;
+
+  // For the destination, it depends...
+  assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
+
+  // Dest is not read if it is entirely overwritten. E.g.:
+  // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
+  bool allOffsetsZero =
+      llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroIndex);
+  RankedTensorType destType = insertSliceOp.getDestType();
+  bool sizesMatchDestSizes =
+      areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape());
+  bool allStridesOne =
+      areAllConstantIntValue(insertSliceOp.getMixedStrides(), 1);
+  return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
+}
+
 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
 /// certain circumstances, this op can also be a no-op.
 ///
@@ -646,32 +669,8 @@ struct InsertSliceOpInterface
                                                      tensor::InsertSliceOp> {
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
-    auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
-    RankedTensorType destType = insertSliceOp.getDestType();
-
-    // The source is always read.
-    if (opOperand == insertSliceOp.getSourceMutable())
-      return true;
-
-    // For the destination, it depends...
-    assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
-
-    // Dest is not read if it is entirely overwritten. E.g.:
-    // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
-    bool allOffsetsZero =
-        llvm::all_of(insertSliceOp.getMixedOffsets(), [](OpFoldResult ofr) {
-          return isConstantIntValue(ofr, 0);
-        });
-    bool sizesMatchDestSizes = llvm::all_of(
-        llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) {
-          return getConstantIntValue(it.value()) ==
-                 destType.getDimSize(it.index());
-        });
-    bool allStridesOne =
-        llvm::all_of(insertSliceOp.getMixedStrides(), [](OpFoldResult ofr) {
-          return isConstantIntValue(ofr, 1);
-        });
-    return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
+    return insertSliceOpRequiresRead(cast<tensor::InsertSliceOp>(op),
+                                     opOperand);
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -931,7 +930,8 @@ struct ParallelInsertSliceOpInterface
 
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
-    return true;
+    return insertSliceOpRequiresRead(cast<tensor::ParallelInsertSliceOp>(op),
+                                     opOperand);
   }
 
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index 995486c87771a3..3566714c6529e3 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -16,11 +16,6 @@ namespace mlir {
 namespace tensor {
 namespace {
 
-static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
-  return llvm::all_of(
-      ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
-}
-
 /// Returns the number of shape sizes that is either dynamic or greater than 1.
 static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
   return llvm::count_if(

diff  --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 547d120404aba3..3eb6215a7a0b9b 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -10,6 +10,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/APSInt.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/MathExtras.h"
 
 namespace mlir {
@@ -131,12 +132,24 @@ getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
   return res;
 }
 
-/// Return true if `ofr` is constant integer equal to `value`.
 bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
   auto val = getConstantIntValue(ofr);
   return val && *val == value;
 }
 
+bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
+  return llvm::all_of(
+      ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
+}
+
+bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs,
+                          ArrayRef<int64_t> values) {
+  if (ofrs.size() != values.size())
+    return false;
+  std::optional<SmallVector<int64_t>> constOfrs = getConstantIntValues(ofrs);
+  return constOfrs && llvm::equal(constOfrs.value(), values);
+}
+
 /// Return true if ofr1 and ofr2 are the same integer constant attribute values
 /// or the same SSA value.
 /// Ignore integer bitwidth and type mismatch that come from the fact there is

diff  --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index e2169fe1404c82..dc4306b8316ab7 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -213,6 +213,21 @@ func.func @rank_reducing_parallel_insert_slice(%in: tensor<100xf32>, %out: tenso
 
 // -----
 
+// CHECK-LABEL: func.func @parallel_insert_full_slice_in_place
+// CHECK-NOT:     memref.alloc()
+func.func @parallel_insert_full_slice_in_place(%2: tensor<2xf32>) -> tensor<2xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %3 = scf.forall (%arg0) in (1) shared_outs(%arg2 = %2) -> (tensor<2xf32>) {
+    %fill = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<2xf32>) -> tensor<2xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %fill into %arg2[0] [2] [1] : tensor<2xf32> into tensor<2xf32>
+    }
+  } {mapping = [#gpu.thread<linear_dim_0>]}
+  return %3 : tensor<2xf32>
+}
+
+// -----
+
 // This test case could bufferize in-place with a better analysis. However, it
 // is simpler to let the canonicalizer fold away the tensor.insert_slice.
 


        


More information about the Mlir-commits mailing list