[Mlir-commits] [mlir] 77a5ea2 - [mlir][Vector] Add basic scalable vectorization support to Linalg vectorizer
Diego Caballero
llvmlistbot at llvm.org
Tue Jun 13 17:05:16 PDT 2023
Author: Diego Caballero
Date: 2023-06-13T23:55:15Z
New Revision: 77a5ea2e671265fea8e041c8002dcc53834b9cc0
URL: https://github.com/llvm/llvm-project/commit/77a5ea2e671265fea8e041c8002dcc53834b9cc0
DIFF: https://github.com/llvm/llvm-project/commit/77a5ea2e671265fea8e041c8002dcc53834b9cc0.diff
LOG: [mlir][Vector] Add basic scalable vectorization support to Linalg vectorizer
For now, only elementwise operations are supported. Operations that perform any
kind of data permutation require changes in the representation of scalable
dimensions in VectorType.
Differential Revision: https://reviews.llvm.org/D152599
Added:
mlir/test/Dialect/Linalg/vectorization-scalable.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Linalg/vectorization-masked.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index b2c5bd17a4793..22137ed44beea 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -299,6 +299,7 @@ LogicalResult promoteSubviewsPrecondition(Operation *op,
/// Return success if the operation can be vectorized.
LogicalResult vectorizeOpPrecondition(Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
+ ArrayRef<bool> inputScalableVecDims = {},
bool vectorizeNDExtract = false);
//===----------------------------------------------------------------------===//
@@ -592,8 +593,8 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
/// dynamic shapes.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
- bool vectorizeNDExtract = false,
- bool lastVectorSizeScalable = false);
+ ArrayRef<bool> inputScalableVecDims = {},
+ bool vectorizeNDExtract = false);
/// Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 3a97d1623e6e3..58a1fa864d430 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3036,7 +3036,7 @@ struct VectorizationPattern : public RewritePattern {
if (!linalgOp)
return rewriter.notifyMatchFailure(op, "expected Linalg Op");
return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{},
- vectorizeNDExtract);
+ /*scalableVecDims=*/{}, vectorizeNDExtract);
}
private:
@@ -3137,16 +3137,16 @@ DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
}
// TODO: Check that the correct number of vectorSizes was provided.
-
+ SmallVector<bool> scalableVecDims(vectorSizes.size(), false);
+ scalableVecDims.back() = getLastVectorSizeScalable();
for (Operation *target : targets) {
if (!isa<linalg::LinalgOp, tensor::PadOp>(target)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Unsupported Op, cannot vectorize";
}
- if (failed(linalg::vectorize(rewriter, target, vectorSizes,
- getVectorizeNdExtract(),
- getLastVectorSizeScalable()))) {
+ if (failed(linalg::vectorize(rewriter, target, vectorSizes, scalableVecDims,
+ getVectorizeNdExtract()))) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Attempted to vectorize, but failed";
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 8b70f4255224f..c953a1a1879b1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -169,6 +169,21 @@ static Value insertConvResultSlices(RewriterBase &rewriter, Location loc,
return res;
}
+/// Return true if the scalable vector dimensions are supported. For now, we
+/// only support scalable vectors in the trailing dimension.
+static bool areValidScalableVecDims(ArrayRef<bool> scalableVecDims) {
+ if (scalableVecDims.empty())
+ return true;
+
+ auto isScalable = [](bool isScalableVecSize) { return isScalableVecSize; };
+ if (std::any_of(scalableVecDims.begin(), scalableVecDims.end() - 1,
+ isScalable)) {
+ return false;
+ }
+
+ return true;
+}
+
/// Contains the vectorization state and related methods used across the
/// vectorization process of a given operation.
struct VectorizationState {
@@ -177,11 +192,42 @@ struct VectorizationState {
/// Initializes the vectorization state, including the computation of the
/// canonical vector shape for vectorization.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
- ArrayRef<int64_t> inputVectorSizes);
+ ArrayRef<int64_t> inputVectorSizes,
+ ArrayRef<bool> inputScalableVecDims);
/// Returns the canonical vector shape used to vectorize the iteration space.
ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
+ /// Returns a vector type of the provided `elementType` with the canonical
+ /// vector shape and the corresponding fixed/scalable dimensions bit. If
+ /// `dimPermutation` is provided, the canonical vector dimensions are permuted
+ /// accordingly.
+ VectorType getCanonicalVecType(
+ Type elementType,
+ std::optional<AffineMap> dimPermutation = std::nullopt) const {
+ SmallVector<int64_t> vectorShape;
+ SmallVector<bool> scalableDims;
+ if (dimPermutation.has_value()) {
+ vectorShape =
+ applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
+ scalableDims =
+ applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
+ } else {
+ vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
+ scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
+ }
+
+ // Make sure we don't end up with unsupported scalable vector dimensions
+ // after the permutation. If so, we should bail out on that operation in the
+ // scalable preconditions.
+ assert(areValidScalableVecDims(scalableDims) &&
+ "Permuted scalable vector dimensions are not supported");
+
+ // TODO: Extend scalable vector type to support a bit map.
+ bool numScalableDims = !scalableVecDims.empty() && scalableVecDims.back();
+ return VectorType::get(vectorShape, elementType, numScalableDims);
+ }
+
/// Masks an operation with the canonical vector mask if the operation needs
/// masking. Returns the masked operation or the original operation if masking
/// is not needed. If provided, the canonical mask for this operation is
@@ -223,6 +269,10 @@ struct VectorizationState {
/// Holds the canonical vector shape used to vectorize the iteration space.
SmallVector<int64_t> canonicalVecShape;
+ /// Holds the vector dimensions that are scalable in the canonical vector
+ /// shape.
+ SmallVector<bool> scalableVecDims;
+
/// Holds the active masks for permutations of the canonical vector iteration
/// space.
DenseMap<AffineMap, Value> activeMaskCache;
@@ -268,7 +318,8 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
// TODO: Move this to the constructor when we can remove the failure cases.
LogicalResult
VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
- ArrayRef<int64_t> inputVectorSizes) {
+ ArrayRef<int64_t> inputVectorSizes,
+ ArrayRef<bool> inputScalableVecDims) {
// Initialize the insertion point.
rewriter.setInsertionPoint(linalgOp);
@@ -277,15 +328,22 @@ VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
// path should be taken to vectorize code with dynamic shapes and when using
// vector sizes greater than the iteration space sizes.
canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
+ scalableVecDims.append(inputScalableVecDims.begin(),
+ inputScalableVecDims.end());
} else {
// Compute the canonical vector shape from the operation shape. If there are
- // dynamic shapes, the operation won't be vectorized.
+ // dynamic shapes, the operation won't be vectorized. We assume all the
+ // vector dimensions are fixed.
canonicalVecShape = linalgOp.getStaticLoopRanges();
+ scalableVecDims.append(linalgOp.getNumLoops(), false);
}
LDBG("Canonical vector shape: ");
LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << "\n");
+ LDBG("Scalable vector dims: ");
+ LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
+ LLVM_DEBUG(llvm::dbgs() << "\n");
if (ShapedType::isDynamicShape(canonicalVecShape))
return failure();
@@ -343,9 +401,10 @@ Value VectorizationState::getOrCreateMaskFor(
// TODO: Improve this check. Only projected permutation indexing maps are
// supported.
SmallVector<int64_t> permutedStaticSizes =
- applyPermutationMap(maskingMap, ArrayRef<int64_t>(iterSpaceStaticSizes));
- SmallVector<int64_t> maskShape =
- applyPermutationMap(maskingMap, ArrayRef<int64_t>(canonicalVecShape));
+ applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
+ auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
+ auto maskShape = maskType.getShape();
+
LDBG("Mask shape: ");
LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << "\n");
@@ -362,8 +421,7 @@ Value VectorizationState::getOrCreateMaskFor(
assert(!maskShape.empty() && !upperBounds.empty() &&
"Masked 0-d vectors are not supported yet");
- // Create the mask based on the dimension size values.
- auto maskType = VectorType::get(maskShape, rewriter.getI1Type());
+ // Create the mask based on the dimension values.
Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
maskType, upperBounds);
LDBG("Creating new mask: " << mask << "\n");
@@ -504,18 +562,16 @@ static Operation *matchLinalgReduction(OpOperand *outputOperand) {
/// Broadcast `value` to a vector of `shape` if possible. Return value
/// otherwise.
-static Value broadcastIfNeeded(OpBuilder &b, Value value,
- ArrayRef<int64_t> shape) {
+static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
+ auto dstVecType = dyn_cast<VectorType>(dstType);
// If no shape to broadcast to, just return `value`.
- if (shape.empty())
+ if (dstVecType.getRank() == 0)
return value;
- VectorType targetVectorType =
- VectorType::get(shape, getElementTypeOrSelf(value));
- if (vector::isBroadcastableTo(value.getType(), targetVectorType) !=
+ if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
vector::BroadcastableToResult::Success)
return value;
Location loc = b.getInsertionPoint()->getLoc();
- return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
+ return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
}
/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
@@ -549,16 +605,15 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
Location loc = value.getLoc();
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
- auto vectorType =
- VectorType::get(opOperandMap.compose(state.getCanonicalVecShape()),
- getElementTypeOrSelf(outputOperand->get().getType()));
+ auto vectorType = state.getCanonicalVecType(
+ getElementTypeOrSelf(outputOperand->get().getType()), opOperandMap);
Operation *write;
if (vectorType.getRank() > 0) {
AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
SmallVector<Value> indices(linalgOp.getRank(outputOperand),
rewriter.create<arith::ConstantIndexOp>(loc, 0));
- value = broadcastIfNeeded(rewriter, value, vectorType.getShape());
+ value = broadcastIfNeeded(rewriter, value, vectorType);
write = rewriter.create<vector::TransferWriteOp>(
loc, value, outputOperand->get(), indices, writeMap);
} else {
@@ -639,10 +694,10 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
return VectorizationResult{VectorizationStatus::Failure, nullptr};
auto loc = indexOp.getLoc();
// Compute the static loop sizes of the index op.
- auto targetShape = llvm::to_vector(state.getCanonicalVecShape());
+ auto targetShape = state.getCanonicalVecShape();
// Compute a one-dimensional index vector for the index op dimension.
- SmallVector<int64_t> constantSeq =
- llvm::to_vector<16>(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()]));
+ auto constantSeq =
+ llvm::to_vector(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()]));
auto indexSteps = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexVectorAttr(constantSeq));
// Return the one-dimensional index vector if it lives in the trailing
@@ -653,9 +708,15 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
// Otherwise permute the targetShape to move the index dimension last,
// broadcast the one-dimensional index vector to the permuted shape, and
// finally transpose the broadcasted index vector to undo the permutation.
- std::swap(targetShape[indexOp.getDim()], targetShape.back());
+ auto permPattern =
+ llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
+ std::swap(permPattern[indexOp.getDim()], permPattern.back());
+ auto permMap =
+ AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
+
auto broadCastOp = rewriter.create<vector::BroadcastOp>(
- loc, VectorType::get(targetShape, rewriter.getIndexType()), indexSteps);
+ loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
+ indexSteps);
SmallVector<int64_t> transposition =
llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
std::swap(transposition.back(), transposition[indexOp.getDim()]);
@@ -698,15 +759,15 @@ tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
/// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
/// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
static Value calculateGatherOffset(RewriterBase &rewriter,
+ VectorizationState &state,
tensor::ExtractOp extractOp,
- const IRMapping &bvm,
- const ArrayRef<int64_t> targetShape) {
- // The vector of indices for GatherOp should be shaped as the output vector
- auto indexVecType = VectorType::get(targetShape, rewriter.getIndexType());
+ const IRMapping &bvm) {
+ // The vector of indices for GatherOp should be shaped as the output vector.
+ auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
auto loc = extractOp.getLoc();
Value offset = broadcastIfNeeded(
- rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType.getShape());
+ rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
const size_t numIndices = extractOp.getIndices().size();
for (size_t i = 1; i < numIndices; i++) {
@@ -715,13 +776,12 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
auto dimSize = broadcastIfNeeded(
rewriter,
rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
- indexVecType.getShape());
+ indexVecType);
offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
- auto extractOpIndex =
- broadcastIfNeeded(rewriter, bvm.lookup(extractOp.getIndices()[i]),
- indexVecType.getShape());
+ auto extractOpIndex = broadcastIfNeeded(
+ rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
}
@@ -935,14 +995,11 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
auto loc = extractOp.getLoc();
// Compute the static loop sizes of the extract op.
- auto targetShape = state.getCanonicalVecShape();
-
- auto resultType =
- VectorType::get(targetShape, extractOp.getResult().getType());
+ auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
auto maskConstantOp = rewriter.create<arith::ConstantOp>(
- loc, DenseIntElementsAttr::get(
- VectorType::get(targetShape, rewriter.getI1Type()),
- /*value=*/true));
+ loc,
+ DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
+ /*value=*/true));
auto passThruConstantOp =
rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
@@ -957,7 +1014,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
// 1. Handle gather access
if (memAccessKind == VectorMemoryAccessKind::Gather) {
- Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape);
+ Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
// Generate the gather load
Operation *gatherOp = rewriter.create<vector::GatherOp>(
@@ -1090,8 +1147,8 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
/// This function does not update `bvm` but returns a VectorizationStatus that
/// instructs the caller what `bvm` update needs to occur.
static VectorizationResult
-vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op,
- const IRMapping &bvm,
+vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
+ LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
LDBG("vectorize op " << *op << "\n");
@@ -1139,33 +1196,41 @@ vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op,
}
// 5. Generic vectorization path for ElementwiseMappable ops.
- // a. first get the first max ranked shape.
- SmallVector<int64_t, 4> firstMaxRankedShape;
+ // a. Get the first max ranked shape.
+ VectorType firstMaxRankedType;
for (Value operand : op->getOperands()) {
- auto vt = dyn_cast<VectorType>(bvm.lookup(operand).getType());
- if (vt && firstMaxRankedShape.size() < vt.getShape().size())
- firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end());
- }
- // rewriter. broadcast each op if needed.
- auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) {
- return firstMaxRankedShape.empty()
- ? bvm.lookup(v)
- : broadcastIfNeeded(rewriter, bvm.lookup(v),
- firstMaxRankedShape);
- });
+ auto vecType = dyn_cast<VectorType>(bvm.lookup(operand).getType());
+ if (vecType && (!firstMaxRankedType ||
+ firstMaxRankedType.getRank() < vecType.getRank()))
+ firstMaxRankedType = vecType;
+ }
+ // b. Broadcast each op if needed.
+ SmallVector<Value> vectorizedOperands;
+ for (Value scalarOperand : op->getOperands()) {
+ Value vectorizedOperand = bvm.lookup(scalarOperand);
+ auto vecType =
+ VectorType::get(firstMaxRankedType.getShape(),
+ getElementTypeOrSelf(vectorizedOperand.getType()),
+ firstMaxRankedType.getNumScalableDims());
+ vectorizedOperands.push_back(
+ !firstMaxRankedType
+ ? vectorizedOperand
+ : broadcastIfNeeded(rewriter, vectorizedOperand, vecType));
+ }
// c. for elementwise, the result is the vector with the firstMaxRankedShape
- auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) {
- return firstMaxRankedShape.empty()
- ? t
- : VectorType::get(firstMaxRankedShape, t);
- });
-
- // Build and return the new op.
+ SmallVector<Type> resultTypes;
+ for (Type resultType : op->getResultTypes()) {
+ resultTypes.push_back(
+ !firstMaxRankedType
+ ? resultType
+ : VectorType::get(firstMaxRankedType.getShape(), resultType,
+ firstMaxRankedType.getNumScalableDims()));
+ }
+ // d. Build and return the new op.
return VectorizationResult{
VectorizationStatus::NewOp,
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
- llvm::to_vector<4>(vectorizedOperands),
- llvm::to_vector<4>(returnTypes), op->getAttrs())};
+ vectorizedOperands, resultTypes, op->getAttrs())};
}
/// Generic vectorization function that rewrites the body of a `linalgOp` into
@@ -1232,22 +1297,21 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
AffineMap maskingMap = indexingMap.dropResults(zeroPos);
AffineMap readMap;
- SmallVector<int64_t> readVecShape;
+ VectorType readType;
+ Type elemType = getElementTypeOrSelf(opOperand->get());
if (linalgOp.isDpsInput(opOperand)) {
// 3.a.i. For input reads we use the canonical vector shape.
readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
- readVecShape = llvm::to_vector(state.getCanonicalVecShape());
+ readType = state.getCanonicalVecType(elemType);
} else {
// 3.a.ii. For output reads (iteration-carried dependence, e.g.,
// reductions), the vector shape is computed by mapping the canonical
// vector shape to the output domain and back to the canonical domain.
readMap = inversePermutation(reindexIndexingMap(indexingMap));
- readVecShape =
- readMap.compose(indexingMap.compose(state.getCanonicalVecShape()));
+ readType =
+ state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
}
- auto readType =
- VectorType::get(readVecShape, getElementTypeOrSelf(opOperand->get()));
SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
Operation *read = rewriter.create<vector::TransferReadOp>(
@@ -1265,7 +1329,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
// 3.c. Not all ops support 0-d vectors, extract the scalar for now.
// TODO: remove this.
- if (cast<VectorType>(readValue.getType()).getRank() == 0)
+ if (readType.getRank() == 0)
readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
@@ -1299,7 +1363,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
// 5. Iteratively call `vectorizeOneOp` to each op in the slice.
for (Operation &op : block->getOperations()) {
VectorizationResult result =
- vectorizeOneOp(rewriter, linalgOp, &op, bvm, hooks);
+ vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
if (result.status == VectorizationStatus::Failure) {
LDBG("failed to vectorize: " << op << "\n");
return failure();
@@ -1526,10 +1590,38 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
return success();
}
-LogicalResult
-mlir::linalg::vectorizeOpPrecondition(Operation *op,
- ArrayRef<int64_t> inputVectorSizes,
- bool vectorizeNDExtract) {
+/// Preconditions for scalable vectors.
+static LogicalResult
+vectorizeScalableVectorPrecondition(Operation *op,
+ ArrayRef<int64_t> inputVectorSizes,
+ ArrayRef<bool> inputScalableVecDims) {
+ assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
+ "Number of input vector sizes and scalable dims doesn't match");
+
+ if (inputVectorSizes.empty())
+ return success();
+
+ if (!areValidScalableVecDims(inputScalableVecDims)) {
+ LDBG("Non-trailing scalable vector dimensions are not supported\n");
+ return failure();
+ }
+
+ bool isScalable = inputScalableVecDims.back();
+ if (!isScalable)
+ return success();
+
+ // Only element-wise ops supported in the presence of scalable dims.
+ auto linalgOp = dyn_cast<LinalgOp>(op);
+ return success(linalgOp && isElementwise(linalgOp));
+}
+
+LogicalResult mlir::linalg::vectorizeOpPrecondition(
+ Operation *op, ArrayRef<int64_t> inputVectorSizes,
+ ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract) {
+ if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
+ inputScalableVecDims)))
+ return failure();
+
return TypeSwitch<Operation *, LogicalResult>(op)
.Case<linalg::LinalgOp>([&](auto linalgOp) {
return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
@@ -1564,19 +1656,18 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
/// operations with dynamic shapes.
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes,
- bool vectorizeNDExtract,
- bool lastVectorSizeScalable) {
+ ArrayRef<bool> inputScalableVecDims,
+ bool vectorizeNDExtract) {
LDBG("Attempting to vectorize:\n" << *op << "\n");
LDBG("Input vector sizes: ");
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << "\n");
- LDBG("Scalable vectorisation: " << lastVectorSizeScalable << "\n");
-
- if (lastVectorSizeScalable)
- op->emitWarning("Scalable vectorization is not supported yet");
+ LDBG("Input scalable vector dims: ");
+ LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
+ LLVM_DEBUG(llvm::dbgs() << "\n");
- if (failed(
- vectorizeOpPrecondition(op, inputVectorSizes, vectorizeNDExtract))) {
+ if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
+ vectorizeNDExtract))) {
LDBG("Vectorization pre-conditions failed\n");
return failure();
}
@@ -1584,7 +1675,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
// Initialize vectorization state.
VectorizationState state(rewriter);
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
- if (failed(state.initState(rewriter, linalgOp, inputVectorSizes))) {
+ if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
+ inputScalableVecDims))) {
LDBG("Vectorization state couldn't be initialized\n");
return failure();
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index acccd66f7c03f..a3220ef85b6f9 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -346,7 +346,8 @@ LogicalResult MultiDimReductionOp::verify() {
Type MultiDimReductionOp::getExpectedMaskType() {
auto vecType = getSourceVectorType();
return VectorType::get(vecType.getShape(),
- IntegerType::get(vecType.getContext(), /*width=*/1));
+ IntegerType::get(vecType.getContext(), /*width=*/1),
+ vecType.getNumScalableDims());
}
namespace {
@@ -483,8 +484,9 @@ void ReductionOp::print(OpAsmPrinter &p) {
/// Returns the mask type expected by this operation.
Type ReductionOp::getExpectedMaskType() {
auto vecType = getSourceVectorType();
- return vecType.cloneWith(std::nullopt,
- IntegerType::get(vecType.getContext(), /*width=*/1));
+ return VectorType::get(vecType.getShape(),
+ IntegerType::get(vecType.getContext(), /*width=*/1),
+ vecType.getNumScalableDims());
}
Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
@@ -926,6 +928,10 @@ Type ContractionOp::getExpectedMaskType() {
assert(!ShapedType::isDynamicShape(maskShape) &&
"Mask shape couldn't be computed");
+ // TODO: Extend the scalable vector type representation with a bit map.
+ assert(lhsType.getNumScalableDims() == 0 &&
+ rhsType.getNumScalableDims() == 0 &&
+ "Scalable vectors are not supported yet");
return VectorType::get(maskShape,
IntegerType::get(lhsType.getContext(), /*width=*/1));
@@ -2856,7 +2862,8 @@ LogicalResult OuterProductOp::verify() {
Type OuterProductOp::getExpectedMaskType() {
auto vecType = this->getResultVectorType();
return VectorType::get(vecType.getShape(),
- IntegerType::get(vecType.getContext(), /*width=*/1));
+ IntegerType::get(vecType.getContext(), /*width=*/1),
+ vecType.getNumScalableDims());
}
//===----------------------------------------------------------------------===//
@@ -3509,9 +3516,12 @@ static VectorType inferTransferOpMaskType(VectorType vecType,
AffineMap permMap) {
auto i1Type = IntegerType::get(permMap.getContext(), 1);
AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
+ // TODO: Extend the scalable vector type representation with a bit map.
+ assert((permMap.isMinorIdentity() || vecType.getNumScalableDims() == 0) &&
+ "Scalable vectors are not supported yet");
assert(invPermMap && "Inversed permutation map couldn't be computed");
SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
- return VectorType::get(maskShape, i1Type);
+ return VectorType::get(maskShape, i1Type, vecType.getNumScalableDims());
}
ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -4470,7 +4480,8 @@ LogicalResult GatherOp::verify() {
Type GatherOp::getExpectedMaskType() {
auto vecType = this->getIndexVectorType();
return VectorType::get(vecType.getShape(),
- IntegerType::get(vecType.getContext(), /*width=*/1));
+ IntegerType::get(vecType.getContext(), /*width=*/1),
+ vecType.getNumScalableDims());
}
std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
diff --git a/mlir/test/Dialect/Linalg/vectorization-masked.mlir b/mlir/test/Dialect/Linalg/vectorization-masked.mlir
index 1b1202532ca27..985dd054c25eb 100644
--- a/mlir/test/Dialect/Linalg/vectorization-masked.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-masked.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file --verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s
func.func @vectorize_dynamic_identity(%arg0: tensor<?xf32>,
%arg1: tensor<?xf32>,
@@ -485,17 +485,3 @@ transform.sequence failures(propagate) {
transform.structured.masked_vectorize %0 vector_sizes [8, 16, 4] : !transform.any_op
}
-// -----
-
-func.func @vectorize_dynamic_matmul_scalable(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
- // expected-warning @+1 {{Scalable vectorization is not supported yet}}
- linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
- outs(%C: memref<?x?xf32>)
- return
-}
-
-transform.sequence failures(propagate) {
-^bb1(%arg1: !transform.any_op):
- %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.structured.masked_vectorize %0 vector_sizes [8, 16, [4]] : !transform.any_op
-}
diff --git a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
new file mode 100644
index 0000000000000..957313b43d4b3
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
@@ -0,0 +1,136 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s
+
+func.func @vectorize_dynamic_identity(%arg0: tensor<?xf32>,
+ %arg1: tensor<?xf32>,
+ %arg2: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"] }
+ ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
+ outs(%arg2 : tensor<?xf32>) {
+ ^bb(%in0: f32, %in1: f32, %out: f32) :
+ %0 = arith.addf %in0, %in1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL: @vectorize_dynamic_identity
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = tensor.dim %{{.*}}, %[[VAL_3]] : tensor<?xf32>
+// CHECK: %[[VAL_7:.*]] = vector.create_mask %[[VAL_4]] : vector<[4]xi1>
+// CHECK: %[[VAL_8:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
+// CHECK: %[[VAL_10:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
+// CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
+// CHECK: %[[VAL_13:.*]] = arith.addf %[[VAL_8]], %[[VAL_10]] : vector<[4]xf32>
+// CHECK: %[[VAL_14:.*]] = vector.mask %[[VAL_7]] { vector.transfer_write %{{.*}} {in_bounds = [true]} : vector<[4]xf32>, tensor<?xf32> } : vector<[4]xi1> -> tensor<?xf32>
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.masked_vectorize %0 vector_sizes [[4]] : !transform.any_op
+}
+
+// -----
+
+func.func @vectorize_partial_dynamic_identity(%arg0: tensor<8x?xf32>,
+ %arg1: tensor<8x?xf32>,
+ %arg2: tensor<8x?xf32>) -> tensor<8x?xf32> {
+ %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"] }
+ ins(%arg0, %arg1 : tensor<8x?xf32>, tensor<8x?xf32>)
+ outs(%arg2 : tensor<8x?xf32>) {
+ ^bb(%in0: f32, %in1: f32, %out: f32) :
+ %0 = arith.addf %in0, %in1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<8x?xf32>
+ return %0 : tensor<8x?xf32>
+}
+
+// CHECK-LABEL: func.func @vectorize_partial_dynamic_identity(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x?xf32>, %[[VAL_1:.*]]: tensor<8x?xf32>, %[[VAL_2:.*]]: tensor<8x?xf32>) -> tensor<8x?xf32> {
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<8x?xf32>
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 8 : index
+// CHECK: %[[VAL_8:.*]] = vector.create_mask %[[VAL_7]], %[[VAL_4]] : vector<8x[32]xi1>
+// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read %[[VAL_0]][%[[VAL_5]], %[[VAL_5]]], %[[VAL_6]] {in_bounds = [true, true]} : tensor<8x?xf32>, vector<8x[32]xf32> } : vector<8x[32]xi1> -> vector<8x[32]xf32>
+// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read %[[VAL_1]][%[[VAL_5]], %[[VAL_5]]], %[[VAL_10]] {in_bounds = [true, true]} : tensor<8x?xf32>, vector<8x[32]xf32> } : vector<8x[32]xi1> -> vector<8x[32]xf32>
+// CHECK: %[[VAL_12:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read %[[VAL_2]][%[[VAL_5]], %[[VAL_5]]], %[[VAL_12]] {in_bounds = [true, true]} : tensor<8x?xf32>, vector<8x[32]xf32> } : vector<8x[32]xi1> -> vector<8x[32]xf32>
+// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_9]], %[[VAL_11]] : vector<8x[32]xf32>
+// CHECK: %[[VAL_15:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_16:.*]] = vector.mask %[[VAL_8]] { vector.transfer_write %[[VAL_14]], %[[VAL_2]][%[[VAL_15]], %[[VAL_15]]] {in_bounds = [true, true]} : vector<8x[32]xf32>, tensor<8x?xf32> } : vector<8x[32]xi1> -> tensor<8x?xf32>
+
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.masked_vectorize %0 vector_sizes [8, [32]] : !transform.any_op
+}
+
+// -----
+
+func.func @vectorize_static_shape_with_mask(%arg0: tensor<8x30xf32>,
+ %arg1: tensor<8x30xf32>,
+ %arg2: tensor<8x30xf32>) -> tensor<8x30xf32> {
+ %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"] }
+ ins(%arg0, %arg1 : tensor<8x30xf32>, tensor<8x30xf32>)
+ outs(%arg2 : tensor<8x30xf32>) {
+ ^bb(%in0: f32, %in1: f32, %out: f32) :
+ %0 = arith.addf %in0, %in1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<8x30xf32>
+ return %0 : tensor<8x30xf32>
+}
+
+// CHECK-LABEL: func.func @vectorize_static_shape_with_mask(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x30xf32>, %[[VAL_1:.*]]: tensor<8x30xf32>, %[[VAL_2:.*]]: tensor<8x30xf32>) -> tensor<8x30xf32> {
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 30 : index
+// CHECK: %[[VAL_7:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_6]] : vector<8x[32]xi1>
+// CHECK: %[[VAL_8:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %[[VAL_0]][%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<8x30xf32>, vector<8x[32]xf32> } : vector<8x[32]xi1> -> vector<8x[32]xf32>
+// CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_10:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %[[VAL_1]][%[[VAL_3]], %[[VAL_3]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor<8x30xf32>, vector<8x[32]xf32> } : vector<8x[32]xi1> -> vector<8x[32]xf32>
+// CHECK: %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %[[VAL_2]][%[[VAL_3]], %[[VAL_3]]], %[[VAL_11]] {in_bounds = [true, true]} : tensor<8x30xf32>, vector<8x[32]xf32> } : vector<8x[32]xi1> -> vector<8x[32]xf32>
+// CHECK: %[[VAL_13:.*]] = arith.addf %[[VAL_8]], %[[VAL_10]] : vector<8x[32]xf32>
+// CHECK: %[[VAL_14:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_7]] { vector.transfer_write %[[VAL_13]], %[[VAL_2]][%[[VAL_14]], %[[VAL_14]]] {in_bounds = [true, true]} : vector<8x[32]xf32>, tensor<8x30xf32> } : vector<8x[32]xi1> -> tensor<8x30xf32>
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.masked_vectorize %0 vector_sizes [8, [32]] : !transform.any_op
+}
+
+// -----
+
+func.func @vectorize_dynamic_fill(%A : tensor<?x?xf32>, %arg0 : f32) -> tensor<?x?xf32> {
+ %0 = linalg.fill ins(%arg0 : f32) outs(%A : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func.func @vectorize_dynamic_fill
+// CHECK: %[[DIM0:.*]] = tensor.dim
+// CHECK: %[[DIM1:.*]] = tensor.dim
+// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM1]] : vector<8x[16]xi1>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %{{.*}} : f32 to vector<8x[16]xf32>
+// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[BCAST]], {{.*}} {in_bounds = [true, true]} : vector<8x[16]xf32>, tensor<?x?xf32> } : vector<8x[16]xi1>
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.masked_vectorize %0 vector_sizes [8, [16]] : !transform.any_op
+}
+
More information about the Mlir-commits
mailing list