[Mlir-commits] [mlir] 8f1650c - [mlir][Linalg] NFC - Refactor vector.broadcast op verification logic and make it available as a precondition in Linalg vectorization.
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Oct 12 04:35:39 PDT 2021
Author: Nicolas Vasilache
Date: 2021-10-12T11:35:34Z
New Revision: 8f1650cb6501408f9ad03c526af3bcd1f57ef48f
URL: https://github.com/llvm/llvm-project/commit/8f1650cb6501408f9ad03c526af3bcd1f57ef48f
DIFF: https://github.com/llvm/llvm-project/commit/8f1650cb6501408f9ad03c526af3bcd1f57ef48f.diff
LOG: [mlir][Linalg] NFC - Refactor vector.broadcast op verification logic and make it available as a precondition in Linalg vectorization.
Reviewed By: pifon2a
Differential Revision: https://reviews.llvm.org/D111558
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/vector-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index a6fbf93f29a08..694875f5e143e 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -40,6 +40,18 @@ namespace detail {
struct BitmaskEnumStorage;
} // namespace detail
+/// Return whether `srcType` can be broadcast to `dstVectorType` under the
+/// semantics of the `vector.broadcast` op.
+enum class BroadcastableToResult {
+ Success = 0,
+ SourceRankHigher = 1,
+ DimensionMismatch = 2,
+ SourceTypeNotAVector = 3
+};
+BroadcastableToResult
+isBroadcastableTo(Type srcType, VectorType dstVectorType,
+ std::pair<int, int> *mismatchingDims = nullptr);
+
/// Collect a set of vector-to-vector canonicalization patterns.
void populateVectorToVectorCanonicalizationPatterns(
RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index e6df2dbf9b1ac..60a9e67e476a6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -147,24 +147,20 @@ matchLinalgReduction(OpOperand *outputOperand) {
return getKindForOp(combinerOps[0]);
}
-/// If `value` of assumed VectorType has a shape
diff erent than `shape`, try to
-/// build and return a new vector.broadcast to `shape`.
-/// Otherwise, just return `value`.
-// TODO: this is best effort atm and there is currently no guarantee of
-// correctness for the broadcast semantics.
+/// Broadcast `value` to a vector of `shape` if possible. Return value
+/// otherwise.
static Value broadcastIfNeeded(OpBuilder &b, Value value,
ArrayRef<int64_t> shape) {
- unsigned numDimsGtOne = std::count_if(shape.begin(), shape.end(),
- [](int64_t val) { return val > 1; });
- auto vecType = value.getType().dyn_cast<VectorType>();
- if (shape.empty() ||
- (vecType != nullptr &&
- (vecType.getShape() == shape || vecType.getRank() > numDimsGtOne)))
+ // If no shape to broadcast to, just return `value`.
+ if (shape.empty())
+ return value;
+ VectorType targetVectorType =
+ VectorType::get(shape, getElementTypeOrSelf(value));
+ if (vector::isBroadcastableTo(value.getType(), targetVectorType) !=
+ vector::BroadcastableToResult::Success)
return value;
- auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
- : value.getType());
- return b.create<vector::BroadcastOp>(b.getInsertionPoint()->getLoc(),
- newVecType, value);
+ Location loc = b.getInsertionPoint()->getLoc();
+ return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
}
/// If value of assumed VectorType has a shape
diff erent than `shape`, build and
@@ -688,7 +684,8 @@ struct GenericPadTensorOpVectorizationPattern
// by TransferReadOp, but TransferReadOp supports only constant padding.
auto padValue = padOp.getConstantPaddingValue();
if (!padValue) {
- if (!sourceType.hasStaticShape()) return failure();
+ if (!sourceType.hasStaticShape())
+ return failure();
// Create dummy padding value.
auto elemType = sourceType.getElementType();
padValue = rewriter.create<ConstantOp>(padOp.getLoc(), elemType,
@@ -733,14 +730,14 @@ struct GenericPadTensorOpVectorizationPattern
// If `dest` is a FillOp and the TransferWriteOp would overwrite the entire
// tensor, write directly to the FillOp's operand.
- if (llvm::equal(vecShape, resultType.getShape())
- && llvm::all_of(writeInBounds, [](bool b) { return b; }))
+ if (llvm::equal(vecShape, resultType.getShape()) &&
+ llvm::all_of(writeInBounds, [](bool b) { return b; }))
if (auto fill = dest.getDefiningOp<FillOp>())
dest = fill.output();
// Generate TransferWriteOp.
- auto writeIndices = ofrToIndexValues(
- rewriter, padOp.getLoc(), padOp.getMixedLowPad());
+ auto writeIndices =
+ ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
padOp, read, dest, writeIndices, writeInBounds);
@@ -764,9 +761,9 @@ struct VectorizePadTensorOpUserPattern : public OpRewritePattern<PadTensorOp> {
return success(changed);
}
- protected:
- virtual LogicalResult rewriteUser(
- PatternRewriter &rewriter, PadTensorOp padOp, OpTy op) const = 0;
+protected:
+ virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
+ PadTensorOp padOp, OpTy op) const = 0;
};
/// Rewrite use of PadTensorOp result in TransferReadOp. E.g.:
@@ -790,18 +787,21 @@ struct VectorizePadTensorOpUserPattern : public OpRewritePattern<PadTensorOp> {
/// - Single, scalar padding value.
struct PadTensorOpVectorizationWithTransferReadPattern
: public VectorizePadTensorOpUserPattern<vector::TransferReadOp> {
- using VectorizePadTensorOpUserPattern<vector::TransferReadOp>
- ::VectorizePadTensorOpUserPattern;
+ using VectorizePadTensorOpUserPattern<
+ vector::TransferReadOp>::VectorizePadTensorOpUserPattern;
LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
vector::TransferReadOp xferOp) const override {
// Low padding must be static 0.
- if (!padOp.hasZeroLowPad()) return failure();
+ if (!padOp.hasZeroLowPad())
+ return failure();
// Pad value must be a constant.
auto padValue = padOp.getConstantPaddingValue();
- if (!padValue) return failure();
+ if (!padValue)
+ return failure();
// Padding value of existing `xferOp` is unused.
- if (xferOp.hasOutOfBoundsDim() || xferOp.mask()) return failure();
+ if (xferOp.hasOutOfBoundsDim() || xferOp.mask())
+ return failure();
rewriter.updateRootInPlace(xferOp, [&]() {
SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
@@ -847,24 +847,30 @@ struct PadTensorOpVectorizationWithTransferReadPattern
/// - Single, scalar padding value.
struct PadTensorOpVectorizationWithTransferWritePattern
: public VectorizePadTensorOpUserPattern<vector::TransferWriteOp> {
- using VectorizePadTensorOpUserPattern<vector::TransferWriteOp>
- ::VectorizePadTensorOpUserPattern;
+ using VectorizePadTensorOpUserPattern<
+ vector::TransferWriteOp>::VectorizePadTensorOpUserPattern;
LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
vector::TransferWriteOp xferOp) const override {
// Low padding must be static 0.
- if (!padOp.hasZeroLowPad()) return failure();
+ if (!padOp.hasZeroLowPad())
+ return failure();
// Pad value must be a constant.
auto padValue = padOp.getConstantPaddingValue();
- if (!padValue) return failure();
+ if (!padValue)
+ return failure();
// TransferWriteOp result must be directly consumed by an ExtractSliceOp.
- if (!xferOp->hasOneUse()) return failure();
+ if (!xferOp->hasOneUse())
+ return failure();
auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
- if (!trimPadding) return failure();
+ if (!trimPadding)
+ return failure();
// Only static zero offsets supported when trimming padding.
- if (!trimPadding.hasZeroOffset()) return failure();
+ if (!trimPadding.hasZeroOffset())
+ return failure();
// trimPadding must remove the amount of padding that was added earlier.
- if (!hasSameTensorSize(padOp.source(), trimPadding)) return failure();
+ if (!hasSameTensorSize(padOp.source(), trimPadding))
+ return failure();
// Insert the new TransferWriteOp at position of the old TransferWriteOp.
rewriter.setInsertionPoint(xferOp);
@@ -894,14 +900,17 @@ struct PadTensorOpVectorizationWithTransferWritePattern
// If the input to PadTensorOp is a CastOp, try with with both CastOp result
// and CastOp operand.
if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
- if (hasSameTensorSize(castOp.source(), afterTrimming)) return true;
+ if (hasSameTensorSize(castOp.source(), afterTrimming))
+ return true;
auto t1 = beforePadding.getType().dyn_cast<RankedTensorType>();
auto t2 = afterTrimming.getType().dyn_cast<RankedTensorType>();
// Only RankedTensorType supported.
- if (!t1 || !t2) return false;
+ if (!t1 || !t2)
+ return false;
// Rank of both values must be the same.
- if (t1.getRank() != t2.getRank()) return false;
+ if (t1.getRank() != t2.getRank())
+ return false;
// All static dimensions must be the same. Mixed cases (e.g., dimension
// static in `t1` but dynamic in `t2`) are not supported.
@@ -913,7 +922,8 @@ struct PadTensorOpVectorizationWithTransferWritePattern
}
// Nothing more to check if all dimensions are static.
- if (t1.getNumDynamicDims() == 0) return true;
+ if (t1.getNumDynamicDims() == 0)
+ return true;
// All dynamic sizes must be the same. The only supported case at the moment
// is when `beforePadding` is an ExtractSliceOp (or a cast thereof).
@@ -925,29 +935,33 @@ struct PadTensorOpVectorizationWithTransferWritePattern
assert(static_cast<size_t>(t1.getRank()) ==
beforeSlice.getMixedSizes().size());
- assert(static_cast<size_t>(t2.getRank())
- == afterTrimming.getMixedSizes().size());
+ assert(static_cast<size_t>(t2.getRank()) ==
+ afterTrimming.getMixedSizes().size());
for (unsigned i = 0; i < t1.getRank(); ++i) {
// Skip static dimensions.
- if (!t1.isDynamicDim(i)) continue;
+ if (!t1.isDynamicDim(i))
+ continue;
auto size1 = beforeSlice.getMixedSizes()[i];
auto size2 = afterTrimming.getMixedSizes()[i];
// Case 1: Same value or same constant int.
- if (isEqualConstantIntOrValue(size1, size2)) continue;
+ if (isEqualConstantIntOrValue(size1, size2))
+ continue;
// Other cases: Take a deeper look at defining ops of values.
auto v1 = size1.dyn_cast<Value>();
auto v2 = size2.dyn_cast<Value>();
- if (!v1 || !v2) return false;
+ if (!v1 || !v2)
+ return false;
// Case 2: Both values are identical AffineMinOps. (Should not happen if
// CSE is run.)
auto minOp1 = v1.getDefiningOp<AffineMinOp>();
auto minOp2 = v2.getDefiningOp<AffineMinOp>();
- if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap()
- && minOp1.operands() == minOp2.operands()) continue;
+ if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
+ minOp1.operands() == minOp2.operands())
+ continue;
// Add additional cases as needed.
}
@@ -987,9 +1001,11 @@ struct PadTensorOpVectorizationWithInsertSlicePattern
LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
tensor::InsertSliceOp insertOp) const override {
// Low padding must be static 0.
- if (!padOp.hasZeroLowPad()) return failure();
+ if (!padOp.hasZeroLowPad())
+ return failure();
// Only unit stride supported.
- if (!insertOp.hasUnitStride()) return failure();
+ if (!insertOp.hasUnitStride())
+ return failure();
// Pad value must be a constant.
auto padValue = padOp.getConstantPaddingValue();
if (!padValue)
@@ -1038,8 +1054,8 @@ struct PadTensorOpVectorizationWithInsertSlicePattern
void mlir::linalg::populatePadTensorOpVectorizationPatterns(
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
- patterns.add<GenericPadTensorOpVectorizationPattern>(
- patterns.getContext(), baseBenefit);
+ patterns.add<GenericPadTensorOpVectorizationPattern>(patterns.getContext(),
+ baseBenefit);
// Try these specialized patterns first before resorting to the generic one.
patterns.add<PadTensorOpVectorizationWithTransferReadPattern,
PadTensorOpVectorizationWithTransferWritePattern,
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 36898a44bf273..879996a041bf9 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1321,31 +1321,59 @@ Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
// BroadcastOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(BroadcastOp op) {
- VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>();
- VectorType dstVectorType = op.getVectorType();
- // Scalar to vector broadcast is always valid. A vector
- // to vector broadcast needs some additional checking.
- if (srcVectorType) {
- int64_t srcRank = srcVectorType.getRank();
- int64_t dstRank = dstVectorType.getRank();
- if (srcRank > dstRank)
- return op.emitOpError("source rank higher than destination rank");
- // Source has an exact match or singleton value for all trailing dimensions
- // (all leading dimensions are simply duplicated).
- int64_t lead = dstRank - srcRank;
- for (int64_t r = 0; r < srcRank; ++r) {
- int64_t srcDim = srcVectorType.getDimSize(r);
- int64_t dstDim = dstVectorType.getDimSize(lead + r);
- if (srcDim != 1 && srcDim != dstDim)
- return op.emitOpError("dimension mismatch (")
- << srcDim << " vs. " << dstDim << ")";
+BroadcastableToResult
+mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
+ std::pair<int, int> *mismatchingDims) {
+ // Broadcast scalar to vector of the same element type.
+ if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
+ getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
+ return BroadcastableToResult::Success;
+ // From now on, only vectors broadcast.
+ VectorType srcVectorType = srcType.dyn_cast<VectorType>();
+ if (!srcVectorType)
+ return BroadcastableToResult::SourceTypeNotAVector;
+
+ int64_t srcRank = srcVectorType.getRank();
+ int64_t dstRank = dstVectorType.getRank();
+ if (srcRank > dstRank)
+ return BroadcastableToResult::SourceRankHigher;
+ // Source has an exact match or singleton value for all trailing dimensions
+ // (all leading dimensions are simply duplicated).
+ int64_t lead = dstRank - srcRank;
+ for (int64_t r = 0; r < srcRank; ++r) {
+ int64_t srcDim = srcVectorType.getDimSize(r);
+ int64_t dstDim = dstVectorType.getDimSize(lead + r);
+ if (srcDim != 1 && srcDim != dstDim) {
+ if (mismatchingDims) {
+ mismatchingDims->first = srcDim;
+ mismatchingDims->second = dstDim;
+ }
+ return BroadcastableToResult::DimensionMismatch;
}
}
- return success();
+
+ return BroadcastableToResult::Success;
+}
+
+static LogicalResult verify(BroadcastOp op) {
+ std::pair<int, int> mismatchingDims;
+ BroadcastableToResult res = isBroadcastableTo(
+ op.getSourceType(), op.getVectorType(), &mismatchingDims);
+ if (res == BroadcastableToResult::Success)
+ return success();
+ if (res == BroadcastableToResult::SourceRankHigher)
+ return op.emitOpError("source rank higher than destination rank");
+ if (res == BroadcastableToResult::DimensionMismatch)
+ return op.emitOpError("dimension mismatch (")
+ << mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
+ if (res == BroadcastableToResult::SourceTypeNotAVector)
+ return op.emitOpError("source type is not a vector");
+ llvm_unreachable("unexpected vector.broadcast op error");
}
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
+ if (getSourceType() == getVectorType())
+ return source();
if (!operands[0])
return {};
auto vectorType = getVectorType();
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 53c244716759c..26845172e1a6d 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -30,6 +30,13 @@ func @broadcast_dim2_mismatch(%arg0: vector<4x8xf32>) {
// -----
+func @broadcast_unknown(%arg0: memref<4x8xf32>) {
+ // expected-error at +1 {{'vector.broadcast' op source type is not a vector}}
+ %1 = vector.broadcast %arg0 : memref<4x8xf32> to vector<1x8xf32>
+}
+
+// -----
+
func @shuffle_elt_type_mismatch(%arg0: vector<2xf32>, %arg1: vector<2xi32>) {
// expected-error at +1 {{'vector.shuffle' op failed to verify that second operand v2 and result have same element type}}
%1 = vector.shuffle %arg0, %arg1 [0, 1] : vector<2xf32>, vector<2xi32>
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index 288f2f6d0a4a7..c925b8e1c76e4 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -493,7 +493,6 @@ func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x
func @cast_away_broadcast_leading_one_dims(
%arg0: vector<8xf32>, %arg1: f32, %arg2: vector<1x4xf32>) ->
(vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>, vector<1x1x4xf32>) {
- // CHECK: vector.broadcast %{{.*}} : vector<8xf32> to vector<8xf32>
// CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
%0 = vector.broadcast %arg0 : vector<8xf32> to vector<1x1x8xf32>
// CHECK: vector.broadcast %{{.*}} : f32 to vector<4xf32>
More information about the Mlir-commits
mailing list