[Mlir-commits] [mlir] 8f1650c - [mlir][Linalg] NFC - Refactor vector.broadcast op verification logic and make it available as a precondition in Linalg vectorization.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Oct 12 04:35:39 PDT 2021


Author: Nicolas Vasilache
Date: 2021-10-12T11:35:34Z
New Revision: 8f1650cb6501408f9ad03c526af3bcd1f57ef48f

URL: https://github.com/llvm/llvm-project/commit/8f1650cb6501408f9ad03c526af3bcd1f57ef48f
DIFF: https://github.com/llvm/llvm-project/commit/8f1650cb6501408f9ad03c526af3bcd1f57ef48f.diff

LOG: [mlir][Linalg] NFC - Refactor vector.broadcast op verification logic and make it available as a precondition in Linalg vectorization.

Reviewed By: pifon2a

Differential Revision: https://reviews.llvm.org/D111558

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/vector-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index a6fbf93f29a08..694875f5e143e 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -40,6 +40,18 @@ namespace detail {
 struct BitmaskEnumStorage;
 } // namespace detail
 
+/// Return whether `srcType` can be broadcast to `dstVectorType` under the
+/// semantics of the `vector.broadcast` op.
+enum class BroadcastableToResult {
+  Success = 0,
+  SourceRankHigher = 1,
+  DimensionMismatch = 2,
+  SourceTypeNotAVector = 3
+};
+BroadcastableToResult
+isBroadcastableTo(Type srcType, VectorType dstVectorType,
+                  std::pair<int, int> *mismatchingDims = nullptr);
+
 /// Collect a set of vector-to-vector canonicalization patterns.
 void populateVectorToVectorCanonicalizationPatterns(
     RewritePatternSet &patterns);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index e6df2dbf9b1ac..60a9e67e476a6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -147,24 +147,20 @@ matchLinalgReduction(OpOperand *outputOperand) {
   return getKindForOp(combinerOps[0]);
 }
 
-/// If `value` of assumed VectorType has a shape 
diff erent than `shape`, try to
-/// build and return a new vector.broadcast to `shape`.
-/// Otherwise, just return `value`.
-// TODO: this is best effort atm and there is currently no guarantee of
-// correctness for the broadcast semantics.
+/// Broadcast `value` to a vector of `shape` if possible. Return value
+/// otherwise.
 static Value broadcastIfNeeded(OpBuilder &b, Value value,
                                ArrayRef<int64_t> shape) {
-  unsigned numDimsGtOne = std::count_if(shape.begin(), shape.end(),
-                                        [](int64_t val) { return val > 1; });
-  auto vecType = value.getType().dyn_cast<VectorType>();
-  if (shape.empty() ||
-      (vecType != nullptr &&
-       (vecType.getShape() == shape || vecType.getRank() > numDimsGtOne)))
+  // If no shape to broadcast to, just return `value`.
+  if (shape.empty())
+    return value;
+  VectorType targetVectorType =
+      VectorType::get(shape, getElementTypeOrSelf(value));
+  if (vector::isBroadcastableTo(value.getType(), targetVectorType) !=
+      vector::BroadcastableToResult::Success)
     return value;
-  auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
-                                                   : value.getType());
-  return b.create<vector::BroadcastOp>(b.getInsertionPoint()->getLoc(),
-                                       newVecType, value);
+  Location loc = b.getInsertionPoint()->getLoc();
+  return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
 }
 
 /// If value of assumed VectorType has a shape 
diff erent than `shape`, build and
@@ -688,7 +684,8 @@ struct GenericPadTensorOpVectorizationPattern
     // by TransferReadOp, but TransferReadOp supports only constant padding.
     auto padValue = padOp.getConstantPaddingValue();
     if (!padValue) {
-      if (!sourceType.hasStaticShape()) return failure();
+      if (!sourceType.hasStaticShape())
+        return failure();
       // Create dummy padding value.
       auto elemType = sourceType.getElementType();
       padValue = rewriter.create<ConstantOp>(padOp.getLoc(), elemType,
@@ -733,14 +730,14 @@ struct GenericPadTensorOpVectorizationPattern
 
     // If `dest` is a FillOp and the TransferWriteOp would overwrite the entire
     // tensor, write directly to the FillOp's operand.
-    if (llvm::equal(vecShape, resultType.getShape())
-        && llvm::all_of(writeInBounds, [](bool b) { return b; }))
+    if (llvm::equal(vecShape, resultType.getShape()) &&
+        llvm::all_of(writeInBounds, [](bool b) { return b; }))
       if (auto fill = dest.getDefiningOp<FillOp>())
         dest = fill.output();
 
     // Generate TransferWriteOp.
-    auto writeIndices = ofrToIndexValues(
-        rewriter, padOp.getLoc(), padOp.getMixedLowPad());
+    auto writeIndices =
+        ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
         padOp, read, dest, writeIndices, writeInBounds);
 
@@ -764,9 +761,9 @@ struct VectorizePadTensorOpUserPattern : public OpRewritePattern<PadTensorOp> {
     return success(changed);
   }
 
- protected:
-  virtual LogicalResult rewriteUser(
-      PatternRewriter &rewriter, PadTensorOp padOp, OpTy op) const = 0;
+protected:
+  virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
+                                    PadTensorOp padOp, OpTy op) const = 0;
 };
 
 /// Rewrite use of PadTensorOp result in TransferReadOp. E.g.:
@@ -790,18 +787,21 @@ struct VectorizePadTensorOpUserPattern : public OpRewritePattern<PadTensorOp> {
 /// - Single, scalar padding value.
 struct PadTensorOpVectorizationWithTransferReadPattern
     : public VectorizePadTensorOpUserPattern<vector::TransferReadOp> {
-  using VectorizePadTensorOpUserPattern<vector::TransferReadOp>
-      ::VectorizePadTensorOpUserPattern;
+  using VectorizePadTensorOpUserPattern<
+      vector::TransferReadOp>::VectorizePadTensorOpUserPattern;
 
   LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
                             vector::TransferReadOp xferOp) const override {
     // Low padding must be static 0.
-    if (!padOp.hasZeroLowPad()) return failure();
+    if (!padOp.hasZeroLowPad())
+      return failure();
     // Pad value must be a constant.
     auto padValue = padOp.getConstantPaddingValue();
-    if (!padValue) return failure();
+    if (!padValue)
+      return failure();
     // Padding value of existing `xferOp` is unused.
-    if (xferOp.hasOutOfBoundsDim() || xferOp.mask()) return failure();
+    if (xferOp.hasOutOfBoundsDim() || xferOp.mask())
+      return failure();
 
     rewriter.updateRootInPlace(xferOp, [&]() {
       SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
@@ -847,24 +847,30 @@ struct PadTensorOpVectorizationWithTransferReadPattern
 /// - Single, scalar padding value.
 struct PadTensorOpVectorizationWithTransferWritePattern
     : public VectorizePadTensorOpUserPattern<vector::TransferWriteOp> {
-  using VectorizePadTensorOpUserPattern<vector::TransferWriteOp>
-      ::VectorizePadTensorOpUserPattern;
+  using VectorizePadTensorOpUserPattern<
+      vector::TransferWriteOp>::VectorizePadTensorOpUserPattern;
 
   LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
                             vector::TransferWriteOp xferOp) const override {
     // Low padding must be static 0.
-    if (!padOp.hasZeroLowPad()) return failure();
+    if (!padOp.hasZeroLowPad())
+      return failure();
     // Pad value must be a constant.
     auto padValue = padOp.getConstantPaddingValue();
-    if (!padValue) return failure();
+    if (!padValue)
+      return failure();
     // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
-    if (!xferOp->hasOneUse()) return failure();
+    if (!xferOp->hasOneUse())
+      return failure();
     auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
-    if (!trimPadding) return failure();
+    if (!trimPadding)
+      return failure();
     // Only static zero offsets supported when trimming padding.
-    if (!trimPadding.hasZeroOffset()) return failure();
+    if (!trimPadding.hasZeroOffset())
+      return failure();
     // trimPadding must remove the amount of padding that was added earlier.
-    if (!hasSameTensorSize(padOp.source(), trimPadding)) return failure();
+    if (!hasSameTensorSize(padOp.source(), trimPadding))
+      return failure();
 
     // Insert the new TransferWriteOp at position of the old TransferWriteOp.
     rewriter.setInsertionPoint(xferOp);
@@ -894,14 +900,17 @@ struct PadTensorOpVectorizationWithTransferWritePattern
     // If the input to PadTensorOp is a CastOp, try with with both CastOp result
     // and CastOp operand.
     if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
-      if (hasSameTensorSize(castOp.source(), afterTrimming)) return true;
+      if (hasSameTensorSize(castOp.source(), afterTrimming))
+        return true;
 
     auto t1 = beforePadding.getType().dyn_cast<RankedTensorType>();
     auto t2 = afterTrimming.getType().dyn_cast<RankedTensorType>();
     // Only RankedTensorType supported.
-    if (!t1 || !t2) return false;
+    if (!t1 || !t2)
+      return false;
     // Rank of both values must be the same.
-    if (t1.getRank() != t2.getRank()) return false;
+    if (t1.getRank() != t2.getRank())
+      return false;
 
     // All static dimensions must be the same. Mixed cases (e.g., dimension
     // static in `t1` but dynamic in `t2`) are not supported.
@@ -913,7 +922,8 @@ struct PadTensorOpVectorizationWithTransferWritePattern
     }
 
     // Nothing more to check if all dimensions are static.
-    if (t1.getNumDynamicDims() == 0) return true;
+    if (t1.getNumDynamicDims() == 0)
+      return true;
 
     // All dynamic sizes must be the same. The only supported case at the moment
     // is when `beforePadding` is an ExtractSliceOp (or a cast thereof).
@@ -925,29 +935,33 @@ struct PadTensorOpVectorizationWithTransferWritePattern
 
     assert(static_cast<size_t>(t1.getRank()) ==
            beforeSlice.getMixedSizes().size());
-    assert(static_cast<size_t>(t2.getRank())
-           == afterTrimming.getMixedSizes().size());
+    assert(static_cast<size_t>(t2.getRank()) ==
+           afterTrimming.getMixedSizes().size());
 
     for (unsigned i = 0; i < t1.getRank(); ++i) {
       // Skip static dimensions.
-      if (!t1.isDynamicDim(i)) continue;
+      if (!t1.isDynamicDim(i))
+        continue;
       auto size1 = beforeSlice.getMixedSizes()[i];
       auto size2 = afterTrimming.getMixedSizes()[i];
 
       // Case 1: Same value or same constant int.
-      if (isEqualConstantIntOrValue(size1, size2)) continue;
+      if (isEqualConstantIntOrValue(size1, size2))
+        continue;
 
       // Other cases: Take a deeper look at defining ops of values.
       auto v1 = size1.dyn_cast<Value>();
       auto v2 = size2.dyn_cast<Value>();
-      if (!v1 || !v2) return false;
+      if (!v1 || !v2)
+        return false;
 
       // Case 2: Both values are identical AffineMinOps. (Should not happen if
       // CSE is run.)
       auto minOp1 = v1.getDefiningOp<AffineMinOp>();
       auto minOp2 = v2.getDefiningOp<AffineMinOp>();
-      if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap()
-          && minOp1.operands() == minOp2.operands()) continue;
+      if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
+          minOp1.operands() == minOp2.operands())
+        continue;
 
       // Add additional cases as needed.
     }
@@ -987,9 +1001,11 @@ struct PadTensorOpVectorizationWithInsertSlicePattern
   LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
                             tensor::InsertSliceOp insertOp) const override {
     // Low padding must be static 0.
-    if (!padOp.hasZeroLowPad()) return failure();
+    if (!padOp.hasZeroLowPad())
+      return failure();
     // Only unit stride supported.
-    if (!insertOp.hasUnitStride()) return failure();
+    if (!insertOp.hasUnitStride())
+      return failure();
     // Pad value must be a constant.
     auto padValue = padOp.getConstantPaddingValue();
     if (!padValue)
@@ -1038,8 +1054,8 @@ struct PadTensorOpVectorizationWithInsertSlicePattern
 
 void mlir::linalg::populatePadTensorOpVectorizationPatterns(
     RewritePatternSet &patterns, PatternBenefit baseBenefit) {
-  patterns.add<GenericPadTensorOpVectorizationPattern>(
-      patterns.getContext(), baseBenefit);
+  patterns.add<GenericPadTensorOpVectorizationPattern>(patterns.getContext(),
+                                                       baseBenefit);
   // Try these specialized patterns first before resorting to the generic one.
   patterns.add<PadTensorOpVectorizationWithTransferReadPattern,
                PadTensorOpVectorizationWithTransferWritePattern,

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 36898a44bf273..879996a041bf9 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1321,31 +1321,59 @@ Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
 // BroadcastOp
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verify(BroadcastOp op) {
-  VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>();
-  VectorType dstVectorType = op.getVectorType();
-  // Scalar to vector broadcast is always valid. A vector
-  // to vector broadcast needs some additional checking.
-  if (srcVectorType) {
-    int64_t srcRank = srcVectorType.getRank();
-    int64_t dstRank = dstVectorType.getRank();
-    if (srcRank > dstRank)
-      return op.emitOpError("source rank higher than destination rank");
-    // Source has an exact match or singleton value for all trailing dimensions
-    // (all leading dimensions are simply duplicated).
-    int64_t lead = dstRank - srcRank;
-    for (int64_t r = 0; r < srcRank; ++r) {
-      int64_t srcDim = srcVectorType.getDimSize(r);
-      int64_t dstDim = dstVectorType.getDimSize(lead + r);
-      if (srcDim != 1 && srcDim != dstDim)
-        return op.emitOpError("dimension mismatch (")
-               << srcDim << " vs. " << dstDim << ")";
+BroadcastableToResult
+mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
+                                std::pair<int, int> *mismatchingDims) {
+  // Broadcast scalar to vector of the same element type.
+  if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
+      getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
+    return BroadcastableToResult::Success;
+  // From now on, only vectors broadcast.
+  VectorType srcVectorType = srcType.dyn_cast<VectorType>();
+  if (!srcVectorType)
+    return BroadcastableToResult::SourceTypeNotAVector;
+
+  int64_t srcRank = srcVectorType.getRank();
+  int64_t dstRank = dstVectorType.getRank();
+  if (srcRank > dstRank)
+    return BroadcastableToResult::SourceRankHigher;
+  // Source has an exact match or singleton value for all trailing dimensions
+  // (all leading dimensions are simply duplicated).
+  int64_t lead = dstRank - srcRank;
+  for (int64_t r = 0; r < srcRank; ++r) {
+    int64_t srcDim = srcVectorType.getDimSize(r);
+    int64_t dstDim = dstVectorType.getDimSize(lead + r);
+    if (srcDim != 1 && srcDim != dstDim) {
+      if (mismatchingDims) {
+        mismatchingDims->first = srcDim;
+        mismatchingDims->second = dstDim;
+      }
+      return BroadcastableToResult::DimensionMismatch;
     }
   }
-  return success();
+
+  return BroadcastableToResult::Success;
+}
+
+static LogicalResult verify(BroadcastOp op) {
+  std::pair<int, int> mismatchingDims;
+  BroadcastableToResult res = isBroadcastableTo(
+      op.getSourceType(), op.getVectorType(), &mismatchingDims);
+  if (res == BroadcastableToResult::Success)
+    return success();
+  if (res == BroadcastableToResult::SourceRankHigher)
+    return op.emitOpError("source rank higher than destination rank");
+  if (res == BroadcastableToResult::DimensionMismatch)
+    return op.emitOpError("dimension mismatch (")
+           << mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
+  if (res == BroadcastableToResult::SourceTypeNotAVector)
+    return op.emitOpError("source type is not a vector");
+  llvm_unreachable("unexpected vector.broadcast op error");
 }
 
 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
+  if (getSourceType() == getVectorType())
+    return source();
   if (!operands[0])
     return {};
   auto vectorType = getVectorType();

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 53c244716759c..26845172e1a6d 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -30,6 +30,13 @@ func @broadcast_dim2_mismatch(%arg0: vector<4x8xf32>) {
 
 // -----
 
+func @broadcast_unknown(%arg0: memref<4x8xf32>) {
+  // expected-error at +1 {{'vector.broadcast' op source type is not a vector}}
+  %1 = vector.broadcast %arg0 : memref<4x8xf32> to vector<1x8xf32>
+}
+
+// -----
+
 func @shuffle_elt_type_mismatch(%arg0: vector<2xf32>, %arg1: vector<2xi32>) {
   // expected-error at +1 {{'vector.shuffle' op failed to verify that second operand v2 and result have same element type}}
   %1 = vector.shuffle %arg0, %arg1 [0, 1] : vector<2xf32>, vector<2xi32>

diff  --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index 288f2f6d0a4a7..c925b8e1c76e4 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -493,7 +493,6 @@ func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x
 func @cast_away_broadcast_leading_one_dims(
   %arg0: vector<8xf32>, %arg1: f32, %arg2: vector<1x4xf32>) ->
   (vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>, vector<1x1x4xf32>) {
-  // CHECK:  vector.broadcast %{{.*}} : vector<8xf32> to vector<8xf32>
   // CHECK:  vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
   %0 = vector.broadcast %arg0 : vector<8xf32> to vector<1x1x8xf32>
   // CHECK:  vector.broadcast %{{.*}} : f32 to vector<4xf32>


        


More information about the Mlir-commits mailing list