[Mlir-commits] [mlir] 0650e1b - [mlir][vector] Fix folding of vector.extract from vector.broadcast
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Nov 28 07:17:37 PST 2022
Author: Nicolas Vasilache
Date: 2022-11-28T07:17:31-08:00
New Revision: 0650e1bcc05af64d25ad89f87b2e21b37fd88114
URL: https://github.com/llvm/llvm-project/commit/0650e1bcc05af64d25ad89f87b2e21b37fd88114
DIFF: https://github.com/llvm/llvm-project/commit/0650e1bcc05af64d25ad89f87b2e21b37fd88114.diff
LOG: [mlir][vector] Fix folding of vector.extract from vector.broadcast
This revision fixes a bug in the vector.extract folding that was missing
handling the "dim-1" broadcasting case in vector.broadcast.
Differential Revision: https://reviews.llvm.org/D138804
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 5060d8c076f14..edaf78b38b274 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -445,6 +445,10 @@ def Vector_BroadcastOp :
VectorType getVectorType() {
return getVector().getType().cast<VectorType>();
}
+
+ /// Return the dimensions of the result vector that were formerly ones in the
+ /// source tensor and thus correspond to "dim-1" broadcasting.
+ llvm::SetVector<int64_t> computeBroadcastedUnitDims();
}];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)";
let hasFolder = 1;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2f9bca6a0564e..9e1b6307b75f2 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1351,7 +1351,11 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
auto getRank = [](Type type) {
return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
};
+ // If splat or broadcast from a scalar, just return the source scalar.
unsigned broadcastSrcRank = getRank(source.getType());
+ if (broadcastSrcRank == 0)
+ return source;
+
unsigned extractResultRank = getRank(extractOp.getType());
if (extractResultRank >= broadcastSrcRank)
return Value();
@@ -1362,13 +1366,25 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
extractVecType.getShape() !=
broadcastVecType.getShape().take_back(extractResultRank))
return Value();
+
+ auto broadcastOp = cast<vector::BroadcastOp>(defOp);
+ int64_t rankDiff = broadcastSrcRank - extractResultRank;
+ // Detect all the positions that come from "dim-1" broadcasting.
+ // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
+ // extract position to `0` when extracting from the source operand.
+ llvm::SetVector<int64_t> broadcastedUnitDims =
+ broadcastOp.computeBroadcastedUnitDims();
auto extractPos = extractVector<int64_t>(extractOp.getPosition());
- unsigned rankDiff = broadcastSrcRank - extractResultRank;
+ for (int64_t i = rankDiff, e = extractPos.size(); i < e; ++i)
+ if (broadcastedUnitDims.contains(i))
+ extractPos[i] = 0;
+ // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
+ // matching extract position when extracting from the source operand.
extractPos.erase(extractPos.begin(),
std::next(extractPos.begin(), extractPos.size() - rankDiff));
- extractOp.setOperand(source);
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
+ extractOp.setOperand(source);
extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
b.getI64ArrayAttr(extractPos));
return extractOp.getResult();
@@ -1683,6 +1699,28 @@ Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
// BroadcastOp
//===----------------------------------------------------------------------===//
+/// Return the dimensions of the result vector that were formerly ones in the
+/// source tensor and thus correspond to "dim-1" broadcasting.
+llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
+ VectorType srcVectorType = getSourceType().dyn_cast<VectorType>();
+ // Scalar broadcast is without any unit dim broadcast.
+ if (!srcVectorType)
+ return {};
+ ArrayRef<int64_t> srcShape = srcVectorType.getShape();
+ ArrayRef<int64_t> dstShape = getVectorType().getShape();
+ int64_t rankDiff = dstShape.size() - srcShape.size();
+ int64_t dstDim = rankDiff;
+ llvm::SetVector<int64_t> res;
+ for (auto [s1, s2] : llvm::zip(srcShape, dstShape.drop_front(rankDiff))) {
+ if (s1 != s2) {
+ assert(s1 == 1 && "expected dim-1 broadcasting");
+ res.insert(dstDim);
+ }
+ ++dstDim;
+ }
+ return res;
+}
+
BroadcastableToResult
mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
std::pair<int, int> *mismatchingDims) {
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 7aabcec231976..872767c7f8440 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2020,3 +2020,15 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf
%1 = vector.transfer_read %0[%c0, %i4, %c0], %f0 {in_bounds = [true]} : tensor<1x4x4xf32>, vector<4xf32>
return %1 : vector<4xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_from_broadcast
+func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
+ %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
+
+ // CHECK-NEXT: %0 = vector.extract {{.*}}[0, 0] : vector<1x1x1xf32>
+ // CHECK-NEXT: return %0 : vector<1xf32>
+ %1 = vector.extract %0[0, 0, 31] : vector<1x1x32x1xf32>
+ return %1: vector<1xf32>
+}
More information about the Mlir-commits
mailing list