[Mlir-commits] [mlir] de13eed - [mlir][Vector] Add a Broadcast::createBroadcastOp helper
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Nov 30 05:32:36 PST 2022
Author: Nicolas Vasilache
Date: 2022-11-30T05:32:14-08:00
New Revision: de13eeda11b665d2a5f13e523124cf7c8c9bedd9
URL: https://github.com/llvm/llvm-project/commit/de13eeda11b665d2a5f13e523124cf7c8c9bedd9
DIFF: https://github.com/llvm/llvm-project/commit/de13eeda11b665d2a5f13e523124cf7c8c9bedd9.diff
LOG: [mlir][Vector] Add a Broadcast::createBroadcastOp helper
This helper handles non trivial cases of broadcast + optional transpose creation
that should not leak to the outside world.
Differential Revision: https://reviews.llvm.org/D139003
Added:
mlir/test/Dialect/Vector/test-create-broadcast.mlir
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 01582570ad2fe..ff7b79b37c200 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -449,6 +449,23 @@ def Vector_BroadcastOp :
/// Return the dimensions of the result vector that were formerly ones in the
/// source tensor and thus correspond to "dim-1" broadcasting.
llvm::SetVector<int64_t> computeBroadcastedUnitDims();
+
+ /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
+ /// `broadcastedDims` dimensions in the dstShape are broadcasted.
+ /// This requires (and asserts) that the broadcast is free of dim-1
+ /// broadcasting.
+ /// Since vector.broadcast only allows expanding leading dimensions, an extra
+ /// vector.transpose may be inserted to make the broadcast possible.
+ /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
+ /// the helper will assert. This means:
+ /// 1. `dstShape` must not be empty.
+ /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
+ /// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
+ // must match the `value` shape.
+ static Value createOrFoldBroadcastOp(
+ OpBuilder &b, Value value,
+ ArrayRef<int64_t> dstShape,
+ const llvm::SetVector<int64_t> &broadcastedDims);
}];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)";
let hasFolder = 1;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c4af2d8a19441..b36206c1ae34d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1725,13 +1725,9 @@ Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
/// Return the dimensions of the result vector that were formerly ones in the
/// source tensor and thus correspond to "dim-1" broadcasting.
-llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
- VectorType srcVectorType = getSourceType().dyn_cast<VectorType>();
- // Scalar broadcast is without any unit dim broadcast.
- if (!srcVectorType)
- return {};
- ArrayRef<int64_t> srcShape = srcVectorType.getShape();
- ArrayRef<int64_t> dstShape = getVectorType().getShape();
+static llvm::SetVector<int64_t>
+computeBroadcastedUnitDims(ArrayRef<int64_t> srcShape,
+ ArrayRef<int64_t> dstShape) {
int64_t rankDiff = dstShape.size() - srcShape.size();
int64_t dstDim = rankDiff;
llvm::SetVector<int64_t> res;
@@ -1745,6 +1741,129 @@ llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
return res;
}
+llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
+ // Scalar broadcast is without any unit dim broadcast.
+ auto srcVectorType = getSourceType().dyn_cast<VectorType>();
+ if (!srcVectorType)
+ return {};
+ return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
+ getVectorType().getShape());
+}
+
+static bool allBitsSet(llvm::SmallBitVector &bv, int64_t lb, int64_t ub) {
+ for (int64_t i = lb; i < ub; ++i)
+ if (!bv.test(i))
+ return false;
+ return true;
+}
+
+/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
+/// `broadcastedDims` dimensions in the dstShape are broadcasted.
+/// This requires (and asserts) that the broadcast is free of dim-1
+/// broadcasting.
+/// Since vector.broadcast only allows expanding leading dimensions, an extra
+/// vector.transpose may be inserted to make the broadcast possible.
+/// `value`, `dstShape` and `broadcastedDims` must be properly specified or
+/// the helper will assert. This means:
+/// 1. `dstShape` must not be empty.
+/// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
+/// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
+// must match the `value` shape.
+Value BroadcastOp::createOrFoldBroadcastOp(
+ OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
+ const llvm::SetVector<int64_t> &broadcastedDims) {
+ assert(!dstShape.empty() && "unexpected empty dst shape");
+
+ // Well-formedness check.
+ SmallVector<int64_t> checkShape;
+ for (int i = 0, e = dstShape.size(); i < e; ++i) {
+ if (broadcastedDims.contains(i))
+ continue;
+ checkShape.push_back(dstShape[i]);
+ }
+ assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
+ "ill-formed broadcastedDims contains values not confined to "
+ "destVectorShape");
+
+ Location loc = value.getLoc();
+ Type elementType = getElementTypeOrSelf(value.getType());
+ VectorType srcVectorType = value.getType().dyn_cast<VectorType>();
+ VectorType dstVectorType = VectorType::get(dstShape, elementType);
+
+ // Step 2. If scalar -> dstShape broadcast, just do it.
+ if (!srcVectorType) {
+ assert(checkShape.empty() &&
+ "ill-formed createOrFoldBroadcastOp arguments");
+ return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
+ }
+
+ assert(srcVectorType.getShape().equals(checkShape) &&
+ "ill-formed createOrFoldBroadcastOp arguments");
+
+ // Step 3. Since vector.broadcast only allows creating leading dims,
+ // vector -> dstShape broadcast may require a transpose.
+ // Traverse the dims in order and construct:
+ // 1. The leading entries of the broadcastShape that is guaranteed to be
+ // achievable by a simple broadcast.
+ // 2. The induced permutation for the subsequent vector.transpose that will
+ // bring us from `broadcastShape` back to he desired `dstShape`.
+ // If the induced permutation is not the identity, create a vector.transpose.
+ SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
+ broadcastShape.reserve(dstShape.size());
+ // Consider the example:
+ // srcShape = 2x4
+ // dstShape = 1x2x3x4x5
+ // broadcastedDims = [0, 2, 4]
+ //
+ // We want to build:
+ // broadcastShape = 1x3x5x2x4
+ // permutation = [0, 2, 4, 1, 3]
+ // ---V--- -----V-----
+ // leading broadcast part src shape part
+ //
+ // Note that the trailing dims of broadcastShape are exactly the srcShape
+ // by construction.
+ // nextSrcShapeDim is used to keep track of where in the permutation the
+ // "src shape part" occurs.
+ int64_t nextSrcShapeDim = broadcastedDims.size();
+ for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
+ if (broadcastedDims.contains(i)) {
+ // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
+ // bring it to the head of the broadcastShape.
+ // It will need to be permuted back from `broadcastShape.size() - 1` into
+ // position `i`.
+ broadcastShape.push_back(dstShape[i]);
+ permutation[i] = broadcastShape.size() - 1;
+ } else {
+ // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
+ // shape and needs to be permuted into position `i`.
+ // Don't touch `broadcastShape` here, the whole srcShape will be
+ // appended after.
+ permutation[i] = nextSrcShapeDim++;
+ }
+ }
+ // 3.c. Append the srcShape.
+ llvm::append_range(broadcastShape, srcVectorType.getShape());
+
+ // Ensure there are no dim-1 broadcasts.
+ assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
+ .empty() &&
+ "unexpected dim-1 broadcast");
+
+ VectorType broadcastType = VectorType::get(broadcastShape, elementType);
+ assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
+ vector::BroadcastableToResult::Success &&
+ "must be broadcastable");
+ Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
+ // Step 4. If we find any dimension that indeed needs to be permuted,
+ // immediately return a new vector.transpose.
+ for (int64_t i = 0, e = permutation.size(); i < e; ++i)
+ if (permutation[i] != i)
+ return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
+ // Otherwise return res.
+ return res;
+}
+
BroadcastableToResult
mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
std::pair<int, int> *mismatchingDims) {
diff --git a/mlir/test/Dialect/Vector/test-create-broadcast.mlir b/mlir/test/Dialect/Vector/test-create-broadcast.mlir
new file mode 100644
index 0000000000000..f7af184da4f51
--- /dev/null
+++ b/mlir/test/Dialect/Vector/test-create-broadcast.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s --test-create-vector-broadcast --allow-unregistered-dialect --split-input-file | FileCheck %s
+
+func.func @foo(%a : f32) -> vector<1x2xf32> {
+ %0 = "test_create_broadcast"(%a) {broadcast_dims = array<i64: 0, 1>} : (f32) -> vector<1x2xf32>
+ // CHECK: vector.broadcast {{.*}} : f32 to vector<1x2xf32>
+ // CHECK-NOT: vector.transpose
+ return %0: vector<1x2xf32>
+}
+
+// -----
+
+func.func @foo(%a : vector<2x2xf32>) -> vector<2x2x3xf32> {
+ %0 = "test_create_broadcast"(%a) {broadcast_dims = array<i64: 2>}
+ : (vector<2x2xf32>) -> vector<2x2x3xf32>
+ // CHECK: vector.broadcast {{.*}} : vector<2x2xf32> to vector<3x2x2xf32>
+ // CHECK: vector.transpose {{.*}}, [1, 2, 0] : vector<3x2x2xf32> to vector<2x2x3xf32>
+ return %0: vector<2x2x3xf32>
+}
+
+// -----
+
+func.func @foo(%a : vector<3x3xf32>) -> vector<4x3x3xf32> {
+ %0 = "test_create_broadcast"(%a) {broadcast_dims = array<i64: 0>}
+ : (vector<3x3xf32>) -> vector<4x3x3xf32>
+ // CHECK: vector.broadcast {{.*}} : vector<3x3xf32> to vector<4x3x3xf32>
+ // CHECK-NOT: vector.transpose
+ return %0: vector<4x3x3xf32>
+}
+
+// -----
+
+func.func @foo(%a : vector<2x4xf32>) -> vector<1x2x3x4x5xf32> {
+ %0 = "test_create_broadcast"(%a) {broadcast_dims = array<i64: 0, 2, 4>}
+ : (vector<2x4xf32>) -> vector<1x2x3x4x5xf32>
+ // CHECK: vector.broadcast {{.*}} : vector<2x4xf32> to vector<1x3x5x2x4xf32>
+ // CHECK: vector.transpose {{.*}}, [0, 3, 1, 4, 2] : vector<1x3x5x2x4xf32> to vector<1x2x3x4x5xf32>
+ return %0: vector<1x2x3x4x5xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 1bd40e7cde7e3..00bd07a034337 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -820,6 +820,38 @@ struct TestVectorExtractStridedSliceLowering
}
};
+struct TestCreateVectorBroadcast
+ : public PassWrapper<TestCreateVectorBroadcast,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestCreateVectorBroadcast)
+
+ StringRef getArgument() const final { return "test-create-vector-broadcast"; }
+ StringRef getDescription() const final {
+ return "Test optimization transformations for transfer ops";
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ getOperation()->walk([](Operation *op) {
+ if (op->getName().getStringRef() != "test_create_broadcast")
+ return;
+ auto targetShape =
+ op->getResult(0).getType().cast<VectorType>().getShape();
+ auto arrayAttr =
+ op->getAttr("broadcast_dims").cast<DenseI64ArrayAttr>().asArrayRef();
+ llvm::SetVector<int64_t> broadcastedDims;
+ broadcastedDims.insert(arrayAttr.begin(), arrayAttr.end());
+ OpBuilder b(op);
+ Value bcast = vector::BroadcastOp::createOrFoldBroadcastOp(
+ b, op->getOperand(0), targetShape, broadcastedDims);
+ op->getResult(0).replaceAllUsesWith(bcast);
+ op->erase();
+ });
+ }
+};
+
} // namespace
namespace mlir {
@@ -856,6 +888,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorDistribution>();
PassRegistration<TestVectorExtractStridedSliceLowering>();
+
+ PassRegistration<TestCreateVectorBroadcast>();
}
} // namespace test
} // namespace mlir
More information about the Mlir-commits
mailing list