[Mlir-commits] [mlir] c1c4c8e - Revert "[mlir][vector] Migrate drop-lead-unit-dim to shape_cast #196206" (#199546)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 25 09:36:41 PDT 2026
Author: Omair Javaid
Date: 2026-05-25T21:36:36+05:00
New Revision: c1c4c8e23d099c199ea90b050742c3d6c5efcfaf
URL: https://github.com/llvm/llvm-project/commit/c1c4c8e23d099c199ea90b050742c3d6c5efcfaf
DIFF: https://github.com/llvm/llvm-project/commit/c1c4c8e23d099c199ea90b050742c3d6c5efcfaf.diff
LOG: Revert "[mlir][vector] Migrate drop-lead-unit-dim to shape_cast #196206" (#199546)
This reverts commit 24b8bb18f3417419cbd16fcd31f4e2842df952a1 from
#196206
This broke AArch64 SVE Linux buildbots, however it was not reported due
a glitch in the buildbot infrastructure. Following bots are failing:
https://lab.llvm.org/buildbot/#/builders/121
https://lab.llvm.org/buildbot/#/builders/41
https://lab.llvm.org/buildbot/#/builders/4
https://lab.llvm.org/buildbot/#/builders/199
https://lab.llvm.org/buildbot/#/builders/17
https://lab.llvm.org/buildbot/#/builders/198
https://lab.llvm.org/buildbot/#/builders/143
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
mlir/test/Dialect/Vector/vector-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index aad42039300e3..26a702ef0f512 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -7,9 +7,7 @@
//===----------------------------------------------------------------------===//
#include <numeric>
-#include <utility>
-#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -17,7 +15,6 @@
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/TypeUtilities.h"
-#include "llvm/ADT/Repeated.h"
#include "llvm/ADT/STLExtras.h"
#define DEBUG_TYPE "vector-drop-unit-dim"
@@ -25,9 +22,9 @@
using namespace mlir;
using namespace mlir::vector;
-// Trims leading unit dimensions from `oldType` and returns the result type.
-static VectorType trimLeadingUnitDims(VectorType oldType,
- bool zeroDimsAllowed) {
+// Trims leading one dimensions from `oldType` and returns the result type.
+// Returns `vector<1xT>` if `oldType` only has one element.
+static VectorType trimLeadingOneDims(VectorType oldType) {
ArrayRef<int64_t> oldShape = oldType.getShape();
ArrayRef<int64_t> newShape = oldShape;
@@ -40,117 +37,22 @@ static VectorType trimLeadingUnitDims(VectorType oldType,
newScalableDims = newScalableDims.drop_front(1);
}
- // Some vector ops forbid 0-D vectors.
- if (!zeroDimsAllowed && newShape.empty()) {
+ // Make sure we have at least 1 dimension per vector type requirements.
+ if (newShape.empty()) {
newShape = oldShape.take_back();
newScalableDims = oldType.getScalableDims().take_back();
}
return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
}
-static bool isNonScalableUnitDim(VectorType type, int64_t dim) {
- assert(dim >= 0 && dim < type.getRank() &&
- "expected a valid vector dimension");
- return type.getShape()[dim] == 1 && !type.getScalableDims()[dim];
+/// Return a smallVector of size `rank` containing all zeros.
+static SmallVector<int64_t> splatZero(int64_t rank) {
+ return SmallVector<int64_t>(rank, 0);
}
-
-/// Returns true if the first `k` dimensions of `type` are non-scalable unit
-/// dimensions.
-static bool areLeadingDimsUnit(VectorType type, int64_t k) {
- assert(k >= 0 && k <= type.getRank() &&
- "expected a valid leading dimension count");
- return llvm::all_of(llvm::seq<int64_t>(0, k), [&](int64_t dim) {
- return isNonScalableUnitDim(type, dim);
- });
-}
-
-static bool areLeadingDimsUnitAfterPermutation(VectorType type,
- ArrayRef<int64_t> permutation,
- int64_t k) {
- assert(k >= 0 && k <= static_cast<int64_t>(permutation.size()) &&
- "expected a valid leading dimension count");
- return llvm::all_of(permutation.take_front(k), [&](int64_t dim) {
- return isNonScalableUnitDim(type, dim);
- });
-}
-
-/// Shape-casts `operand` to the vector type obtained by dropping dimension
-/// `dim`, which must be non-scalable and unit-sized.
-static Value dropUnitDim(OpBuilder &b, Location loc, Value operand,
- int64_t dimToDrop, bool zeroDimsAllowed) {
- auto oldType = cast<VectorType>(operand.getType());
- assert(isNonScalableUnitDim(oldType, dimToDrop) &&
- "expected a non-scalable unit dim to drop");
- int64_t rank = oldType.getRank();
- assert((zeroDimsAllowed || rank > 1) &&
- "target op does not allow 0-D vectors");
-
- SmallVector<int64_t> newShape;
- SmallVector<bool> newScalableDims;
- newShape.reserve(rank - 1);
- newScalableDims.reserve(rank - 1);
- for (auto [i, size, scalable] :
- llvm::enumerate(oldType.getShape(), oldType.getScalableDims())) {
- if (static_cast<int64_t>(i) == dimToDrop)
- continue;
- newShape.push_back(size);
- newScalableDims.push_back(scalable);
- }
-
- return b.createOrFold<vector::ShapeCastOp>(
- loc, VectorType::get(newShape, oldType.getElementType(), newScalableDims),
- operand);
-}
-
-/// Shape-casts `operand` to the vector type obtained by dropping the first
-/// `k` non-scalable unit dimensions.
-static Value dropLeadingUnitDims(OpBuilder &b, Location loc, Value operand,
- int64_t k, bool zeroDimsAllowed) {
- auto oldType = cast<VectorType>(operand.getType());
- assert(areLeadingDimsUnit(oldType, k) &&
- "expected non-scalable leading unit dims to drop");
- assert((zeroDimsAllowed || k < oldType.getRank()) &&
- "target op does not allow 0-D vectors");
- VectorType newType = VectorType::get(oldType.getShape().drop_front(k),
- oldType.getElementType(),
- oldType.getScalableDims().drop_front(k));
- return b.createOrFold<vector::ShapeCastOp>(loc, newType, operand);
-}
-
-/// Returns the vector type obtained by applying `permutation` to `type`.
-static VectorType permuteVectorType(VectorType type,
- ArrayRef<int64_t> permutation) {
- assert(static_cast<int64_t>(permutation.size()) == type.getRank() &&
- "expected a permutation matching the operand rank");
- SmallVector<int64_t> permutedShape =
- applyPermutation(type.getShape(), permutation);
- SmallVector<bool> permutedScalableDims =
- applyPermutation(type.getScalableDims(), permutation);
- return VectorType::get(permutedShape, type.getElementType(),
- permutedScalableDims);
-}
-
-/// Like `dropLeadingUnitDims` except that if all dimensions would be dropped,
-/// the single element inside that vector is extracted and returned.
-static Value dropLeadingUnitDims0DIsScalar(OpBuilder &b, Location loc,
- Value operand, int64_t k) {
- auto oldType = cast<VectorType>(operand.getType());
- assert(areLeadingDimsUnit(oldType, k) &&
- "expected non-scalable leading unit dims to drop");
-
- if (k == oldType.getRank()) {
- SmallVector<int64_t> zeros(k, static_cast<int64_t>(0));
- return vector::ExtractOp::create(b, loc, operand, zeros);
- }
-
- return dropLeadingUnitDims(b, loc, operand, k,
- /*zeroDimsAllowed=*/true);
-}
-
namespace {
// Casts away leading one dimensions in vector.extract_strided_slice's vector
-// input by inserting vector.shape_cast.
+// input by inserting vector.broadcast.
struct CastAwayExtractStridedSliceLeadingOneDim
: public OpRewritePattern<vector::ExtractStridedSliceOp> {
using Base::Base;
@@ -161,8 +63,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim
// the same rank. Here we drop leading one dimensions from the input vector
// type to make sure we don't cause mismatch.
VectorType oldSrcType = extractOp.getSourceVectorType();
- VectorType newSrcType =
- trimLeadingUnitDims(oldSrcType, /*zeroDimsAllowed=*/false);
+ VectorType newSrcType = trimLeadingOneDims(oldSrcType);
if (newSrcType.getRank() == oldSrcType.getRank())
return failure();
@@ -177,8 +78,8 @@ struct CastAwayExtractStridedSliceLeadingOneDim
Location loc = extractOp.getLoc();
- Value newSrcVector = rewriter.createOrFold<vector::ShapeCastOp>(
- loc, newSrcType, extractOp.getSource());
+ Value newSrcVector = vector::ExtractOp::create(
+ rewriter, loc, extractOp.getSource(), splatZero(dropCount));
// The offsets/sizes/strides attribute can have a less number of elements
// than the input vector's rank: it is meant for the leading dimensions.
@@ -193,7 +94,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim
rewriter, loc, newDstType, newSrcVector, newOffsets, newSizes,
newStrides);
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, oldDstType,
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
newExtractOp);
return success();
@@ -201,7 +102,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim
};
// Casts away leading one dimensions in vector.insert_strided_slice's vector
-// inputs by inserting vector.shape_cast.
+// inputs by inserting vector.broadcast.
struct CastAwayInsertStridedSliceLeadingOneDim
: public OpRewritePattern<vector::InsertStridedSliceOp> {
using Base::Base;
@@ -209,11 +110,9 @@ struct CastAwayInsertStridedSliceLeadingOneDim
LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
PatternRewriter &rewriter) const override {
VectorType oldSrcType = insertOp.getSourceVectorType();
- VectorType newSrcType =
- trimLeadingUnitDims(oldSrcType, /*zeroDimsAllowed=*/false);
+ VectorType newSrcType = trimLeadingOneDims(oldSrcType);
VectorType oldDstType = insertOp.getDestVectorType();
- VectorType newDstType =
- trimLeadingUnitDims(oldDstType, /*zeroDimsAllowed=*/false);
+ VectorType newDstType = trimLeadingOneDims(oldDstType);
int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
@@ -223,10 +122,10 @@ struct CastAwayInsertStridedSliceLeadingOneDim
// Trim leading one dimensions from both operands.
Location loc = insertOp.getLoc();
- Value newSrcVector = rewriter.createOrFold<vector::ShapeCastOp>(
- loc, newSrcType, insertOp.getValueToStore());
- Value newDstVector = rewriter.createOrFold<vector::ShapeCastOp>(
- loc, newDstType, insertOp.getDest());
+ Value newSrcVector = vector::ExtractOp::create(
+ rewriter, loc, insertOp.getValueToStore(), splatZero(srcDropCount));
+ Value newDstVector = vector::ExtractOp::create(
+ rewriter, loc, insertOp.getDest(), splatZero(dstDropCount));
auto newOffsets = rewriter.getArrayAttr(
insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
@@ -237,7 +136,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim
rewriter, loc, newDstType, newSrcVector, newDstVector, newOffsets,
newStrides);
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(insertOp, oldDstType,
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
newInsertOp);
return success();
@@ -245,7 +144,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim
};
// Casts away leading one dimensions in vector.insert's vector inputs by
-// inserting vector.shape_cast.
+// inserting vector.broadcast.
struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
using Base::Base;
@@ -255,14 +154,13 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
Type newSrcType = oldSrcType;
int64_t oldSrcRank = 0, newSrcRank = 0;
if (auto type = dyn_cast<VectorType>(oldSrcType)) {
- newSrcType = trimLeadingUnitDims(type, /*zeroDimsAllowed=*/false);
+ newSrcType = trimLeadingOneDims(type);
oldSrcRank = type.getRank();
newSrcRank = cast<VectorType>(newSrcType).getRank();
}
VectorType oldDstType = insertOp.getDestVectorType();
- VectorType newDstType =
- trimLeadingUnitDims(oldDstType, /*zeroDimsAllowed=*/oldSrcRank == 0);
+ VectorType newDstType = trimLeadingOneDims(oldDstType);
int64_t srcDropCount = oldSrcRank - newSrcRank;
int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
@@ -273,11 +171,12 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
Location loc = insertOp.getLoc();
Value newSrcVector = insertOp.getValueToStore();
- if (oldSrcRank != 0)
- newSrcVector = rewriter.createOrFold<vector::ShapeCastOp>(
- loc, cast<VectorType>(newSrcType), insertOp.getValueToStore());
- Value newDstVector = rewriter.createOrFold<vector::ShapeCastOp>(
- loc, newDstType, insertOp.getDest());
+ if (oldSrcRank != 0) {
+ newSrcVector = vector::ExtractOp::create(
+ rewriter, loc, insertOp.getValueToStore(), splatZero(srcDropCount));
+ }
+ Value newDstVector = vector::ExtractOp::create(
+ rewriter, loc, insertOp.getDest(), splatZero(dstDropCount));
// New position rank needs to be computed in two steps: (1) if destination
// type has leading unit dims, we also trim the position array accordingly,
@@ -294,7 +193,7 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
auto newInsertOp = vector::InsertOp::create(rewriter, loc, newSrcVector,
newDstVector, newPosition);
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(insertOp, oldDstType,
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
newInsertOp);
return success();
@@ -302,10 +201,20 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
};
static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask,
- VectorType newType, AffineMap newMap) {
+ VectorType newType, AffineMap newMap,
+ VectorType oldMaskType) {
// Infer the type of the new mask from the new map.
VectorType newMaskType = inferTransferOpMaskType(newType, newMap);
- return b.createOrFold<vector::ShapeCastOp>(loc, newMaskType, mask);
+
+ // If the new mask is broadcastable to the old result type, we can safely
+ // use a `vector.extract` to get the new mask. Otherwise the best we can
+ // do is shape cast.
+ if (vector::isBroadcastableTo(newMaskType, oldMaskType) ==
+ BroadcastableToResult::Success) {
+ int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
+ return vector::ExtractOp::create(b, loc, mask, splatZero(dropDim));
+ }
+ return vector::ShapeCastOp::create(b, loc, newMaskType, mask);
}
// Turns vector.transfer_read on vector with leading 1 dimensions into
@@ -320,7 +229,7 @@ struct CastAwayTransferReadLeadingOneDim
// TODO(#78787): Not supported masked op yet.
if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
return failure();
- // Nothing to trim when the transfer itself has rank zero.
+ // TODO: support 0-d corner case.
if (read.getTransferRank() == 0)
return failure();
@@ -329,7 +238,7 @@ struct CastAwayTransferReadLeadingOneDim
return failure();
VectorType oldType = read.getVectorType();
- VectorType newType = trimLeadingUnitDims(oldType, /*zeroDimsAllowed=*/true);
+ VectorType newType = trimLeadingOneDims(oldType);
if (newType == oldType)
return failure();
@@ -347,14 +256,16 @@ struct CastAwayTransferReadLeadingOneDim
read.getInBoundsAttr().getValue().take_back(newType.getRank()));
Value mask = Value();
- if (read.getMask())
+ if (read.getMask()) {
+ VectorType maskType = read.getMaskType();
mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(),
- newType, newMap);
+ newType, newMap, maskType);
+ }
auto newRead = vector::TransferReadOp::create(
rewriter, read.getLoc(), newType, read.getBase(), read.getIndices(),
AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr);
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(read, oldType, newRead);
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
return success();
}
@@ -372,7 +283,7 @@ struct CastAwayTransferWriteLeadingOneDim
// TODO(#78787): Not supported masked op yet.
if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
return failure();
- // Nothing to trim when the transfer itself has rank zero.
+ // TODO: support 0-d corner case.
if (write.getTransferRank() == 0)
return failure();
@@ -381,9 +292,11 @@ struct CastAwayTransferWriteLeadingOneDim
return failure();
VectorType oldType = write.getVectorType();
- VectorType newType = trimLeadingUnitDims(oldType, /*zeroDimsAllowed=*/true);
+ VectorType newType = trimLeadingOneDims(oldType);
if (newType == oldType)
return failure();
+ int64_t dropDim = oldType.getRank() - newType.getRank();
+
AffineMap oldMap = write.getPermutationMap();
ArrayRef<AffineExpr> newResults =
oldMap.getResults().take_back(newType.getRank());
@@ -396,12 +309,13 @@ struct CastAwayTransferWriteLeadingOneDim
inBoundsAttr = rewriter.getArrayAttr(
write.getInBoundsAttr().getValue().take_back(newType.getRank()));
- auto newVector = rewriter.createOrFold<vector::ShapeCastOp>(
- write.getLoc(), newType, write.getVector());
+ auto newVector = vector::ExtractOp::create(
+ rewriter, write.getLoc(), write.getVector(), splatZero(dropDim));
if (write.getMask()) {
- Value newMask = dropUnitDimsFromMask(rewriter, write.getLoc(),
- write.getMask(), newType, newMap);
+ VectorType maskType = write.getMaskType();
+ Value newMask = dropUnitDimsFromMask(
+ rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
write, newVector, write.getBase(), write.getIndices(),
AffineMapAttr::get(newMap), newMask, inBoundsAttr);
@@ -417,15 +331,6 @@ struct CastAwayTransferWriteLeadingOneDim
} // namespace
-namespace {
-struct VectorContractOperandCastPlan {
- AffineMap map;
- SmallVector<int64_t> permutation;
- bool dropLeadingUnitDim = false;
- bool permuteOperand = false;
-};
-} // namespace
-
FailureOr<Value>
mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
MaskingOpInterface maskingOp,
@@ -435,7 +340,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
return failure();
if (oldAccType.getRank() < 1)
return failure();
- if (!isNonScalableUnitDim(oldAccType, 0))
+ if (oldAccType.getShape()[0] != 1)
return failure();
// currently we support only dropping one dim but the pattern can be applied
// greedily to drop more.
@@ -462,70 +367,74 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
contractOp.getAcc()};
- SmallVector<VectorContractOperandCastPlan> operandCastPlans;
SmallVector<Value> newOperands;
auto loc = contractOp.getLoc();
- if (maskingOp) {
- auto oldMaskType = cast<VectorType>(maskingOp.getMask().getType());
- if (oldMaskType.getRank() <= 1 || dimToDrop >= oldMaskType.getRank() ||
- !isNonScalableUnitDim(oldMaskType, dimToDrop))
- return failure();
- }
-
for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
// Check if the dim to be dropped exists as a leading dim in the operand
- // if it does then we use vector.shape_cast to drop it.
- VectorContractOperandCastPlan plan;
+ // if it does then we use vector.extract to drop it.
+ bool validExtract = false;
SmallVector<AffineExpr> results;
- plan.map = it.value();
- int64_t originalZeroDim = plan.map.getDimPosition(0);
- if (originalZeroDim != dimToDrop) {
+ auto map = it.value();
+ int64_t orginalZeroDim = it.value().getDimPosition(0);
+ if (orginalZeroDim != dimToDrop) {
// There are two reasons to be in this path, 1. We need to
- // permute the operand type to make the dim to be dropped
+ // transpose the operand to make the dim to be dropped
// leading. 2. The dim to be dropped does not exist and in
- // that case we dont want to add a unit permutation but we must
+ // that case we dont want to add a unit transpose but we must
// check all the indices to make sure this is the case.
- SmallVector<AffineExpr> permutedResults;
+ bool transposeNeeded = false;
+ SmallVector<int64_t> perm;
+ SmallVector<AffineExpr> transposeResults;
- for (int64_t i = 0, e = plan.map.getNumResults(); i < e; ++i) {
- int64_t currDim = plan.map.getDimPosition(i);
+ for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+ int64_t currDim = map.getDimPosition(i);
if (currDim == dimToDrop) {
- plan.permuteOperand = true;
- plan.permutation.insert(plan.permutation.begin(), i);
+ transposeNeeded = true;
+ perm.insert(perm.begin(), i);
auto targetExpr = rewriter.getAffineDimExpr(currDim);
- permutedResults.insert(permutedResults.begin(), targetExpr);
+ transposeResults.insert(transposeResults.begin(), targetExpr);
} else {
- plan.permutation.push_back(i);
+ perm.push_back(i);
auto targetExpr = rewriter.getAffineDimExpr(currDim);
- permutedResults.push_back(targetExpr);
+ transposeResults.push_back(targetExpr);
}
}
- // Update the map now so that the later shape_cast drops the correct dim.
- if (plan.permuteOperand) {
- plan.map = AffineMap::get(plan.map.getNumDims(), 0, permutedResults,
- contractOp.getContext());
- if (plan.map.getDimPosition(0) == dimToDrop) {
- auto operandType = cast<VectorType>(operands[it.index()].getType());
- if (!areLeadingDimsUnitAfterPermutation(operandType, plan.permutation,
- dropDim))
- return failure();
+ // Checks if only the outer, unit dimensions (of size 1) are permuted.
+ // Such transposes do not materially effect the underlying vector and can
+ // be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32>
+ bool transposeNonOuterUnitDims = false;
+ auto operandShape = cast<ShapedType>(operands[it.index()].getType());
+ for (auto [index, dim] :
+ llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) {
+ if (dim != static_cast<int64_t>(index) &&
+ operandShape.getDimSize(index) != 1) {
+ transposeNonOuterUnitDims = true;
+ break;
+ }
+ }
+
+ // Do the transpose now if needed so that we can drop the
+ // correct dim using extract later.
+ if (transposeNeeded) {
+ map = AffineMap::get(map.getNumDims(), 0, transposeResults,
+ contractOp.getContext());
+ if (transposeNonOuterUnitDims) {
+ operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>(
+ loc, operands[it.index()], perm);
}
}
}
// We have taken care to have the dim to be dropped be
// the leading dim. If its still not leading that means it
- // does not exist in this operand and hence we do not need a shape_cast.
- if (plan.map.getDimPosition(0) == dimToDrop)
- plan.dropLeadingUnitDim = true;
- if (plan.dropLeadingUnitDim && originalZeroDim == dimToDrop &&
- !areLeadingDimsUnit(cast<VectorType>(operands[it.index()].getType()),
- dropDim))
- return failure();
+ // does not exist in this operand and hence we do not need
+ // an extract.
+ if (map.getDimPosition(0) == dimToDrop)
+ validExtract = true;
- for (int64_t i = 0, e = plan.map.getNumResults(); i < e; ++i) {
- int64_t currDim = plan.map.getDimPosition(i);
+ for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+ int64_t currDim = map.getDimPosition(i);
if (currDim == dimToDrop)
// This is the dim we are dropping.
continue;
@@ -533,23 +442,15 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
currDim < dimToDrop ? currDim : currDim - 1);
results.push_back(targetExpr);
}
- newIndexingMaps.push_back(AffineMap::get(plan.map.getNumDims() - 1, 0,
- results, contractOp.getContext()));
- operandCastPlans.push_back(std::move(plan));
- }
-
- for (auto [plan, operand] : llvm::zip_equal(operandCastPlans, operands)) {
- Value newOperand = operand;
- if (plan.permuteOperand)
- newOperand = rewriter.createOrFold<vector::ShapeCastOp>(
- loc,
- permuteVectorType(cast<VectorType>(newOperand.getType()),
- plan.permutation),
- newOperand);
- if (plan.dropLeadingUnitDim)
- newOperand =
- dropLeadingUnitDims0DIsScalar(rewriter, loc, newOperand, dropDim);
- newOperands.push_back(newOperand);
+ newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
+ contractOp.getContext()));
+ // Extract if its a valid extraction, otherwise use the operand
+ // without extraction.
+ newOperands.push_back(validExtract
+ ? vector::ExtractOp::create(rewriter, loc,
+ operands[it.index()],
+ splatZero(dropDim))
+ : operands[it.index()]);
}
// Depending on whether this vector.contract is masked, the replacing Op
@@ -560,19 +461,13 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
if (maskingOp) {
- Value newMask = dropUnitDim(rewriter, loc, maskingOp.getMask(), dimToDrop,
- /*zeroDimsAllowed=*/false);
+ auto newMask = vector::ExtractOp::create(rewriter, loc, maskingOp.getMask(),
+ splatZero(dropDim));
newOp = mlir::vector::maskOperation(rewriter, newOp, newMask);
}
- if (!isa<VectorType>(newOp->getResults()[0].getType()))
- return vector::BroadcastOp::create(rewriter, loc,
- contractOp->getResultTypes()[0],
- newOp->getResults()[0])
- .getResult();
-
- return vector::ShapeCastOp::create(rewriter, loc,
+ return vector::BroadcastOp::create(rewriter, loc,
contractOp->getResultTypes()[0],
newOp->getResults()[0])
.getResult();
@@ -581,9 +476,9 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
namespace {
/// Turns vector.contract on vector with leading 1 dimensions into
-/// vector.shape_cast followed by vector.contract on vector without leading
-/// 1 dimensions. Non-leading unit dimensions are dropped via direct
-/// shape_casts.
+/// vector.extract followed by vector.contract on vector without leading
+/// 1 dimensions. Also performs transpose of lhs and rhs operands if required
+/// prior to extract.
struct CastAwayContractionLeadingOneDim
: public MaskableOpRewritePattern<vector::ContractionOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;
@@ -598,15 +493,14 @@ struct CastAwayContractionLeadingOneDim
/// Looks at elementwise operations on vectors with at least one leading
/// dimension equal 1, e.g. vector<1x[4]x1xf32> (but not vector<2x[4]x1xf32>),
-/// and casts away the leading one dimensions (_plural_) with shape_cast.
+/// and cast aways the leading one dimensions (_plural_) and then broadcasts
+/// the results.
///
/// Example before:
/// %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32>
/// Example after:
-/// %2 = vector.shape_cast %arg0 : vector<1x4x1xf32> to vector<4x1xf32>
-/// %3 = vector.shape_cast %arg1 : vector<1x4x1xf32> to vector<4x1xf32>
-/// %4 = arith.mulf %2, %3 : vector<4x1xf32>
-/// %5 = vector.shape_cast %4 : vector<4x1xf32> to vector<1x4x1xf32>
+/// %2 = arith.mulf %0, %1 : vector<4x1xf32>
+/// %3 = vector.broadcast %2 : vector<4x1xf32> to vector<1x4x1xf32>
///
/// Does support scalable vectors.
class CastAwayElementwiseLeadingOneDim : public RewritePattern {
@@ -622,34 +516,55 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0]);
if (!vecType)
return failure();
- VectorType newVecType =
- trimLeadingUnitDims(vecType, /*zeroDimsAllowed=*/true);
+ VectorType newVecType = trimLeadingOneDims(vecType);
if (newVecType == vecType)
return failure();
+ int64_t dropDim = vecType.getRank() - newVecType.getRank();
SmallVector<Value, 4> newOperands;
for (Value operand : op->getOperands()) {
- if (auto opVecType = dyn_cast<VectorType>(operand.getType()))
- newOperands.push_back(rewriter.createOrFold<vector::ShapeCastOp>(
- op->getLoc(),
- trimLeadingUnitDims(opVecType, /*zeroDimsAllowed=*/true), operand));
- else
+ if (auto opVecType = dyn_cast<VectorType>(operand.getType())) {
+ newOperands.push_back(vector::ExtractOp::create(
+ rewriter, op->getLoc(), operand, splatZero(dropDim)));
+ } else {
newOperands.push_back(operand);
+ }
}
Operation *newOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
newOperands, newVecType, op->getAttrs());
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, vecType,
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
newOp->getResult(0));
return success();
}
};
} // namespace
+// Drops `dropDim` leading dimensions from `operand` using vector.extract when
+// those dims are all non-scalable units (the cheap, structural rewrite); falls
+// back to vector.shape_cast otherwise.
+static Value dropLeadingOneDimsFromOperand(OpBuilder &b, Location loc,
+ Value operand, int64_t nDropped) {
+ auto oldType = cast<VectorType>(operand.getType());
+ ArrayRef<int64_t> leadingShape = oldType.getShape().take_front(nDropped);
+ ArrayRef<bool> leadingScalable =
+ oldType.getScalableDims().take_front(nDropped);
+ bool extractable =
+ llvm::all_of(leadingShape, [](int64_t d) { return d == 1; }) &&
+ llvm::none_of(leadingScalable, [](bool s) { return s; });
+ if (extractable)
+ return vector::ExtractOp::create(b, loc, operand, splatZero(nDropped));
+ VectorType newType = VectorType::get(
+ oldType.getShape().drop_front(nDropped), oldType.getElementType(),
+ oldType.getScalableDims().drop_front(nDropped));
+ return vector::ShapeCastOp::create(b, loc, newType, operand);
+}
+
namespace {
-// Drops leading unit dimensions from load-like memory operations by
-// shape_casting each vector operand and shape_casting the result back to the
-// original type.
+// Drops leading 1 dimensions from load-like memory operaitons. REmoves leading
+// unit dimensions from the result types and then broadcasts back in those 1s,
+// while also extracting (or shape_cast-ing) any leading unit dimensions on
+// the input operands.
template <typename OpTy>
struct CastAwayLoadLikeLeadingOneDim : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
@@ -657,10 +572,7 @@ struct CastAwayLoadLikeLeadingOneDim : public OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
VectorType oldResultType = op.getVectorType();
- constexpr bool zeroDimsAllowed =
- llvm::is_one_of<OpTy, vector::LoadOp>::value;
- VectorType newResultType =
- trimLeadingUnitDims(oldResultType, zeroDimsAllowed);
+ VectorType newResultType = trimLeadingOneDims(oldResultType);
if (newResultType == oldResultType)
return failure();
int64_t nDropped = oldResultType.getRank() - newResultType.getRank();
@@ -670,8 +582,8 @@ struct CastAwayLoadLikeLeadingOneDim : public OpRewritePattern<OpTy> {
newOperands.reserve(op->getNumOperands());
for (Value operand : op->getOperands()) {
if (isa<VectorType>(operand.getType())) {
- newOperands.push_back(dropLeadingUnitDims(rewriter, loc, operand,
- nDropped, zeroDimsAllowed));
+ newOperands.push_back(
+ dropLeadingOneDimsFromOperand(rewriter, loc, operand, nDropped));
} else {
newOperands.push_back(operand);
}
@@ -680,14 +592,15 @@ struct CastAwayLoadLikeLeadingOneDim : public OpRewritePattern<OpTy> {
Operation *newOp =
rewriter.create(loc, op->getName().getIdentifier(), newOperands,
TypeRange{newResultType}, op->getAttrs());
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, oldResultType,
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, oldResultType,
newOp->getResult(0));
return success();
}
};
-// Drops leading unit dimensions from store-like memory operations by
-// shape_casting each vector operand and leaving any scalar operands alone.
+// Drops leading 1 dimensions from store-like memory ops. Extracts or
+// `shape_cast`s away those leading unit dimensions and leaves any scalar
+// operands alone.
template <typename OpTy>
struct CastAwayStoreLikeLeadingOneDim : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
@@ -695,9 +608,7 @@ struct CastAwayStoreLikeLeadingOneDim : public OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
VectorType oldVecType = op.getVectorType();
- constexpr bool zeroDimsAllowed =
- llvm::is_one_of<OpTy, vector::StoreOp>::value;
- VectorType newVecType = trimLeadingUnitDims(oldVecType, zeroDimsAllowed);
+ VectorType newVecType = trimLeadingOneDims(oldVecType);
if (newVecType == oldVecType)
return failure();
int64_t nDropped = oldVecType.getRank() - newVecType.getRank();
@@ -707,8 +618,8 @@ struct CastAwayStoreLikeLeadingOneDim : public OpRewritePattern<OpTy> {
newOperands.reserve(op->getNumOperands());
for (Value operand : op->getOperands()) {
if (isa<VectorType>(operand.getType())) {
- newOperands.push_back(dropLeadingUnitDims(rewriter, loc, operand,
- nDropped, zeroDimsAllowed));
+ newOperands.push_back(
+ dropLeadingOneDimsFromOperand(rewriter, loc, operand, nDropped));
} else {
newOperands.push_back(operand);
}
@@ -722,8 +633,8 @@ struct CastAwayStoreLikeLeadingOneDim : public OpRewritePattern<OpTy> {
}
};
-// Drops leading 1 dimensions from vector.constant_mask and shape_casts back to
-// the original shape.
+// Drops leading 1 dimensions from vector.constant_mask and inserts a
+// vector.broadcast back to the original shape.
struct CastAwayConstantMaskLeadingOneDim
: public OpRewritePattern<vector::ConstantMaskOp> {
using Base::Base;
@@ -731,8 +642,7 @@ struct CastAwayConstantMaskLeadingOneDim
LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
PatternRewriter &rewriter) const override {
VectorType oldType = mask.getType();
- VectorType newType = trimLeadingUnitDims(oldType,
- /*zeroDimsAllowed=*/true);
+ VectorType newType = trimLeadingOneDims(oldType);
if (newType == oldType)
return failure();
@@ -740,22 +650,16 @@ struct CastAwayConstantMaskLeadingOneDim
int64_t dropDim = oldType.getRank() - newType.getRank();
ArrayRef<int64_t> dimSizes = mask.getMaskDimSizes();
- // If any of the folded unit dims has a size of `0`, the entire leading
- // mask region is zero. Otherwise the folded unit dims have no effect on
- // the mask.
- SmallVector<int64_t> newDimSizes;
- if (newType.getRank() == 0) {
- newDimSizes.push_back(llvm::product_of(dimSizes));
- } else {
- int64_t flatLeadingSize =
- llvm::product_of(dimSizes.take_front(dropDim + 1));
- newDimSizes.push_back(flatLeadingSize);
- newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
- }
+ // If any of the dropped unit dims has a size of `0`, the entire mask is a
+ // zero mask, else the unit dim has no effect on the mask.
+ int64_t flatLeadingSize =
+ llvm::product_of(dimSizes.take_front(dropDim + 1));
+ SmallVector<int64_t> newDimSizes = {flatLeadingSize};
+ newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
auto newMask = vector::ConstantMaskOp::create(rewriter, mask.getLoc(),
newType, newDimSizes);
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(mask, oldType, newMask);
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
return success();
}
};
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 2575c9e4a85b9..752610efc6992 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1931,12 +1931,12 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
}
};
-// Helper function dropping unit non-scalable dimension from a VectorType.
-// Scalable unit dimensions are not dropped. Folding such dimensions would
-// require "shifting" the scalable flag onto some other fixed-width dim (e.g.
-// vector<[1]x4xf32> -> vector<[4]xf32>).
-static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy,
- bool zeroDimsAllowed) {
+// Helper function dropping unit non-scalable dimension from a VectorType
+// keeping at least 1 dimension to avoid generating 0-D vectors. Scalable unit
+// dimensions are not dropped. Folding such dimensions would require "shifting"
+// the scalable flag onto some other fixed-width dim (e.g. vector<[1]x4xf32> ->
+// vector<[4]xf32>). This could be implemented in the future.
+static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
auto inVecShape = inVecTy.getShape();
SmallVector<int64_t> newShape;
SmallVector<bool> newScalableDims;
@@ -1948,8 +1948,8 @@ static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy,
newShape.push_back(dim);
newScalableDims.push_back(isScalable);
}
- // Some vector ops forbid 0-D vectors.
- if (!zeroDimsAllowed && newShape.empty()) {
+ // All dims have been dropped, return vector<1xeType>.
+ if (newShape.empty()) {
newShape.push_back(1);
newScalableDims.push_back(false);
}
@@ -2000,12 +2000,14 @@ struct DropUnitDimFromElementwiseOps final
auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
if (!sourceVectorType)
return failure();
+ if (sourceVectorType.getRank() < 2)
+ return failure();
+
SmallVector<Value> newOperands;
auto loc = op->getLoc();
for (auto operand : op->getOperands()) {
auto opVectorType = cast<VectorType>(operand.getType());
- auto newVType = dropNonScalableUnitDimFromType(opVectorType,
- /*zeroDimsAllowed=*/true);
+ auto newVType = dropNonScalableUnitDimFromType(opVectorType);
if (newVType == opVectorType)
return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
@@ -2014,8 +2016,7 @@ struct DropUnitDimFromElementwiseOps final
}
VectorType newResultVectorType =
- dropNonScalableUnitDimFromType(resultVectorType,
- /*zeroDimsAllowed=*/true);
+ dropNonScalableUnitDimFromType(resultVectorType);
// Create an updated elementwise Op without unit dim.
Operation *elementwiseOp =
rewriter.create(loc, op->getName().getIdentifier(), newOperands,
@@ -2056,8 +2057,7 @@ struct DropUnitDimsFromTransposeOp final
PatternRewriter &rewriter) const override {
VectorType sourceType = op.getSourceVectorType();
VectorType sourceTypeWithoutUnitDims =
- dropNonScalableUnitDimFromType(sourceType,
- /*zeroDimsAllowed=*/true);
+ dropNonScalableUnitDimFromType(sourceType);
if (sourceType == sourceTypeWithoutUnitDims)
return failure();
@@ -2082,9 +2082,9 @@ struct DropUnitDimsFromTransposeOp final
}
// Fixup for `newPerm`. The `sourceTypeWithoutUnitDims` could be vector<1xT>
- // type when the dimensions are unit dimensions and 0-D vectors are not
- // allowed. In this case, the newPerm should be [0].
- if (newPerm.empty() && sourceTypeWithoutUnitDims.getRank() > 0) {
+ // type when the dimensions are unit dimensions. In this case, the newPerm
+ // should be [0].
+ if (newPerm.empty()) {
newPerm.push_back(0);
}
@@ -2139,9 +2139,7 @@ struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
if (!vectorType)
continue;
- VectorType newVectorType =
- dropNonScalableUnitDimFromType(vectorType,
- /*zeroDimsAllowed=*/true);
+ VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
if (vectorType == newVectorType)
continue;
diff --git a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
index 4e800ab169bf6..34a155fbf2fc1 100644
--- a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
+++ b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
@@ -150,26 +150,10 @@ func.func @fold_all_unit_dims(%vec: vector<1x1xf32>) -> vector<1xf32> {
// CHECK-LABEL: func.func @fold_all_unit_dims(
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1xf32>) -> vector<1xf32>
-// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<f32>
-// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<f32>
-// CHECK: %[[VAL_3:.*]] = arith.mulf %[[VAL_1]], %[[VAL_2]] : vector<f32>
-// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<f32> to vector<1xf32>
-// CHECK: return %[[VAL_4]] : vector<1xf32>
-
-// -----
-
-func.func @fold_rank1_unit_dim(%vec: vector<1xf32>) -> vector<1xf32> {
- %res = arith.addf %vec, %vec : vector<1xf32>
- return %res : vector<1xf32>
-}
-
-// CHECK-LABEL: func.func @fold_rank1_unit_dim(
-// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>) -> vector<1xf32>
-// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1xf32> to vector<f32>
-// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1xf32> to vector<f32>
-// CHECK: %[[VAL_3:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : vector<f32>
-// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<f32> to vector<1xf32>
-// CHECK: return %[[VAL_4]] : vector<1xf32>
+// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32>
+// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32>
+// CHECK: %[[VAL_3:.*]] = arith.mulf %[[VAL_1]], %[[VAL_2]] : vector<1xf32>
+// CHECK: return %[[VAL_3]] : vector<1xf32>
///----------------------------------------------------------------------------------------
/// [Pattern: DropUnitDimsFromTransposeOp]
@@ -265,11 +249,11 @@ func.func @scf_for_with_all_unit_dims(%vec: vector<1x1xf32>) -> vector<1x1xf32>
// CHECK-LABEL: func.func @scf_for_with_all_unit_dims
// CHECK-SAME: %[[VEC:[A-Za-z0-9]+]]: vector<1x1xf32>
-// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x1xf32> to vector<f32>
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x1xf32> to vector<1xf32>
// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[ITER:.+]] = %[[CAST]])
-// CHECK: %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector<f32>
+// CHECK: %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector<1xf32>
// CHECK: scf.yield %[[SQRT]]
-// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector<f32> to vector<1x1xf32>
+// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector<1xf32> to vector<1x1xf32>
// CHECK: return %[[CASTBACK]]
// -----
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index cd1ecec455896..bf01c8a8589d9 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -5,13 +5,13 @@
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: cast_away_contraction_leading_one_dims
-// CHECK-NEXT: %[[R0:.+]] = vector.shape_cast %{{.*}} : vector<1x16x8xf32> to vector<16x8xf32>
-// CHECK-NEXT: %[[R1:.+]] = vector.shape_cast %{{.*}} : vector<1x8x16xf32> to vector<8x16xf32>
-// CHECK-NEXT: %[[R2:.+]] = vector.shape_cast %{{.*}} : vector<1x16x16xf32> to vector<16x16xf32>
+// CHECK-NEXT: %[[R0:.+]] = vector.extract %{{.*}}[0] : vector<16x8xf32> from vector<1x16x8xf32>
+// CHECK-NEXT: %[[R1:.+]] = vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32>
+// CHECK-NEXT: %[[R2:.+]] = vector.extract %{{.*}}[0] : vector<16x16xf32> from vector<1x16x16xf32>
// CHECK-NEXT: %[[R3:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
// CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
-// CHECK-NEXT: %[[R4:.+]] = vector.shape_cast %[[R3]] : vector<16x16xf32> to vector<1x16x16xf32>
+// CHECK-NEXT: %[[R4:.+]] = vector.broadcast %[[R3]] : vector<16x16xf32> to vector<1x16x16xf32>
// CHECK-NEXT: return %[[R4]] : vector<1x16x16xf32>
#contraction_accesses0 = [
@@ -36,14 +36,14 @@ func.func @cast_away_contraction_leading_one_dims(%arg0: vector<1x16x8xf32>, %ar
// CHECK-LABEL: func.func @cast_away_contraction_leading_one_dim_under_const_mask
// CHECK: %[[MASK:.*]] = vector.constant_mask [15, 15, 8] : vector<16x16x8xi1>
-// CHECK: %[[R0:.*]] = vector.shape_cast %{{.*}} : vector<1x16x8xf32> to vector<16x8xf32>
-// CHECK: %[[R1:.*]] = vector.shape_cast %{{.*}} : vector<1x8x16xf32> to vector<8x16xf32>
-// CHECK: %[[R2:.*]] = vector.shape_cast %{{.*}} : vector<1x16x16xf32> to vector<16x16xf32>
+// CHECK: %[[R0:.*]] = vector.extract %{{.*}}[0] : vector<16x8xf32> from vector<1x16x8xf32>
+// CHECK: %[[R1:.*]] = vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32>
+// CHECK: %[[R2:.*]] = vector.extract %{{.*}}[0] : vector<16x16xf32> from vector<1x16x16xf32>
// CHECK: %[[CONTRACT:.*]] = vector.mask %[[MASK]] {
// CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
// CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
// CHECK-SAME: } : vector<16x16x8xi1> -> vector<16x16xf32>
-// CHECK: %[[RES:.*]] = vector.shape_cast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
+// CHECK: %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
// CHECK: return %[[RES]] : vector<1x16x16xf32>
#contraction_accesses0 = [
@@ -70,15 +70,15 @@ func.func @cast_away_contraction_leading_one_dim_under_const_mask(%arg0: vector<
// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func.func @cast_away_contraction_leading_one_dim_under_mask
-// CHECK: %[[R0:.*]] = vector.shape_cast %{{.*}} : vector<1x16x8xf32> to vector<16x8xf32>
-// CHECK: %[[R1:.*]] = vector.shape_cast %{{.*}} : vector<1x8x16xf32> to vector<8x16xf32>
-// CHECK: %[[R2:.*]] = vector.shape_cast %{{.*}} : vector<1x16x16xf32> to vector<16x16xf32>
-// CHECK: %[[M:.*]] = vector.shape_cast %{{.*}} : vector<1x16x16x8xi1> to vector<16x16x8xi1>
+// CHECK: %[[R0:.*]] = vector.extract %{{.*}} : vector<16x8xf32> from vector<1x16x8xf32>
+// CHECK: %[[R1:.*]] = vector.extract %{{.*}} : vector<8x16xf32> from vector<1x8x16xf32>
+// CHECK: %[[R2:.*]] = vector.extract %{{.*}} : vector<16x16xf32> from vector<1x16x16xf32>
+// CHECK: %[[M:.*]] = vector.extract %{{.*}} : vector<16x16x8xi1> from vector<1x16x16x8xi1>
// CHECK: %[[CONTRACT:.*]] = vector.mask %[[M]] {
// CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
// CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
// CHECK-SAME: } : vector<16x16x8xi1> -> vector<16x16xf32>
-// CHECK-NEXT: %[[RES:.*]] = vector.shape_cast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
+// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
// CHECK-NEXT: return %[[RES]] : vector<1x16x16xf32>
#contraction_accesses0 = [
@@ -109,14 +109,15 @@ func.func @cast_away_contraction_leading_one_dim_under_mask(
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1) -> (d0)>
// CHECK-LABEL: cast_away_contraction_leading_one_dims_transposeneeded
-// CHECK-NEXT: %[[R0:.+]] = vector.shape_cast %{{.*}} : vector<1x8x16xf32> to vector<8x16xf32>
-// CHECK-NEXT: %[[R1:.+]] = vector.shape_cast %{{.*}} : vector<1x1x8xf32> to vector<8xf32>
-// CHECK-NEXT: %[[R2:.+]] = vector.shape_cast %{{.*}} : vector<1x1x16xf32> to vector<16xf32>
+// CHECK-NEXT: %[[R0:.+]] = vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32>
+// CHECK-NEXT: %[[R1:.+]] = vector.extract %{{.*}}[0, 0] : vector<8xf32> from vector<1x1x8xf32>
+// CHECK-NEXT: %[[R2:.+]] = vector.extract %{{.*}}[0, 0] : vector<16xf32> from vector<1x1x16xf32>
// CHECK-NEXT: %[[R3:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
// CHECK-SAME: iterator_types = ["parallel", "reduction"], kind = #vector.kind<mul>}
// CHECK-SAME: %[[R1]], %[[R0]], %[[R2]] : vector<8xf32>, vector<8x16xf32> into vector<16xf32>
-// CHECK-NEXT: %[[R4:.+]] = vector.shape_cast %[[R3]] : vector<16xf32> to vector<1x1x16xf32>
-// CHECK-NEXT: return %[[R4]] : vector<1x1x16xf32>
+// CHECK-NEXT: %[[R4:.+]] = vector.broadcast %[[R3]] : vector<16xf32> to vector<1x16xf32>
+// CHECK-NEXT: %[[R5:.+]] = vector.broadcast %[[R4]] : vector<1x16xf32> to vector<1x1x16xf32>
+// CHECK-NEXT: return %[[R5]] : vector<1x1x16xf32>
#contraction_accesses1 = [
affine_map<(l, i, j, k) -> (i, l, k)>,
@@ -140,13 +141,15 @@ func.func @cast_away_contraction_leading_one_dims_transposeneeded(%arg0: vector<
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: cast_away_contraction_leading_one_dims_transposeneeded2
-// CHECK-NEXT: %[[R1:.+]] = vector.shape_cast %{{.*}} : vector<8x1x16xf32> to vector<8x16xf32>
-// CHECK-NEXT: %[[R3:.+]] = vector.shape_cast %{{.*}} : vector<2x8x1xf32> to vector<2x8xf32>
-// CHECK-NEXT: %[[R4:.+]] = vector.shape_cast %{{.*}} : vector<1x2x16xf32> to vector<2x16xf32>
+// CHECK-NEXT: %[[R0:.+]] = vector.transpose %{{.*}}[1, 0, 2] : vector<8x1x16xf32> to vector<1x8x16xf32>
+// CHECK-NEXT: %[[R1:.+]] = vector.extract %[[R0]][0] : vector<8x16xf32> from vector<1x8x16xf32>
+// CHECK-NEXT: %[[R2:.+]] = vector.transpose %{{.*}}[2, 0, 1] : vector<2x8x1xf32> to vector<1x2x8xf32>
+// CHECK-NEXT: %[[R3:.+]] = vector.extract %[[R2]][0] : vector<2x8xf32> from vector<1x2x8xf32>
+// CHECK-NEXT: %[[R4:.+]] = vector.extract %{{.*}}[0] : vector<2x16xf32> from vector<1x2x16xf32>
// CHECK-NEXT: %[[R5:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
// CHECK-SAME: %[[R1]], %[[R3]], %[[R4]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32>
-// CHECK-NEXT: %[[R6:.+]] = vector.shape_cast %[[R5]] : vector<2x16xf32> to vector<1x2x16xf32>
+// CHECK-NEXT: %[[R6:.+]] = vector.broadcast %[[R5]] : vector<2x16xf32> to vector<1x2x16xf32>
// CHECK-NEXT: return %[[R6]] : vector<1x2x16xf32>
#contraction_accesses2 = [
@@ -172,14 +175,19 @@ func.func @cast_away_contraction_leading_one_dims_transposeneeded2(%arg0: vector
// CHECK-LABEL: cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4
-// CHECK-NEXT: %[[R3:.+]] = vector.shape_cast %{{.*}} : vector<1x8x1x16xf32> to vector<8x16xf32>
-// CHECK-NEXT: %[[R5:.+]] = vector.shape_cast %{{.*}} : vector<1x2x8x1xf32> to vector<2x8xf32>
-// CHECK-NEXT: %[[R6:.+]] = vector.shape_cast %{{.*}} : vector<1x1x2x16xf32> to vector<2x16xf32>
+// CHECK-NEXT: %[[R0:.+]] = vector.extract %{{.*}}[0] : vector<8x1x16xf32> from vector<1x8x1x16xf32>
+// CHECK-NEXT: %[[R1:.+]] = vector.extract %{{.*}}[0] : vector<2x8x1xf32> from vector<1x2x8x1xf32>
+// CHECK-NEXT: %[[R2:.+]] = vector.transpose %[[R0]], [1, 0, 2] : vector<8x1x16xf32> to vector<1x8x16xf32>
+// CHECK-NEXT: %[[R3:.+]] = vector.extract %[[R2]][0] : vector<8x16xf32> from vector<1x8x16xf32>
+// CHECK-NEXT: %[[R4:.+]] = vector.transpose %[[R1]], [2, 0, 1] : vector<2x8x1xf32> to vector<1x2x8xf32>
+// CHECK-NEXT: %[[R5:.+]] = vector.extract %[[R4]][0] : vector<2x8xf32> from vector<1x2x8xf32>
+// CHECK-NEXT: %[[R6:.+]] = vector.extract %{{.*}}[0, 0] : vector<2x16xf32> from vector<1x1x2x16xf32>
// CHECK-NEXT: %[[R7:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
// CHECK-SAME: %[[R3]], %[[R5]], %[[R6]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32>
-// CHECK-NEXT: %[[R8:.+]] = vector.shape_cast %[[R7]] : vector<2x16xf32> to vector<1x1x2x16xf32>
-// CHECK-NEXT: return %[[R8]] : vector<1x1x2x16xf32>
+// CHECK-NEXT: %[[R8:.+]] = vector.broadcast %[[R7]] : vector<2x16xf32> to vector<1x2x16xf32>
+// CHECK-NEXT: %[[R9:.+]] = vector.broadcast %[[R8]] : vector<1x2x16xf32> to vector<1x1x2x16xf32>
+// CHECK-NEXT: return %[[R9]] : vector<1x1x2x16xf32>
#contraction_accesses2 = [
affine_map<(m, l, i, j, k) -> (m, k, l, j)>,
@@ -203,14 +211,17 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4(%arg0:
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctranspose
-// CHECK-NEXT: %[[R2:.+]] = vector.shape_cast %{{.*}} : vector<1x8x1x16xf32> to vector<8x16xf32>
-// CHECK-NEXT: %[[R3:.+]] = vector.shape_cast %{{.*}} : vector<1x2x8x1xf32> to vector<2x8xf32>
-// CHECK-NEXT: %[[R4:.+]] = vector.shape_cast %{{.*}} : vector<1x1x2x16xf32> to vector<2x16xf32>
+// CHECK-NEXT: %[[R0:.+]] = vector.transpose %{{.*}}, [2, 0, 1, 3] : vector<1x8x1x16xf32> to vector<1x1x8x16xf32>
+// CHECK-NEXT: %[[R1:.+]] = vector.transpose %{{.*}}, [3, 0, 1, 2] : vector<1x2x8x1xf32> to vector<1x1x2x8xf32>
+// CHECK-NEXT: %[[R2:.+]] = vector.extract %[[R0]][0, 0] : vector<8x16xf32> from vector<1x1x8x16xf32>
+// CHECK-NEXT: %[[R3:.+]] = vector.extract %[[R1]][0, 0] : vector<2x8xf32> from vector<1x1x2x8xf32>
+// CHECK-NEXT: %[[R4:.+]] = vector.extract %{{.*}}[0, 0] : vector<2x16xf32> from vector<1x1x2x16xf32>
// CHECK-NEXT: %[[R5:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
// CHECK-SAME: %[[R2]], %[[R3]], %[[R4]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32>
-// CHECK-NEXT: %[[R6:.+]] = vector.shape_cast %[[R5]] : vector<2x16xf32> to vector<1x1x2x16xf32>
-// CHECK-NEXT: return %[[R6]] : vector<1x1x2x16xf32>
+// CHECK-NEXT: %[[R6:.+]] = vector.broadcast %[[R5]] : vector<2x16xf32> to vector<1x2x16xf32>
+// CHECK-NEXT: %[[R7:.+]] = vector.broadcast %[[R6]] : vector<1x2x16xf32> to vector<1x1x2x16xf32>
+// CHECK-NEXT: return %[[R7]] : vector<1x1x2x16xf32>
#contraction_accesses3 = [
affine_map<(m, l, i, j, k) -> (m, k, l, j)>,
@@ -245,7 +256,7 @@ func.func @cast_away_contraction_does_not_transpose_leading_unit_dims(%lhs: vect
// CHECK-DAG: #[[$map_dp1:.*]] = affine_map<(d0) -> ()>
// CHECK-LABEL: cast_away_contraction_leading_one_dims_to_dot_product
-// CHECK-NEXT: %[[R0:.+]] = vector.shape_cast %{{.*}} : vector<1x64xf32> to vector<64xf32>
+// CHECK-NEXT: %[[R0:.+]] = vector.extract %{{.*}}[0] : vector<64xf32> from vector<1x64xf32>
// CHECK-NEXT: %[[R1:.+]] = vector.extract %{{.*}}[0] : f32 from vector<1xf32>
// CHECK-NEXT: %[[R2:.+]] = vector.contract {indexing_maps = [#[[$map_dp0]], #[[$map_dp0]], #[[$map_dp1]]],
// CHECK-SAME: iterator_types = ["reduction"], kind = #vector.kind<add>}
@@ -259,96 +270,44 @@ func.func @cast_away_contraction_leading_one_dims_to_dot_product(%arg0: vector<6
}
// -----
-
-// CHECK-DAG: #[[$DOT_MAP:.*]] = affine_map<(d0) -> (d0)>
-// CHECK-DAG: #[[$SCALAR_MAP:.*]] = affine_map<(d0) -> ()>
-
-// CHECK-LABEL: cast_away_masked_contraction_with_rank1_acc
-// CHECK-NEXT: %[[RHS:.+]] = vector.shape_cast %{{.*}} : vector<1x64xf32> to vector<64xf32>
-// CHECK-NEXT: %[[ACC:.+]] = vector.extract %{{.*}}[0] : f32 from vector<1xf32>
-// CHECK-NEXT: %[[MASK:.+]] = vector.shape_cast %{{.*}} : vector<64x1xi1> to vector<64xi1>
-// CHECK-NEXT: %[[DOT:.+]] = vector.mask %[[MASK]] {
-// CHECK-SAME: vector.contract {indexing_maps = [#[[$DOT_MAP]], #[[$DOT_MAP]], #[[$SCALAR_MAP]]], iterator_types = ["reduction"], kind = #vector.kind<add>}
-// CHECK-SAME: %{{.*}}, %[[RHS]], %[[ACC]] : vector<64xf32>, vector<64xf32> into f32
-// CHECK-SAME: } : vector<64xi1> -> f32
-// CHECK-NEXT: %[[RES:.+]] = vector.broadcast %[[DOT]] : f32 to vector<1xf32>
-// CHECK-NEXT: return %[[RES]] : vector<1xf32>
-
-func.func @cast_away_masked_contraction_with_rank1_acc(%arg0: vector<64xf32>, %arg1: vector<1x64xf32>, %arg2: vector<1xf32>, %mask: vector<64x1xi1>) -> vector<1xf32> {
- %0 = vector.mask %mask {
- vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d1)>], iterator_types = ["reduction", "parallel"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<64xf32>, vector<1x64xf32> into vector<1xf32>
- } : vector<64x1xi1> -> vector<1xf32>
- return %0 : vector<1xf32>
-}
-
-// -----
-
-// CHECK-LABEL: negative_cast_away_contraction_with_scalable_rank1_acc
-// CHECK-NOT: vector.shape_cast
-// CHECK-NOT: vector.extract
-// CHECK-NOT: vector.broadcast
-// CHECK-NEXT: vector.contract
-// CHECK-NEXT: return
-
-func.func @negative_cast_away_contraction_with_scalable_rank1_acc(%arg0: vector<64xf32>, %arg1: vector<[1]x64xf32>, %arg2: vector<[1]xf32>) -> vector<[1]xf32> {
- %0 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d1)>], iterator_types = ["reduction", "parallel"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<64xf32>, vector<[1]x64xf32> into vector<[1]xf32>
- return %0 : vector<[1]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: negative_cast_away_contraction_with_scalable_operand_dim
-// CHECK-NOT: vector.shape_cast
-// CHECK-NOT: vector.extract
-// CHECK-NOT: vector.broadcast
-// CHECK-NEXT: vector.contract
-// CHECK-NEXT: return
-
-func.func @negative_cast_away_contraction_with_scalable_operand_dim(%arg0: vector<64xf32>, %arg1: vector<[1]x64xf32>, %arg2: vector<1xf32>) -> vector<1xf32> {
- %0 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d1)>], iterator_types = ["reduction", "parallel"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<64xf32>, vector<[1]x64xf32> into vector<1xf32>
- return %0 : vector<1xf32>
-}
-
-// -----
-
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
func.func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> {
- // CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16>
+ // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8x8xf16> from vector<1x8x8xf16>
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x8xf16> to vector<1x8xf16>
%0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x8xf16> to vector<1x1x8xf16>
- // CHECK: %[[RET:.+]] = vector.shape_cast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16>
+ // CHECK: %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16>
// CHECK: return %[[RET]]
return %0: vector<1x1x8xf16>
}
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims_scalable
func.func @cast_away_extract_strided_slice_leading_one_dims_scalable(%arg0: vector<1x8x[8]xf16>) -> vector<1x1x[8]xf16> {
- // CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8x[8]xf16> to vector<8x[8]xf16>
+ // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8x[8]xf16> from vector<1x8x[8]xf16>
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x[8]xf16> to vector<1x[8]xf16>
%0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x[8]xf16> to vector<1x1x[8]xf16>
- // CHECK: %[[RET:.+]] = vector.shape_cast %[[EXTRACT]] : vector<1x[8]xf16> to vector<1x1x[8]xf16>
+ // CHECK: %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x[8]xf16> to vector<1x1x[8]xf16>
// CHECK: return %[[RET]]
return %0: vector<1x1x[8]xf16>
}
// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims
func.func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> {
- // CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8xf16> to vector<8xf16>
- // CHECK: %[[DST:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16>
+ // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8xf16> from vector<1x8xf16>
+ // CHECK: %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<8x8xf16> from vector<1x8x8xf16>
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<8xf16> into vector<8x8xf16>
%0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<1x8x8xf16>
- // CHECK: %[[RET:.+]] = vector.shape_cast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16>
+ // CHECK: %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16>
// CHECK: return %[[RET]]
return %0: vector<1x8x8xf16>
}
// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_scalable
func.func @cast_away_insert_strided_slice_leading_one_dims_scalable(%arg0: vector<1x[8]xf16>, %arg1: vector<1x8x[8]xf16>) -> vector<1x8x[8]xf16> {
- // CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x[8]xf16> to vector<[8]xf16>
- // CHECK: %[[DST:.+]] = vector.shape_cast %{{.*}} : vector<1x8x[8]xf16> to vector<8x[8]xf16>
+ // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<[8]xf16> from vector<1x[8]xf16>
+ // CHECK: %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<8x[8]xf16> from vector<1x8x[8]xf16>
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<[8]xf16> into vector<8x[8]xf16>
%0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x[8]xf16> into vector<1x8x[8]xf16>
- // CHECK: %[[RET:.+]] = vector.shape_cast %[[INSERT]] : vector<8x[8]xf16> to vector<1x8x[8]xf16>
+ // CHECK: %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x[8]xf16> to vector<1x8x[8]xf16>
// CHECK: return %[[RET]]
return %0: vector<1x8x[8]xf16>
}
@@ -356,7 +315,8 @@ func.func @cast_away_insert_strided_slice_leading_one_dims_scalable(%arg0: vecto
// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element
// CHECK-SAME: %[[ARG0:.+]]: vector<1x1xf16>, %{{.+}}: vector<1x1x1xf16>
func.func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: vector<1x1xf16>, %arg1: vector<1x1x1xf16>) -> vector<1x1x1xf16> {
- // CHECK: %[[B:.+]] = vector.shape_cast %{{.*}} : vector<1x1xf16> to vector<1x1x1xf16>
+ // CHECK: %[[EXT:.+]] = vector.extract %{{.*}}[0] : vector<1xf16> from vector<1x1xf16>
+ // CHECK: %[[B:.+]] = vector.broadcast %[[EXT]] : vector<1xf16> to vector<1x1x1xf16>
%0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x1xf16> into vector<1x1x1xf16>
// CHECK: return %[[B]]
return %0: vector<1x1x1xf16>
@@ -365,7 +325,8 @@ func.func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: ve
// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element_scalable
// CHECK-SAME: %[[ARG0:.+]]: vector<1x[1]xf16>, %{{.+}}: vector<1x1x[1]xf16>
func.func @cast_away_insert_strided_slice_leading_one_dims_one_element_scalable(%arg0: vector<1x[1]xf16>, %arg1: vector<1x1x[1]xf16>) -> vector<1x1x[1]xf16> {
- // CHECK: %[[B:.+]] = vector.shape_cast %{{.*}} : vector<1x[1]xf16> to vector<1x1x[1]xf16>
+ // CHECK: %[[EXT:.+]] = vector.extract %{{.*}}[0] : vector<[1]xf16> from vector<1x[1]xf16>
+ // CHECK: %[[B:.+]] = vector.broadcast %[[EXT]] : vector<[1]xf16> to vector<1x1x[1]xf16>
%0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x[1]xf16> into vector<1x1x[1]xf16>
// CHECK: return %[[B]]
return %0: vector<1x1x[1]xf16>
@@ -378,7 +339,7 @@ func.func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>)
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
%f0 = arith.constant 0. : f16
// CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16>
- // CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<4xf16> to vector<1x4xf16>
+ // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16>
%0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16>
// CHECK: return %[[CAST]]
return %0: vector<1x4xf16>
@@ -390,9 +351,9 @@ func.func @cast_away_masked_transfer_read_leading_one_dims(%arg0: memref<1x4x8x1
%c0 = arith.constant 0 : index
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
%f0 = arith.constant 0. : f16
- // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1>
+ // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
// CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16>
- // CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<4xf16> to vector<1x4xf16>
+ // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16>
%0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16>
// CHECK: return %[[CAST]]
return %0: vector<1x4xf16>
@@ -402,7 +363,7 @@ func.func @cast_away_masked_transfer_read_leading_one_dims(%arg0: memref<1x4x8x1
func.func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> {
%c0 = arith.constant 0 : index
%f0 = arith.constant 0. : f16
- // CHECK: vector.shape_cast %{{.+}} : vector<f16> to vector<1x1xf16>
+ // CHECK: vector.broadcast %{{.+}} : vector<1xf16> to vector<1x1xf16>
%0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x1x1xf16>, vector<1x1xf16>
return %0: vector<1x1xf16>
}
@@ -419,7 +380,7 @@ func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16
// CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
// CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]
// CHECK-SAME: permutation_map = #[[$MAP]]} : memref<1x4x8xf16>, vector<4xf16>
- // CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<4xf16> to vector<1x1x4xf16>
+ // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x1x4xf16>
%0 = vector.transfer_read %arg0[%c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true, true],
permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : memref<1x4x8xf16>, vector<1x1x4xf16>
// CHECK: return %[[CAST]]
@@ -430,7 +391,7 @@ func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16
// CHECK-LABEL: func @not_insert_cast_fo4_transfer_read_under_mask
// CHECK: %[[MASK:.+]] = vector.constant_mask
-// CHECK: %[[CASTED_MASK:.+]] = vector.shape_cast %[[MASK]]
+// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
// CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
// CHECK-SAME: vector.transfer_read {{.*}} : memref<1x1x4xf16>, vector<1x4xf16> }
// CHECK: return %[[RET]] : vector<1x4xf16>
@@ -450,7 +411,7 @@ func.func @not_insert_cast_fo4_transfer_read_under_mask(%arg0: memref<1x1x4xf16>
func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
- // CHECK: %[[CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf16> to vector<4xf16>
+ // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xf16> from vector<1x4xf16>
// CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16>
vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16>
@@ -461,8 +422,8 @@ func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>
func.func @cast_away_masked_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>, %arg2: vector<1x4xi1>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
- // CHECK: %[[CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf16> to vector<4xf16>
- // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1>
+ // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xf16> from vector<1x4xf16>
+ // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
// CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16>
vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0], %arg2 {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16>
@@ -472,7 +433,7 @@ func.func @cast_away_masked_transfer_write_leading_one_dims(%arg0: memref<1x4x8x
// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element
func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) {
%c0 = arith.constant 0 : index
- // CHECK: vector.shape_cast %{{.+}} : vector<1x1xf16> to vector<f16>
+ // CHECK: vector.extract %{{.+}}[0] : vector<1xf16> from vector<1x1xf16>
vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x1xf16>, memref<1x1x1x1xf16>
return
}
@@ -481,7 +442,7 @@ func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1
// CHECK-LABEL: func @not_insert_cast_for_transfer_write_under_mask
// CHECK: %[[MASK:.+]] = vector.constant_mask
-// CHECK: %[[CASTED_MASK:.+]] = vector.shape_cast %[[MASK]]
+// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
// CHECK: vector.mask %[[CASTED_MASK]] {
// CHECK-SAME: vector.transfer_write {{.*}} : vector<1x4xf16>, memref<1x1x4xf16> }
// CHECK: return
@@ -501,7 +462,7 @@ func.func @not_insert_cast_for_transfer_write_under_mask(%arg0: memref<1x1x4xf16
func.func @cast_away_nontrivial_map_masked_transfer_write(%arg0: memref<1x4x8xf16>, %arg1: vector<1x1x4xf16>, %arg2: vector<1x4x1xi1>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
- // CHECK: %[[CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x1x4xf16> to vector<4xf16>
+ // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0, 0] : vector<4xf16> from vector<1x1x4xf16>
// CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
// CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]
// CHECK-SAME: permutation_map = #[[$MAP]]} : vector<4xf16>, memref<1x4x8xf16>
@@ -518,25 +479,25 @@ func.func @cast_away_elementwise_leading_one_dims(
%arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>,
%arg3: vector<1x4xf32>, %arg4: i1) ->
(vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>) {
- // CHECK: vector.shape_cast %{{.*}} : vector<1x1x8xf32> to vector<8xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<1x1x8xf32> to vector<8xf32>
+ // CHECK: vector.extract %{{.*}}[0, 0] : vector<8xf32> from vector<1x1x8xf32>
+ // CHECK: vector.extract %{{.*}}[0, 0] : vector<8xf32> from vector<1x1x8xf32>
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
+ // CHECK: vector.broadcast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
%0 = arith.addf %arg0, %arg0 : vector<1x1x8xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+ // CHECK: vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
+ // CHECK: vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
// CHECK: arith.cmpf ogt, %{{.*}}, %{{.*}} : vector<4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<4xi1> to vector<1x4xi1>
+ // CHECK: vector.broadcast %{{.*}} : vector<4xi1> to vector<1x4xi1>
%1 = arith.cmpf ogt, %arg2, %arg3 : vector<1x4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+ // CHECK: vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
+ // CHECK: vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
// CHECK: select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x4xf32>
+ // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32>
%2 = arith.select %1, %arg3, %arg2 : vector<1x4xi1>, vector<1x4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+ // CHECK: vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
+ // CHECK: vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
// CHECK: select %arg4, %12, %{{.*}} : vector<4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x4xf32>
+ // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32>
%3 = arith.select %arg4, %arg3, %arg2 : vector<1x4xf32>
return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>
}
@@ -545,10 +506,10 @@ func.func @cast_away_elementwise_leading_one_dims(
// CHECK-LABEL: func @cast_away_insert_leading_one_dims_scalar
// CHECK-SAME: (%[[S:.+]]: f32, %[[V:.+]]: vector<1x1x4xf32>)
-// CHECK: %[[DST_CAST:.+]] = vector.shape_cast %[[V]] : vector<1x1x4xf32> to vector<4xf32>
-// CHECK: %[[INSERT:.+]] = vector.insert %[[S]], %[[DST_CAST]] [0] : f32 into vector<4xf32>
-// CHECK: %[[RESULT_CAST:.+]] = vector.shape_cast %[[INSERT]] : vector<4xf32> to vector<1x1x4xf32>
-// CHECK: return %[[RESULT_CAST]]
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[V]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
+// CHECK: %[[INSERT:.+]] = vector.insert %[[S]], %[[EXTRACT]] [0] : f32 into vector<4xf32>
+// CHECK: %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<4xf32> to vector<1x1x4xf32>
+// CHECK: return %[[BCAST]]
func.func @cast_away_insert_leading_one_dims_scalar(%s: f32, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> {
%0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x1x4xf32>
return %0: vector<1x1x4xf32>
@@ -556,27 +517,14 @@ func.func @cast_away_insert_leading_one_dims_scalar(%s: f32, %v: vector<1x1x4xf3
// -----
-// CHECK-LABEL: func @cast_away_insert_leading_one_dims_scalar_0d_dest
-// CHECK-SAME: (%[[S:.+]]: f32, %[[V:.+]]: vector<1x1xf32>)
-// CHECK: %[[DST_CAST:.+]] = vector.shape_cast %[[V]] : vector<1x1xf32> to vector<f32>
-// CHECK: %[[INSERT:.+]] = vector.insert %[[S]], %[[DST_CAST]] [] : f32 into vector<f32>
-// CHECK: %[[RESULT_CAST:.+]] = vector.shape_cast %[[INSERT]] : vector<f32> to vector<1x1xf32>
-// CHECK: return %[[RESULT_CAST]]
-func.func @cast_away_insert_leading_one_dims_scalar_0d_dest(%s: f32, %v: vector<1x1xf32>) -> vector<1x1xf32> {
- %0 = vector.insert %s, %v [0, 0] : f32 into vector<1x1xf32>
- return %0: vector<1x1xf32>
-}
-
-// -----
-
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_scalar_scalable(
// CHECK-SAME: %[[S:.*]]: f32,
// CHECK-SAME: %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
func.func @cast_away_insert_leading_one_dims_scalar_scalable(%s: f32, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
-// CHECK: %[[DST_CAST:.*]] = vector.shape_cast %[[V]] : vector<1x1x[4]xf32> to vector<[4]xf32>
-// CHECK: %[[INSERT:.*]] = vector.insert %[[S]], %[[DST_CAST]] [0] : f32 into vector<[4]xf32>
-// CHECK: %[[RESULT_CAST:.*]] = vector.shape_cast %[[INSERT]] : vector<[4]xf32> to vector<1x1x[4]xf32>
-// CHECK: return %[[RESULT_CAST]] : vector<1x1x[4]xf32>
+// CHECK: %[[EXTRACT:.*]] = vector.extract %[[V]][0, 0] : vector<[4]xf32> from vector<1x1x[4]xf32>
+// CHECK: %[[INSERT:.*]] = vector.insert %[[S]], %[[EXTRACT]] [0] : f32 into vector<[4]xf32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<[4]xf32> to vector<1x1x[4]xf32>
+// CHECK: return %[[BCAST]] : vector<1x1x[4]xf32>
%0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x1x[4]xf32>
return %0: vector<1x1x[4]xf32>
}
@@ -587,10 +535,10 @@ func.func @cast_away_insert_leading_one_dims_scalar_scalable(%s: f32, %v: vector
// CHECK-SAME: %[[S:.*]]: f32,
// CHECK-SAME: %[[V:.*]]: vector<1x[1]x4xf32>) -> vector<1x[1]x4xf32> {
func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim(%s: f32, %v: vector<1x[1]x4xf32>) -> vector<1x[1]x4xf32> {
-// CHECK: %[[DST_CAST:.*]] = vector.shape_cast %[[V]] : vector<1x[1]x4xf32> to vector<[1]x4xf32>
-// CHECK: %[[INSERT:.*]] = vector.insert %[[S]], %[[DST_CAST]] [0, 0] : f32 into vector<[1]x4xf32>
-// CHECK: %[[RESULT_CAST:.*]] = vector.shape_cast %[[INSERT]] : vector<[1]x4xf32> to vector<1x[1]x4xf32>
-// CHECK: return %[[RESULT_CAST]] : vector<1x[1]x4xf32>
+// CHECK: %[[EXTRACT:.*]] = vector.extract %[[V]][0] : vector<[1]x4xf32> from vector<1x[1]x4xf32>
+// CHECK: %[[INSERT:.*]] = vector.insert %[[S]], %[[EXTRACT]] [0, 0] : f32 into vector<[1]x4xf32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<[1]x4xf32> to vector<1x[1]x4xf32>
+// CHECK: return %[[BCAST]] : vector<1x[1]x4xf32>
%0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x[1]x4xf32>
return %0: vector<1x[1]x4xf32>
}
@@ -599,8 +547,8 @@ func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim(%s: f32, %
// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank1
// CHECK-SAME: (%[[S:.+]]: vector<4xf32>, %[[V:.+]]: vector<1x1x4xf32>)
-// CHECK: %[[RESULT_CAST:.+]] = vector.shape_cast %[[S]] : vector<4xf32> to vector<1x1x4xf32>
-// CHECK: return %[[RESULT_CAST]]
+// CHECK: %[[BCAST:.+]] = vector.broadcast %[[S]] : vector<4xf32> to vector<1x1x4xf32>
+// CHECK: return %[[BCAST]]
func.func @cast_away_insert_leading_one_dims_rank1(%s: vector<4xf32>, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> {
%0 = vector.insert %s, %v [0, 0] : vector<4xf32> into vector<1x1x4xf32>
return %0: vector<1x1x4xf32>
@@ -611,8 +559,8 @@ func.func @cast_away_insert_leading_one_dims_rank1(%s: vector<4xf32>, %v: vector
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_rank1_scalable(
// CHECK-SAME: %[[S:.*]]: vector<[4]xf32>,
// CHECK-SAME: %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
-// CHECK: %[[RESULT_CAST:.*]] = vector.shape_cast %[[S]] : vector<[4]xf32> to vector<1x1x[4]xf32>
-// CHECK: return %[[RESULT_CAST]] : vector<1x1x[4]xf32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[S]] : vector<[4]xf32> to vector<1x1x[4]xf32>
+// CHECK: return %[[BCAST]] : vector<1x1x[4]xf32>
func.func @cast_away_insert_leading_one_dims_rank1_scalable(%s: vector<[4]xf32>, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
%0 = vector.insert %s, %v [0, 0] : vector<[4]xf32> into vector<1x1x[4]xf32>
return %0: vector<1x1x[4]xf32>
@@ -622,8 +570,9 @@ func.func @cast_away_insert_leading_one_dims_rank1_scalable(%s: vector<[4]xf32>,
// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2
// CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x1x4xf32>)
-// CHECK: %[[SRC_CAST:.+]] = vector.shape_cast %[[S]] : vector<1x4xf32> to vector<1x1x4xf32>
-// CHECK: return %[[SRC_CAST]]
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32>
+// CHECK: %[[BCAST:.+]] = vector.broadcast %[[EXTRACT]] : vector<4xf32> to vector<1x1x4xf32>
+// CHECK: return %[[BCAST]]
func.func @cast_away_insert_leading_one_dims_rank2(%s: vector<1x4xf32>, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> {
%0 = vector.insert %s, %v [0] : vector<1x4xf32> into vector<1x1x4xf32>
return %0: vector<1x1x4xf32>
@@ -634,8 +583,9 @@ func.func @cast_away_insert_leading_one_dims_rank2(%s: vector<1x4xf32>, %v: vect
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_rank2_scalable(
// CHECK-SAME: %[[S:.*]]: vector<1x[4]xf32>,
// CHECK-SAME: %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
-// CHECK: %[[SRC_CAST:.*]] = vector.shape_cast %[[S]] : vector<1x[4]xf32> to vector<1x1x[4]xf32>
-// CHECK: return %[[SRC_CAST]] : vector<1x1x[4]xf32>
+// CHECK: %[[EXTRACT:.*]] = vector.extract %[[S]][0] : vector<[4]xf32> from vector<1x[4]xf32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : vector<[4]xf32> to vector<1x1x[4]xf32>
+// CHECK: return %[[BCAST]] : vector<1x1x[4]xf32>
func.func @cast_away_insert_leading_one_dims_rank2_scalable(%s: vector<1x[4]xf32>, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
%0 = vector.insert %s, %v [0] : vector<1x[4]xf32> into vector<1x1x[4]xf32>
return %0: vector<1x1x[4]xf32>
@@ -645,11 +595,11 @@ func.func @cast_away_insert_leading_one_dims_rank2_scalable(%s: vector<1x[4]xf32
// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2_one_dest
// CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x2x1x4xf32>)
-// CHECK: %[[SRC_CAST:.+]] = vector.shape_cast %[[S]] : vector<1x4xf32> to vector<4xf32>
-// CHECK: %[[DST_CAST:.+]] = vector.shape_cast %[[V]] : vector<1x2x1x4xf32> to vector<2x1x4xf32>
-// CHECK: %[[INSERT:.+]] = vector.insert %[[SRC_CAST]], %[[DST_CAST]] [1, 0] : vector<4xf32> into vector<2x1x4xf32>
-// CHECK: %[[RESULT_CAST:.+]] = vector.shape_cast %[[INSERT]] : vector<2x1x4xf32> to vector<1x2x1x4xf32>
-// CHECK: return %[[RESULT_CAST]]
+// CHECK: %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32>
+// CHECK: %[[EXTRACTV:.+]] = vector.extract %[[V]][0] : vector<2x1x4xf32> from vector<1x2x1x4xf32>
+// CHECK: %[[INSERT:.+]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [1, 0] : vector<4xf32> into vector<2x1x4xf32>
+// CHECK: %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<2x1x4xf32> to vector<1x2x1x4xf32>
+// CHECK: return %[[BCAST]]
func.func @cast_away_insert_leading_one_dims_rank2_one_dest(%s: vector<1x4xf32>, %v: vector<1x2x1x4xf32>) -> vector<1x2x1x4xf32> {
%0 = vector.insert %s, %v [0, 1] : vector<1x4xf32> into vector<1x2x1x4xf32>
return %0: vector<1x2x1x4xf32>
@@ -660,11 +610,11 @@ func.func @cast_away_insert_leading_one_dims_rank2_one_dest(%s: vector<1x4xf32>,
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(
// CHECK-SAME: %[[S:.*]]: vector<1x[4]xf32>,
// CHECK-SAME: %[[V:.*]]: vector<1x2x1x[4]xf32>) -> vector<1x2x1x[4]xf32> {
-// CHECK: %[[SRC_CAST:.*]] = vector.shape_cast %[[S]] : vector<1x[4]xf32> to vector<[4]xf32>
-// CHECK: %[[DST_CAST:.*]] = vector.shape_cast %[[V]] : vector<1x2x1x[4]xf32> to vector<2x1x[4]xf32>
-// CHECK: %[[INSERT:.*]] = vector.insert %[[SRC_CAST]], %[[DST_CAST]] [1, 0] : vector<[4]xf32> into vector<2x1x[4]xf32>
-// CHECK: %[[RESULT_CAST:.*]] = vector.shape_cast %[[INSERT]] : vector<2x1x[4]xf32> to vector<1x2x1x[4]xf32>
-// CHECK: return %[[RESULT_CAST]] : vector<1x2x1x[4]xf32>
+// CHECK: %[[EXTRACTS:.*]] = vector.extract %[[S]][0] : vector<[4]xf32> from vector<1x[4]xf32>
+// CHECK: %[[EXTRACTV:.*]] = vector.extract %[[V]][0] : vector<2x1x[4]xf32> from vector<1x2x1x[4]xf32>
+// CHECK: %[[INSERT:.*]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [1, 0] : vector<[4]xf32> into vector<2x1x[4]xf32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<2x1x[4]xf32> to vector<1x2x1x[4]xf32>
+// CHECK: return %[[BCAST]] : vector<1x2x1x[4]xf32>
func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(%s: vector<1x[4]xf32>, %v: vector<1x2x1x[4]xf32>) -> vector<1x2x1x[4]xf32> {
%0 = vector.insert %s, %v [0, 1] : vector<1x[4]xf32> into vector<1x2x1x[4]xf32>
return %0: vector<1x2x1x[4]xf32>
@@ -674,8 +624,8 @@ func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(%s: vector<
// CHECK-LABEL: func @cast_away_insert_leading_one_dims_non_one_dest
// CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<8x1x4xf32>)
-// CHECK: %[[SRC_CAST:.+]] = vector.shape_cast %[[S]] : vector<1x4xf32> to vector<4xf32>
-// CHECK: %[[INSERT:.+]] = vector.insert %[[SRC_CAST]], %[[V]] [5, 0] : vector<4xf32> into vector<8x1x4xf32>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32>
+// CHECK: %[[INSERT:.+]] = vector.insert %[[EXTRACT]], %[[V]] [5, 0] : vector<4xf32> into vector<8x1x4xf32>
// CHECK: return %[[INSERT]]
func.func @cast_away_insert_leading_one_dims_non_one_dest(%s: vector<1x4xf32>, %v: vector<8x1x4xf32>) -> vector<8x1x4xf32> {
%0 = vector.insert %s, %v [5] : vector<1x4xf32> into vector<8x1x4xf32>
@@ -687,8 +637,8 @@ func.func @cast_away_insert_leading_one_dims_non_one_dest(%s: vector<1x4xf32>, %
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(
// CHECK-SAME: %[[S:.*]]: vector<1x[4]xf32>,
// CHECK-SAME: %[[V:.*]]: vector<8x1x[4]xf32>) -> vector<8x1x[4]xf32> {
-// CHECK: %[[SRC_CAST:.*]] = vector.shape_cast %[[S]] : vector<1x[4]xf32> to vector<[4]xf32>
-// CHECK: %[[INSERT:.*]] = vector.insert %[[SRC_CAST]], %[[V]] [5, 0] : vector<[4]xf32> into vector<8x1x[4]xf32>
+// CHECK: %[[EXTRACT:.*]] = vector.extract %[[S]][0] : vector<[4]xf32> from vector<1x[4]xf32>
+// CHECK: %[[INSERT:.*]] = vector.insert %[[EXTRACT]], %[[V]] [5, 0] : vector<[4]xf32> into vector<8x1x[4]xf32>
// CHECK: return %[[INSERT]] : vector<8x1x[4]xf32>
func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(%s: vector<1x[4]xf32>, %v: vector<8x1x[4]xf32>) -> vector<8x1x[4]xf32> {
%0 = vector.insert %s, %v [5] : vector<1x[4]xf32> into vector<8x1x[4]xf32>
@@ -699,11 +649,11 @@ func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(%s: vector<1x
// CHECK-LABEL: func @cast_away_insert_leading_one_dims_one_two_dest
// CHECK-SAME: (%[[S:.+]]: vector<1x8xi1>, %[[V:.+]]: vector<1x1x8x1x8xi1>)
-// CHECK: %[[SRC_CAST:.+]] = vector.shape_cast %[[S]] : vector<1x8xi1> to vector<8xi1>
-// CHECK: %[[DST_CAST:.+]] = vector.shape_cast %[[V]] : vector<1x1x8x1x8xi1> to vector<8x1x8xi1>
-// CHECK: %[[INSERT:.+]] = vector.insert %[[SRC_CAST]], %[[DST_CAST]] [7, 0] : vector<8xi1> into vector<8x1x8xi1>
-// CHECK: %[[RESULT_CAST:.+]] = vector.shape_cast %[[INSERT]] : vector<8x1x8xi1> to vector<1x1x8x1x8xi1>
-// CHECK: return %[[RESULT_CAST]]
+// CHECK: %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<8xi1> from vector<1x8xi1>
+// CHECK: %[[EXTRACTV:.+]] = vector.extract %[[V]][0, 0] : vector<8x1x8xi1> from vector<1x1x8x1x8xi1>
+// CHECK: %[[INSERT:.+]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [7, 0] : vector<8xi1> into vector<8x1x8xi1>
+// CHECK: %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<8x1x8xi1> to vector<1x1x8x1x8xi1>
+// CHECK: return %[[BCAST]]
func.func @cast_away_insert_leading_one_dims_one_two_dest(%s: vector<1x8xi1>, %v: vector<1x1x8x1x8xi1>) -> vector<1x1x8x1x8xi1> {
%0 = vector.insert %s, %v [0, 0, 7] : vector<1x8xi1> into vector<1x1x8x1x8xi1>
return %0: vector<1x1x8x1x8xi1>
@@ -714,11 +664,11 @@ func.func @cast_away_insert_leading_one_dims_one_two_dest(%s: vector<1x8xi1>, %v
// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(
// CHECK-SAME: %[[S:.*]]: vector<1x[8]xi1>,
// CHECK-SAME: %[[V:.*]]: vector<1x1x8x1x[8]xi1>) -> vector<1x1x8x1x[8]xi1> {
-// CHECK: %[[SRC_CAST:.*]] = vector.shape_cast %[[S]] : vector<1x[8]xi1> to vector<[8]xi1>
-// CHECK: %[[DST_CAST:.*]] = vector.shape_cast %[[V]] : vector<1x1x8x1x[8]xi1> to vector<8x1x[8]xi1>
-// CHECK: %[[INSERT:.*]] = vector.insert %[[SRC_CAST]], %[[DST_CAST]] [7, 0] : vector<[8]xi1> into vector<8x1x[8]xi1>
-// CHECK: %[[RESULT_CAST:.*]] = vector.shape_cast %[[INSERT]] : vector<8x1x[8]xi1> to vector<1x1x8x1x[8]xi1>
-// CHECK: return %[[RESULT_CAST]] : vector<1x1x8x1x[8]xi1>
+// CHECK: %[[EXTRACTS:.*]] = vector.extract %[[S]][0] : vector<[8]xi1> from vector<1x[8]xi1>
+// CHECK: %[[EXTRACTV:.*]] = vector.extract %[[V]][0, 0] : vector<8x1x[8]xi1> from vector<1x1x8x1x[8]xi1>
+// CHECK: %[[INSERT:.*]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [7, 0] : vector<[8]xi1> into vector<8x1x[8]xi1>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<8x1x[8]xi1> to vector<1x1x8x1x[8]xi1>
+// CHECK: return %[[BCAST]] : vector<1x1x8x1x[8]xi1>
func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(%s: vector<1x[8]xi1>, %v: vector<1x1x8x1x[8]xi1>) -> vector<1x1x8x1x[8]xi1> {
%0 = vector.insert %s, %v [0, 0, 7] : vector<1x[8]xi1> into vector<1x1x8x1x[8]xi1>
return %0: vector<1x1x8x1x[8]xi1>
@@ -728,8 +678,8 @@ func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(%s: vector<1x
// CHECK-LABEL: func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
// CHECK: %[[MASK:.*]] = vector.constant_mask [6, 1, 1] : vector<8x2x1xi1>
-// CHECK: %[[MASK_CAST:.*]] = vector.shape_cast %[[MASK]] : vector<8x2x1xi1> to vector<1x1x8x2x1xi1>
-// CHECK: return %[[MASK_CAST]] : vector<1x1x8x2x1xi1>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[MASK]] : vector<8x2x1xi1> to vector<1x1x8x2x1xi1>
+// CHECK: return %[[BCAST]] : vector<1x1x8x2x1xi1>
func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
%0 = vector.constant_mask [1, 1, 6, 1, 1] : vector<1x1x8x2x1xi1>
return %0: vector<1x1x8x2x1xi1>
@@ -737,16 +687,6 @@ func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
// -----
-// CHECK-LABEL: func.func @cast_away_constant_mask_all_unit_dims() -> vector<1x1xi1> {
-// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1xi1>
-// CHECK: return %[[MASK]] : vector<1x1xi1>
-func.func @cast_away_constant_mask_all_unit_dims() -> vector<1x1xi1> {
- %0 = vector.constant_mask [1, 1] : vector<1x1xi1>
- return %0: vector<1x1xi1>
-}
-
-// -----
-
// CHECK-LABEL: func.func @drop_unit_dims_scalar_cond_select(
// CHECK: arith.select {{.*}} : vector<16xi1>
func.func @drop_unit_dims_scalar_cond_select(%cond: i1, %arg0: vector<1x16xi1>, %arg1: vector<1x16xi1>) -> vector<1x16xi1> {
@@ -758,7 +698,7 @@ func.func @drop_unit_dims_scalar_cond_select(%cond: i1, %arg0: vector<1x16xi1>,
// CHECK-LABEL: func.func @cast_away_load_leading_one_dims
// CHECK: %[[L:.+]] = vector.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x16xf32>, vector<4xf32>
-// CHECK: %[[B:.+]] = vector.shape_cast %[[L]] : vector<4xf32> to vector<1x4xf32>
+// CHECK: %[[B:.+]] = vector.broadcast %[[L]] : vector<4xf32> to vector<1x4xf32>
// CHECK: return %[[B]] : vector<1x4xf32>
func.func @cast_away_load_leading_one_dims(%base: memref<8x16xf32>, %i: index, %j: index) -> vector<1x4xf32> {
%0 = vector.load %base[%i, %j] : memref<8x16xf32>, vector<1x4xf32>
@@ -767,33 +707,11 @@ func.func @cast_away_load_leading_one_dims(%base: memref<8x16xf32>, %i: index, %
// -----
-// CHECK-LABEL: func.func @cast_away_load_all_unit_dims
-// CHECK: %[[L:.+]] = vector.load %{{.*}}[%{{.*}}] : memref<1xf32>, vector<f32>
-// CHECK: %[[B:.+]] = vector.shape_cast %[[L]] : vector<f32> to vector<1xf32>
-// CHECK: return %[[B]] : vector<1xf32>
-func.func @cast_away_load_all_unit_dims(%base: memref<1xf32>, %i: index) -> vector<1xf32> {
- %0 = vector.load %base[%i] : memref<1xf32>, vector<1xf32>
- return %0 : vector<1xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func.func @cast_away_load_leading_one_dims_scalable
-// CHECK: %[[L:.+]] = vector.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>, vector<[4]xf32>
-// CHECK: %[[B:.+]] = vector.shape_cast %[[L]] : vector<[4]xf32> to vector<1x[4]xf32>
-// CHECK: return %[[B]] : vector<1x[4]xf32>
-func.func @cast_away_load_leading_one_dims_scalable(%base: memref<?x?xf32>, %i: index, %j: index) -> vector<1x[4]xf32> {
- %0 = vector.load %base[%i, %j] : memref<?x?xf32>, vector<1x[4]xf32>
- return %0 : vector<1x[4]xf32>
-}
-
-// -----
-
// CHECK-LABEL: func.func @cast_away_maskedload_leading_one_dims
-// CHECK: %[[M:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1>
-// CHECK: %[[P:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+// CHECK: %[[P:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
// CHECK: %[[L:.+]] = vector.maskedload %{{.*}}[%{{.*}}], %[[M]], %[[P]] : memref<16xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
-// CHECK: %[[B:.+]] = vector.shape_cast %[[L]] : vector<4xf32> to vector<1x4xf32>
+// CHECK: %[[B:.+]] = vector.broadcast %[[L]] : vector<4xf32> to vector<1x4xf32>
// CHECK: return %[[B]] : vector<1x4xf32>
func.func @cast_away_maskedload_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %pass: vector<1x4xf32>) -> vector<1x4xf32> {
%0 = vector.maskedload %base[%i], %mask, %pass : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
@@ -803,10 +721,10 @@ func.func @cast_away_maskedload_leading_one_dims(%base: memref<16xf32>, %i: inde
// -----
// CHECK-LABEL: func.func @cast_away_expandload_leading_one_dims
-// CHECK: %[[M:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1>
-// CHECK: %[[P:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+// CHECK: %[[P:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
// CHECK: %[[L:.+]] = vector.expandload %{{.*}}[%{{.*}}], %[[M]], %[[P]] : memref<16xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
-// CHECK: %[[B:.+]] = vector.shape_cast %[[L]] : vector<4xf32> to vector<1x4xf32>
+// CHECK: %[[B:.+]] = vector.broadcast %[[L]] : vector<4xf32> to vector<1x4xf32>
// CHECK: return %[[B]] : vector<1x4xf32>
func.func @cast_away_expandload_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %pass: vector<1x4xf32>) -> vector<1x4xf32> {
%0 = vector.expandload %base[%i], %mask, %pass : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
@@ -816,11 +734,11 @@ func.func @cast_away_expandload_leading_one_dims(%base: memref<16xf32>, %i: inde
// -----
// CHECK-LABEL: func.func @cast_away_gather_leading_one_dims
-// CHECK: %[[I:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi32> to vector<4xi32>
-// CHECK: %[[M:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1>
-// CHECK: %[[P:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+// CHECK: %[[I:.+]] = vector.extract %{{.*}}[0] : vector<4xi32> from vector<1x4xi32>
+// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+// CHECK: %[[P:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
// CHECK: %[[G:.+]] = vector.gather %{{.*}}[%{{.*}}] [%[[I]]], %[[M]], %[[P]] : memref<16xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
-// CHECK: %[[B:.+]] = vector.shape_cast %[[G]] : vector<4xf32> to vector<1x4xf32>
+// CHECK: %[[B:.+]] = vector.broadcast %[[G]] : vector<4xf32> to vector<1x4xf32>
// CHECK: return %[[B]] : vector<1x4xf32>
func.func @cast_away_gather_leading_one_dims(%base: memref<16xf32>, %i: index, %idx: vector<1x4xi32>, %mask: vector<1x4xi1>, %pass: vector<1x4xf32>) -> vector<1x4xf32> {
%0 = vector.gather %base[%i] [%idx], %mask, %pass : memref<16xf32>, vector<1x4xi32>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
@@ -830,7 +748,7 @@ func.func @cast_away_gather_leading_one_dims(%base: memref<16xf32>, %i: index, %
// -----
// CHECK-LABEL: func.func @cast_away_store_leading_one_dims
-// CHECK: %[[V:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+// CHECK: %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
// CHECK: vector.store %[[V]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x16xf32>, vector<4xf32>
func.func @cast_away_store_leading_one_dims(%val: vector<1x4xf32>, %base: memref<8x16xf32>, %i: index, %j: index) {
vector.store %val, %base[%i, %j] : memref<8x16xf32>, vector<1x4xf32>
@@ -839,29 +757,9 @@ func.func @cast_away_store_leading_one_dims(%val: vector<1x4xf32>, %base: memref
// -----
-// CHECK-LABEL: func.func @cast_away_store_all_unit_dims
-// CHECK: %[[V:.+]] = vector.shape_cast %{{.*}} : vector<1xf32> to vector<f32>
-// CHECK: vector.store %[[V]], %{{.*}}[%{{.*}}] : memref<1xf32>, vector<f32>
-func.func @cast_away_store_all_unit_dims(%val: vector<1xf32>, %base: memref<1xf32>, %i: index) {
- vector.store %val, %base[%i] : memref<1xf32>, vector<1xf32>
- return
-}
-
-// -----
-
-// CHECK-LABEL: func.func @cast_away_store_leading_one_dims_scalable
-// CHECK: %[[V:.+]] = vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]xf32>
-// CHECK: vector.store %[[V]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>, vector<[4]xf32>
-func.func @cast_away_store_leading_one_dims_scalable(%val: vector<1x[4]xf32>, %base: memref<?x?xf32>, %i: index, %j: index) {
- vector.store %val, %base[%i, %j] : memref<?x?xf32>, vector<1x[4]xf32>
- return
-}
-
-// -----
-
// CHECK-LABEL: func.func @cast_away_maskedstore_leading_one_dims
-// CHECK: %[[M:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1>
-// CHECK: %[[V:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+// CHECK: %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
// CHECK: vector.maskedstore %{{.*}}[%{{.*}}], %[[M]], %[[V]] : memref<16xf32>, vector<4xi1>, vector<4xf32>
func.func @cast_away_maskedstore_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %val: vector<1x4xf32>) {
vector.maskedstore %base[%i], %mask, %val : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32>
@@ -871,8 +769,8 @@ func.func @cast_away_maskedstore_leading_one_dims(%base: memref<16xf32>, %i: ind
// -----
// CHECK-LABEL: func.func @cast_away_compressstore_leading_one_dims
-// CHECK: %[[M:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1>
-// CHECK: %[[V:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+// CHECK: %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
// CHECK: vector.compressstore %{{.*}}[%{{.*}}], %[[M]], %[[V]] : memref<16xf32>, vector<4xi1>, vector<4xf32>
func.func @cast_away_compressstore_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %val: vector<1x4xf32>) {
vector.compressstore %base[%i], %mask, %val : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32>
@@ -882,41 +780,11 @@ func.func @cast_away_compressstore_leading_one_dims(%base: memref<16xf32>, %i: i
// -----
// CHECK-LABEL: func.func @cast_away_scatter_leading_one_dims
-// CHECK: %[[I:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi32> to vector<4xi32>
-// CHECK: %[[M:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1>
-// CHECK: %[[V:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+// CHECK: %[[I:.+]] = vector.extract %{{.*}}[0] : vector<4xi32> from vector<1x4xi32>
+// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+// CHECK: %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
// CHECK: vector.scatter %{{.*}}[%{{.*}}] [%[[I]]], %[[M]], %[[V]] : memref<16xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32>
func.func @cast_away_scatter_leading_one_dims(%base: memref<16xf32>, %i: index, %idx: vector<1x4xi32>, %mask: vector<1x4xi1>, %val: vector<1x4xf32>) {
vector.scatter %base[%i] [%idx], %mask, %val : memref<16xf32>, vector<1x4xi32>, vector<1x4xi1>, vector<1x4xf32>
return
}
-
-// -----
-
-// CHECK-LABEL: func.func @negative_cast_memory_ops_to_0d
-// CHECK-NOT: vector.shape_cast
-// CHECK: vector.maskedload {{.*}} : memref<16xf32>, vector<1xi1>, vector<1xf32> into vector<1xf32>
-// CHECK-NOT: vector.shape_cast
-// CHECK: vector.expandload {{.*}} : memref<16xf32>, vector<1xi1>, vector<1xf32> into vector<1xf32>
-// CHECK-NOT: vector.shape_cast
-// CHECK: vector.gather {{.*}} : memref<16xf32>, vector<1xi32>, vector<1xi1>, vector<1xf32> into vector<1xf32>
-// CHECK-NOT: vector.shape_cast
-// CHECK: vector.maskedstore {{.*}} : memref<16xf32>, vector<1xi1>, vector<1xf32>
-// CHECK-NOT: vector.shape_cast
-// CHECK: vector.compressstore {{.*}} : memref<16xf32>, vector<1xi1>, vector<1xf32>
-// CHECK-NOT: vector.shape_cast
-// CHECK: vector.scatter {{.*}} : memref<16xf32>, vector<1xi32>, vector<1xi1>, vector<1xf32>
-// CHECK-NOT: vector.shape_cast
-// CHECK: return
-func.func @negative_cast_memory_ops_to_0d(
- %base: memref<16xf32>, %i: index, %idx: vector<1xi32>,
- %mask: vector<1xi1>, %pass: vector<1xf32>, %val: vector<1xf32>)
- -> (vector<1xf32>, vector<1xf32>, vector<1xf32>) {
- %0 = vector.maskedload %base[%i], %mask, %pass : memref<16xf32>, vector<1xi1>, vector<1xf32> into vector<1xf32>
- %1 = vector.expandload %base[%i], %mask, %pass : memref<16xf32>, vector<1xi1>, vector<1xf32> into vector<1xf32>
- %2 = vector.gather %base[%i] [%idx], %mask, %pass : memref<16xf32>, vector<1xi32>, vector<1xi1>, vector<1xf32> into vector<1xf32>
- vector.maskedstore %base[%i], %mask, %val : memref<16xf32>, vector<1xi1>, vector<1xf32>
- vector.compressstore %base[%i], %mask, %val : memref<16xf32>, vector<1xi1>, vector<1xf32>
- vector.scatter %base[%i] [%idx], %mask, %val : memref<16xf32>, vector<1xi32>, vector<1xi1>, vector<1xf32>
- return %0, %1, %2 : vector<1xf32>, vector<1xf32>, vector<1xf32>
-}
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index d0d3a6c0bb976..de12a87253a67 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -36,7 +36,7 @@ func.func @no_change(%arg0: vector<2x[4]x1xf32>, %arg1: vector<2x[4]x1xf32>) ->
// CHECK-LABEL: func.func @cast_away_leading_one_dim(
// CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<4x1xf32>
-// CHECK: vector.shape_cast %[[MUL]] : vector<4x1xf32> to vector<1x4x1xf32>
+// CHECK: vector.broadcast %[[MUL]] : vector<4x1xf32> to vector<1x4x1xf32>
func.func @cast_away_leading_one_dim(%arg0: vector<1x4x1xf32>, %arg1: vector<1x4x1xf32>) -> vector<1x4x1xf32> {
%1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32>
return %1: vector<1x4x1xf32>
@@ -44,7 +44,7 @@ func.func @cast_away_leading_one_dim(%arg0: vector<1x4x1xf32>, %arg1: vector<1x4
// CHECK-LABEL: func.func @cast_away_leading_one_dim_scalable(
// CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<[4]x1xf32>
-// CHECK: vector.shape_cast %[[MUL]] : vector<[4]x1xf32> to vector<1x[4]x1xf32>
+// CHECK: vector.broadcast %[[MUL]] : vector<[4]x1xf32> to vector<1x[4]x1xf32>
func.func @cast_away_leading_one_dim_scalable(%arg0: vector<1x[4]x1xf32>, %arg1: vector<1x[4]x1xf32>) -> vector<1x[4]x1xf32> {
%1 = arith.mulf %arg0, %arg1 : vector<1x[4]x1xf32>
return %1: vector<1x[4]x1xf32>
@@ -277,15 +277,13 @@ func.func @contraction4x4_ikj_xfer_read_tensor(%arg0 : tensor<4x2xf32>,
func.func @bubble_down_bitcast_in_extract(%src: vector<4xf32>) -> (f16, f16) {
%0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
// CHECK: %[[EXTRACT1:.+]] = vector.extract %[[SRC]][1] : f32 from vector<4xf32>
- // CHECK: %[[INSERT1:.+]] = vector.insert %[[EXTRACT1]], %{{.+}} [] : f32 into vector<f32>
- // CHECK: %[[SHAPE_CAST1:.+]] = vector.shape_cast %[[INSERT1]] : vector<f32> to vector<1xf32>
- // CHECK: %[[CAST1:.+]] = vector.bitcast %[[SHAPE_CAST1]] : vector<1xf32> to vector<2xf16>
+ // CHECK: %[[INSERT1:.+]] = vector.insert %[[EXTRACT1]], %{{.+}} [0] : f32 into vector<1xf32>
+ // CHECK: %[[CAST1:.+]] = vector.bitcast %[[INSERT1]] : vector<1xf32> to vector<2xf16>
// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[CAST1]][1] : f16 from vector<2xf16>
%1 = vector.extract %0[3] : f16 from vector<8xf16>
// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[SRC]][2] : f32 from vector<4xf32>
- // CHECK: %[[INSERT3:.+]] = vector.insert %[[EXTRACT3]], %{{.+}} [] : f32 into vector<f32>
- // CHECK: %[[SHAPE_CAST2:.+]] = vector.shape_cast %[[INSERT3]] : vector<f32> to vector<1xf32>
- // CHECK: %[[CAST2:.+]] = vector.bitcast %[[SHAPE_CAST2]] : vector<1xf32> to vector<2xf16>
+ // CHECK: %[[INSERT3:.+]] = vector.insert %[[EXTRACT3]], %{{.+}} [0] : f32 into vector<1xf32>
+ // CHECK: %[[CAST2:.+]] = vector.bitcast %[[INSERT3]] : vector<1xf32> to vector<2xf16>
// CHECK: %[[EXTRACT4:.+]] = vector.extract %[[CAST2]][0] : f16 from vector<2xf16>
%2 = vector.extract %0[4] : f16 from vector<8xf16>
// CHECK: return %[[EXTRACT2]], %[[EXTRACT4]]
More information about the Mlir-commits
mailing list