[Mlir-commits] [mlir] [mlir] Do not bufferize parallel_insert_slice dest to read for full slices (PR #112761)
Han-Chung Wang
llvmlistbot at llvm.org
Thu Oct 17 15:19:47 PDT 2024
================
@@ -636,6 +637,34 @@ struct InsertOpInterface
}
};
+template <typename InsertOpTy>
+static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
+ OpOperand &opOperand) {
+ RankedTensorType destType = insertSliceOp.getDestType();
+
+ // The source is always read.
+ if (opOperand == insertSliceOp.getSourceMutable())
+ return true;
+
+ // For the destination, it depends...
+ assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
+
+ // Dest is not read if it is entirely overwritten. E.g.:
+ // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
+ bool allOffsetsZero =
+ llvm::all_of(insertSliceOp.getMixedOffsets(),
+ [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
+ bool sizesMatchDestSizes = llvm::all_of(
+ llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) {
+ return getConstantIntValue(it.value()) ==
+ destType.getDimSize(it.index());
+ });
+ bool allStridesOne =
+ llvm::all_of(insertSliceOp.getMixedStrides(),
+ [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
----------------
hanhanW wrote:
We can consider adding a new method to `StaticValueUtils.h`, which takes an ArrayRef and check if the values are all `value` or not.
```cpp
bool isConstantIntValue(ArrayRef<OpFoldResult ofr>, int64_t value);
// or name it to isConstantIntValueArray
```
https://github.com/llvm/llvm-project/blob/8c62bf54df76e37d0978f4901c6be6554e978b53/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h#L93-L94
https://github.com/llvm/llvm-project/pull/112761
More information about the Mlir-commits
mailing list