[Mlir-commits] [mlir] 186709c - [mlir] [VectorOps] Progressive lowering of vector.broadcast
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 16 21:02:43 PDT 2020
Author: aartbik
Date: 2020-04-16T21:02:27-07:00
New Revision: 186709c6e0bd1025fb578e43911325530cb97f09
URL: https://github.com/llvm/llvm-project/commit/186709c6e0bd1025fb578e43911325530cb97f09
DIFF: https://github.com/llvm/llvm-project/commit/186709c6e0bd1025fb578e43911325530cb97f09.diff
LOG: [mlir] [VectorOps] Progressive lowering of vector.broadcast
Summary:
Rather than having a full, recursive, lowering of vector.broadcast
to LLVM IR, it is much more elegant to have a progressive lowering
of each vector.broadcast into a lower dimensional vector.broadcast,
until only elementary vector operations remain. This results
in more elegant, step-wise code, that is easier to understand.
Also makes some optimizations in the generated code.
Reviewers: nicolasvasilache, mehdi_amini, andydavis1, grosul1
Reviewed By: nicolasvasilache
Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, grosul1, frgossen, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D78071
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/vector-contract-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 2a8835102d59..c0785b9a2f9f 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -55,6 +55,7 @@ void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
/// ContractionOpLowering,
/// ShapeCastOp2DDownCastRewritePattern,
/// ShapeCastOp2DUpCastRewritePattern
+/// BroadcastOpLowering,
/// TransposeOpLowering
/// OuterproductOpLowering
/// These transformation express higher level vector ops in terms of more
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index b7c4a57a78ba..003e06a87299 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -126,155 +126,6 @@ static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
namespace {
-class VectorBroadcastOpConversion : public ConvertToLLVMPattern {
-public:
- explicit VectorBroadcastOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::BroadcastOp::getOperationName(), context,
- typeConverter) {}
-
- LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto broadcastOp = cast<vector::BroadcastOp>(op);
- VectorType dstVectorType = broadcastOp.getVectorType();
- if (typeConverter.convertType(dstVectorType) == nullptr)
- return failure();
- // Rewrite when the full vector type can be lowered (which
- // implies all 'reduced' types can be lowered too).
- auto adaptor = vector::BroadcastOpOperandAdaptor(operands);
- VectorType srcVectorType =
- broadcastOp.getSourceType().dyn_cast<VectorType>();
- rewriter.replaceOp(
- op, expandRanks(adaptor.source(), // source value to be expanded
- op->getLoc(), // location of original broadcast
- srcVectorType, dstVectorType, rewriter));
- return success();
- }
-
-private:
- // Expands the given source value over all the ranks, as defined
- // by the source and destination type (a null source type denotes
- // expansion from a scalar value into a vector).
- //
- // TODO(ajcbik): consider replacing this one-pattern lowering
- // with a two-pattern lowering using other vector
- // ops once all insert/extract/shuffle operations
- // are available with lowering implementation.
- //
- Value expandRanks(Value value, Location loc, VectorType srcVectorType,
- VectorType dstVectorType,
- ConversionPatternRewriter &rewriter) const {
- assert((dstVectorType != nullptr) && "invalid result type in broadcast");
- // Determine rank of source and destination.
- int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0;
- int64_t dstRank = dstVectorType.getRank();
- int64_t curDim = dstVectorType.getDimSize(0);
- if (srcRank < dstRank)
- // Duplicate this rank.
- return duplicateOneRank(value, loc, srcVectorType, dstVectorType, dstRank,
- curDim, rewriter);
- // If all trailing dimensions are the same, the broadcast consists of
- // simply passing through the source value and we are done. Otherwise,
- // any non-matching dimension forces a stretch along this rank.
- assert((srcVectorType != nullptr) && (srcRank > 0) &&
- (srcRank == dstRank) && "invalid rank in broadcast");
- for (int64_t r = 0; r < dstRank; r++) {
- if (srcVectorType.getDimSize(r) != dstVectorType.getDimSize(r)) {
- return stretchOneRank(value, loc, srcVectorType, dstVectorType, dstRank,
- curDim, rewriter);
- }
- }
- return value;
- }
-
- // Picks the best way to duplicate a single rank. For the 1-D case, a
- // single insert-elt/shuffle is the most efficient expansion. For higher
- // dimensions, however, we need dim x insert-values on a new broadcast
- // with one less leading dimension, which will be lowered "recursively"
- // to matching LLVM IR.
- // For example:
- // v = broadcast s : f32 to vector<4x2xf32>
- // becomes:
- // x = broadcast s : f32 to vector<2xf32>
- // v = [x,x,x,x]
- // becomes:
- // x = [s,s]
- // v = [x,x,x,x]
- Value duplicateOneRank(Value value, Location loc, VectorType srcVectorType,
- VectorType dstVectorType, int64_t rank, int64_t dim,
- ConversionPatternRewriter &rewriter) const {
- Type llvmType = typeConverter.convertType(dstVectorType);
- assert((llvmType != nullptr) && "unlowerable vector type");
- if (rank == 1) {
- Value undef = rewriter.create<LLVM::UndefOp>(loc, llvmType);
- Value expand = insertOne(rewriter, typeConverter, loc, undef, value,
- llvmType, rank, 0);
- SmallVector<int32_t, 4> zeroValues(dim, 0);
- return rewriter.create<LLVM::ShuffleVectorOp>(
- loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues));
- }
- Value expand = expandRanks(value, loc, srcVectorType,
- reducedVectorTypeFront(dstVectorType), rewriter);
- Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
- for (int64_t d = 0; d < dim; ++d) {
- result = insertOne(rewriter, typeConverter, loc, result, expand, llvmType,
- rank, d);
- }
- return result;
- }
-
- // Picks the best way to stretch a single rank. For the 1-D case, a
- // single insert-elt/shuffle is the most efficient expansion when at
- // a stretch. Otherwise, every dimension needs to be expanded
- // individually and individually inserted in the resulting vector.
- // For example:
- // v = broadcast w : vector<4x1x2xf32> to vector<4x2x2xf32>
- // becomes:
- // a = broadcast w[0] : vector<1x2xf32> to vector<2x2xf32>
- // b = broadcast w[1] : vector<1x2xf32> to vector<2x2xf32>
- // c = broadcast w[2] : vector<1x2xf32> to vector<2x2xf32>
- // d = broadcast w[3] : vector<1x2xf32> to vector<2x2xf32>
- // v = [a,b,c,d]
- // becomes:
- // x = broadcast w[0][0] : vector<2xf32> to vector <2x2xf32>
- // y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32>
- // a = [x, y]
- // etc.
- Value stretchOneRank(Value value, Location loc, VectorType srcVectorType,
- VectorType dstVectorType, int64_t rank, int64_t dim,
- ConversionPatternRewriter &rewriter) const {
- Type llvmType = typeConverter.convertType(dstVectorType);
- assert((llvmType != nullptr) && "unlowerable vector type");
- Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
- bool atStretch = dim != srcVectorType.getDimSize(0);
- if (rank == 1) {
- assert(atStretch);
- Type redLlvmType =
- typeConverter.convertType(dstVectorType.getElementType());
- Value one =
- extractOne(rewriter, typeConverter, loc, value, redLlvmType, rank, 0);
- Value expand = insertOne(rewriter, typeConverter, loc, result, one,
- llvmType, rank, 0);
- SmallVector<int32_t, 4> zeroValues(dim, 0);
- return rewriter.create<LLVM::ShuffleVectorOp>(
- loc, expand, result, rewriter.getI32ArrayAttr(zeroValues));
- }
- VectorType redSrcType = reducedVectorTypeFront(srcVectorType);
- VectorType redDstType = reducedVectorTypeFront(dstVectorType);
- Type redLlvmType = typeConverter.convertType(redSrcType);
- for (int64_t d = 0; d < dim; ++d) {
- int64_t pos = atStretch ? 0 : d;
- Value one = extractOne(rewriter, typeConverter, loc, value, redLlvmType,
- rank, pos);
- Value expand = expandRanks(one, loc, redSrcType, redDstType, rewriter);
- result = insertOne(rewriter, typeConverter, loc, result, expand, llvmType,
- rank, d);
- }
- return result;
- }
-};
-
/// Conversion pattern for a vector.matrix_multiply.
/// This is lowered directly to the proper llvm.intr.matrix.multiply.
class VectorMatmulOpConversion : public ConvertToLLVMPattern {
@@ -1209,8 +1060,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorInsertStridedSliceOpSameRankRewritePattern,
VectorStridedSliceOpConversion>(ctx);
patterns
- .insert<VectorBroadcastOpConversion,
- VectorReductionOpConversion,
+ .insert<VectorReductionOpConversion,
VectorShuffleOpConversion,
VectorExtractElementOpConversion,
VectorExtractOpConversion,
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index e888c5cdfd2f..c0d6ce931d10 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -979,7 +979,114 @@ class InsertSlicesOpLowering : public OpRewritePattern<vector::InsertSlicesOp> {
}
};
-/// Progressive lowering of OuterProductOp.
+/// Progressive lowering of BroadcastOp.
+class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
+public:
+ using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::BroadcastOp op,
+ PatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ VectorType dstType = op.getVectorType();
+ VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
+ Type eltType = dstType.getElementType();
+
+ // Determine rank of source and destination.
+ int64_t srcRank = srcType ? srcType.getRank() : 0;
+ int64_t dstRank = dstType.getRank();
+
+ // Duplicate this rank.
+ // For example:
+ // %x = broadcast %y : k-D to n-D, k < n
+ // becomes:
+ // %b = broadcast %y : k-D to (n-1)-D
+ // %x = [%b,%b,%b,%b] : n-D
+ // becomes:
+ // %b = [%y,%y] : (n-1)-D
+ // %x = [%b,%b,%b,%b] : n-D
+ if (srcRank < dstRank) {
+ // Scalar to any vector can use splat.
+ if (srcRank == 0) {
+ rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, op.source());
+ return success();
+ }
+ // Duplication.
+ VectorType resType =
+ VectorType::get(dstType.getShape().drop_front(), eltType);
+ Value bcst =
+ rewriter.create<vector::BroadcastOp>(loc, resType, op.source());
+ Value zero = rewriter.create<ConstantOp>(loc, eltType,
+ rewriter.getZeroAttr(eltType));
+ Value result = rewriter.create<SplatOp>(loc, dstType, zero);
+ for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
+ result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+
+ // Find non-matching dimension, if any.
+ assert(srcRank == dstRank);
+ int64_t m = -1;
+ for (int64_t r = 0; r < dstRank; r++)
+ if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
+ m = r;
+ break;
+ }
+
+ // All trailing dimensions are the same. Simply pass through.
+ if (m == -1) {
+ rewriter.replaceOp(op, op.source());
+ return success();
+ }
+
+ // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
+ if (srcRank == 1) {
+ assert(m == 0);
+ Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
+ rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, ext);
+ return success();
+ }
+
+ // Any non-matching dimension forces a stretch along this rank.
+ // For example:
+ // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
+ // becomes:
+ // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
+ // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
+ // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
+ // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
+ // %x = [%a,%b,%c,%d]
+ // becomes:
+ // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
+ // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
+ // %a = [%u, %v]
+ // ..
+ // %x = [%a,%b,%c,%d]
+ VectorType resType =
+ VectorType::get(dstType.getShape().drop_front(), eltType);
+ Value zero = rewriter.create<ConstantOp>(loc, eltType,
+ rewriter.getZeroAttr(eltType));
+ Value result = rewriter.create<SplatOp>(loc, dstType, zero);
+ if (m == 0) {
+ // Stetch at start.
+ Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
+ Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
+ for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
+ result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
+ } else {
+ // Stetch not at start.
+ for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
+ Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), d);
+ Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
+ result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
+ }
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+/// Progressive lowering of TransposeOp.
/// One:
/// %x = vector.transpose %y, [1, 0]
/// is replaced by:
@@ -1518,7 +1625,7 @@ void mlir::vector::populateVectorContractLoweringPatterns(
OwningRewritePatternList &patterns, MLIRContext *context,
VectorTransformsOptions parameters) {
patterns.insert<ShapeCastOp2DDownCastRewritePattern,
- ShapeCastOp2DUpCastRewritePattern, TransposeOpLowering,
- OuterProductOpLowering>(context);
+ ShapeCastOp2DUpCastRewritePattern, BroadcastOpLowering,
+ TransposeOpLowering, OuterProductOpLowering>(context);
patterns.insert<ContractionOpLowering>(parameters, context);
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 6a65b219b632..96d4343b1a4b 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -4,201 +4,199 @@ func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
%0 = vector.broadcast %arg0 : f32 to vector<2xf32>
return %0 : vector<2xf32>
}
-// CHECK-LABEL: llvm.func @broadcast_vec1d_from_scalar
-// CHECK: llvm.mlir.undef : !llvm<"<2 x float>">
-// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
-// CHECK: llvm.shufflevector {{.*}}, {{.*}}[0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
-// CHECK: llvm.return {{.*}} : !llvm<"<2 x float>">
+// CHECK-LABEL: llvm.func @broadcast_vec1d_from_scalar(
+// CHECK-SAME: %[[A:.*]]: !llvm.float)
+// CHECK: %[[T0:.*]] = llvm.mlir.undef : !llvm<"<2 x float>">
+// CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+// CHECK: %[[T2:.*]] = llvm.insertelement %[[A]], %[[T0]][%[[T1]] : !llvm.i32] : !llvm<"<2 x float>">
+// CHECK: %[[T3:.*]] = llvm.shufflevector %[[T2]], %[[T0]] [0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
+// CHECK: llvm.return %[[T3]] : !llvm<"<2 x float>">
func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
%0 = vector.broadcast %arg0 : f32 to vector<2x3xf32>
return %0 : vector<2x3xf32>
}
-// CHECK-LABEL: llvm.func @broadcast_vec2d_from_scalar
-// CHECK: llvm.mlir.undef : !llvm<"<3 x float>">
-// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>">
-// CHECK: llvm.shufflevector {{.*}}, {{.*}}[0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
-// CHECK: llvm.mlir.undef : !llvm<"[2 x <3 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[2 x <3 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[2 x <3 x float>]">
-// CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]">
+// CHECK-LABEL: llvm.func @broadcast_vec2d_from_scalar(
+// CHECK-SAME: %[[A:.*]]: !llvm.float)
+// CHECK: %[[T0:.*]] = llvm.mlir.undef : !llvm<"[2 x <3 x float>]">
+// CHECK: %[[T1:.*]] = llvm.mlir.undef : !llvm<"<3 x float>">
+// CHECK: %[[T2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+// CHECK: %[[T3:.*]] = llvm.insertelement %[[A]], %[[T1]][%[[T2]] : !llvm.i32] : !llvm<"<3 x float>">
+// CHECK: %[[T4:.*]] = llvm.shufflevector %[[T3]], %[[T3]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
+// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T4]], %[[T0]][0] : !llvm<"[2 x <3 x float>]">
+// CHECK: %[[T6:.*]] = llvm.insertvalue %[[T4]], %[[T5]][1] : !llvm<"[2 x <3 x float>]">
+// CHECK: llvm.return %[[T6]] : !llvm<"[2 x <3 x float>]">
func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
%0 = vector.broadcast %arg0 : f32 to vector<2x3x4xf32>
return %0 : vector<2x3x4xf32>
}
-// CHECK-LABEL: llvm.func @broadcast_vec3d_from_scalar
-// CHECK: llvm.mlir.undef : !llvm<"<4 x float>">
-// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>">
-// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
-// CHECK: llvm.mlir.undef : !llvm<"[3 x <4 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <4 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <4 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <4 x float>]">
-// CHECK: llvm.mlir.undef : !llvm<"[2 x [3 x <4 x float>]]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[2 x [3 x <4 x float>]]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[2 x [3 x <4 x float>]]">
-// CHECK: llvm.return {{.*}} : !llvm<"[2 x [3 x <4 x float>]]">
+// CHECK-LABEL: llvm.func @broadcast_vec3d_from_scalar(
+// CHECK-SAME: %[[A:.*]]: !llvm.float)
+// CHECK: %[[T0:.*]] = llvm.mlir.undef : !llvm<"[2 x [3 x <4 x float>]]">
+// CHECK: %[[T1:.*]] = llvm.mlir.undef : !llvm<"<4 x float>">
+// CHECK: %[[T2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+// CHECK: %[[T3:.*]] = llvm.insertelement %[[A]], %[[T1]][%[[T2]] : !llvm.i32] : !llvm<"<4 x float>">
+// CHECK: %[[T4:.*]] = llvm.shufflevector %[[T3]], %[[T3]] [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
+// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T4]], %[[T0]][0, 0] : !llvm<"[2 x [3 x <4 x float>]]">
+// CHECK: %[[T6:.*]] = llvm.insertvalue %[[T4]], %[[T5]][0, 1] : !llvm<"[2 x [3 x <4 x float>]]">
+// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T4]], %[[T6]][0, 2] : !llvm<"[2 x [3 x <4 x float>]]">
+// CHECK: %[[T8:.*]] = llvm.insertvalue %[[T4]], %[[T7]][1, 0] : !llvm<"[2 x [3 x <4 x float>]]">
+// CHECK: %[[T9:.*]] = llvm.insertvalue %[[T4]], %[[T8]][1, 1] : !llvm<"[2 x [3 x <4 x float>]]">
+// CHECK: %[[T10:.*]] = llvm.insertvalue %[[T4]], %[[T9]][1, 2] : !llvm<"[2 x [3 x <4 x float>]]">
+// CHECK: llvm.return %[[T10]] : !llvm<"[2 x [3 x <4 x float>]]">
func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> {
%0 = vector.broadcast %arg0 : vector<2xf32> to vector<2xf32>
return %0 : vector<2xf32>
}
-// CHECK-LABEL: llvm.func @broadcast_vec1d_from_vec1d
-// CHECK: llvm.return {{.*}} : !llvm<"<2 x float>">
+// CHECK-LABEL: llvm.func @broadcast_vec1d_from_vec1d(
+// CHECK-SAME: %[[A:.*]]: !llvm<"<2 x float>">)
+// CHECK: llvm.return %[[A]] : !llvm<"<2 x float>">
func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> {
%0 = vector.broadcast %arg0 : vector<2xf32> to vector<3x2xf32>
return %0 : vector<3x2xf32>
}
-// CHECK-LABEL: llvm.func @broadcast_vec2d_from_vec1d
-// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.return {{.*}} : !llvm<"[3 x <2 x float>]">
+// CHECK-LABEL: llvm.func @broadcast_vec2d_from_vec1d(
+// CHECK-SAME: %[[A:.*]]: !llvm<"<2 x float>">)
+// CHECK: %[[T0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<3x2xf32>) : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T1:.*]] = llvm.insertvalue %[[A]], %[[T0]][0] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T2:.*]] = llvm.insertvalue %[[A]], %[[T1]][1] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][2] : !llvm<"[3 x <2 x float>]">
+// CHECK: llvm.return %[[T3]] : !llvm<"[3 x <2 x float>]">
func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32> {
%0 = vector.broadcast %arg0 : vector<2xf32> to vector<4x3x2xf32>
return %0 : vector<4x3x2xf32>
}
-// CHECK-LABEL: llvm.func @broadcast_vec3d_from_vec1d
-// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.mlir.undef : !llvm<"[4 x [3 x <2 x float>]]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x [3 x <2 x float>]]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x [3 x <2 x float>]]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [3 x <2 x float>]]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x [3 x <2 x float>]]">
-// CHECK: llvm.return {{.*}} : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK-LABEL: llvm.func @broadcast_vec3d_from_vec1d(
+// CHECK-SAME: %[[A:.*]]: !llvm<"<2 x float>">)
+// CHECK: %[[T0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<3x2xf32>) : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T1:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<4x3x2xf32>) : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: %[[T2:.*]] = llvm.insertvalue %[[A]], %[[T0]][0] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][1] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T4:.*]] = llvm.insertvalue %[[A]], %[[T3]][2] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T4]], %[[T1]][0] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: %[[T6:.*]] = llvm.insertvalue %[[T4]], %[[T5]][1] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T4]], %[[T6]][2] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: %[[T8:.*]] = llvm.insertvalue %[[T4]], %[[T7]][3] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: llvm.return %[[T8]] : !llvm<"[4 x [3 x <2 x float>]]">
func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf32> {
%0 = vector.broadcast %arg0 : vector<3x2xf32> to vector<4x3x2xf32>
return %0 : vector<4x3x2xf32>
}
-// CHECK-LABEL: llvm.func @broadcast_vec3d_from_vec2d
-// CHECK: llvm.mlir.undef : !llvm<"[4 x [3 x <2 x float>]]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x [3 x <2 x float>]]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x [3 x <2 x float>]]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [3 x <2 x float>]]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x [3 x <2 x float>]]">
-// CHECK: llvm.return {{.*}} : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK-LABEL: llvm.func @broadcast_vec3d_from_vec2d(
+// CHECK-SAME: %[[A:.*]]: !llvm<"[3 x <2 x float>]">)
+// CHECK: %[[T0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<4x3x2xf32>) : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: %[[T1:.*]] = llvm.insertvalue %[[A]], %[[T0]][0] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: %[[T2:.*]] = llvm.insertvalue %[[A]], %[[T1]][1] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][2] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: %[[T4:.*]] = llvm.insertvalue %[[A]], %[[T3]][3] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: llvm.return %[[T4]] : !llvm<"[4 x [3 x <2 x float>]]">
func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
return %0 : vector<4xf32>
}
-// CHECK-LABEL: llvm.func @broadcast_stretch
-// CHECK: llvm.mlir.undef : !llvm<"<4 x float>">
-// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>">
-// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>">
-// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
-// CHECK: llvm.return {{.*}} : !llvm<"<4 x float>">
+// CHECK-LABEL: llvm.func @broadcast_stretch(
+// CHECK-SAME: %[[A:.*]]: !llvm<"<1 x float>">)
+// CHECK: %[[T0:.*]] = llvm.mlir.constant(0 : i64) : !llvm.i64
+// CHECK: %[[T1:.*]] = llvm.extractelement %[[A]][%[[T0]] : !llvm.i64] : !llvm<"<1 x float>">
+// CHECK: %[[T2:.*]] = llvm.mlir.undef : !llvm<"<4 x float>">
+// CHECK: %[[T3:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+// CHECK: %[[T4:.*]] = llvm.insertelement %[[T1]], %[[T2]][%3 : !llvm.i32] : !llvm<"<4 x float>">
+// CHECK: %[[T5:.*]] = llvm.shufflevector %[[T4]], %[[T2]] [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
+// CHECK: llvm.return %[[T5]] : !llvm<"<4 x float>">
func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> {
%0 = vector.broadcast %arg0 : vector<1x4xf32> to vector<3x4xf32>
return %0 : vector<3x4xf32>
}
-// CHECK-LABEL: llvm.func @broadcast_stretch_at_start
-// CHECK: llvm.mlir.undef : !llvm<"[3 x <4 x float>]">
-// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <4 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <4 x float>]">
-// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <4 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <4 x float>]">
-// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <4 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <4 x float>]">
-// CHECK: llvm.return {{.*}} : !llvm<"[3 x <4 x float>]">
+// CHECK-LABEL: llvm.func @broadcast_stretch_at_start(
+// CHECK-SAME: %[[A:.*]]: !llvm<"[1 x <4 x float>]">)
+// CHECK: %[[T0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<3x4xf32>) : !llvm<"[3 x <4 x float>]">
+// CHECK: %[[T1:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[1 x <4 x float>]">
+// CHECK: %[[T2:.*]] = llvm.insertvalue %[[T1]], %[[T0]][0] : !llvm<"[3 x <4 x float>]">
+// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][1] : !llvm<"[3 x <4 x float>]">
+// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T1]], %[[T3]][2] : !llvm<"[3 x <4 x float>]">
+// CHECK: llvm.return %[[T4]] : !llvm<"[3 x <4 x float>]">
func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> {
%0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<4x3xf32>
return %0 : vector<4x3xf32>
}
-// CHECK-LABEL: llvm.func @broadcast_stretch_at_end
-// CHECK: llvm.mlir.undef : !llvm<"[4 x <3 x float>]">
-// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[4 x <1 x float>]">
-// CHECK: llvm.mlir.undef : !llvm<"<3 x float>">
-// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>">
-// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>">
-// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <3 x float>]">
-// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"[4 x <1 x float>]">
-// CHECK: llvm.mlir.undef : !llvm<"<3 x float>">
-// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>">
-// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>">
-// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x <3 x float>]">
-// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x <1 x float>]">
-// CHECK: llvm.mlir.undef : !llvm<"<3 x float>">
-// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>">
-// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>">
-// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x <3 x float>]">
-// CHECK: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <1 x float>]">
-// CHECK: llvm.mlir.undef : !llvm<"<3 x float>">
-// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>">
-// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>">
-// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <3 x float>]">
-// CHECK: llvm.return {{.*}} : !llvm<"[4 x <3 x float>]">
+// CHECK-LABEL: llvm.func @broadcast_stretch_at_end(
+// CHECK-SAME: %[[A:.*]]: !llvm<"[4 x <1 x float>]">)
+// CHECK: %[[T0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<4x3xf32>) : !llvm<"[4 x <3 x float>]">
+// CHECK: %[[T1:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[4 x <1 x float>]">
+// CHECK: %[[T2:.*]] = llvm.mlir.constant(0 : i64) : !llvm.i64
+// CHECK: %[[T3:.*]] = llvm.extractelement %[[T1]][%[[T2]] : !llvm.i64] : !llvm<"<1 x float>">
+// CHECK: %[[T4:.*]] = llvm.mlir.undef : !llvm<"<3 x float>">
+// CHECK: %[[T5:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+// CHECK: %[[T6:.*]] = llvm.insertelement %[[T3]], %[[T4]][%[[T5]] : !llvm.i32] : !llvm<"<3 x float>">
+// CHECK: %[[T7:.*]] = llvm.shufflevector %[[T6]], %[[T4]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
+// CHECK: %[[T8:.*]] = llvm.insertvalue %[[T7]], %[[T0]][0] : !llvm<"[4 x <3 x float>]">
+// CHECK: %[[T9:.*]] = llvm.extractvalue %[[A]][1] : !llvm<"[4 x <1 x float>]">
+// CHECK: %[[T10:.*]] = llvm.mlir.constant(0 : i64) : !llvm.i64
+// CHECK: %[[T11:.*]] = llvm.extractelement %[[T9]][%[[T10]] : !llvm.i64] : !llvm<"<1 x float>">
+// CHECK: %[[T12:.*]] = llvm.mlir.undef : !llvm<"<3 x float>">
+// CHECK: %[[T13:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+// CHECK: %[[T14:.*]] = llvm.insertelement %[[T11]], %[[T12]][%[[T13]] : !llvm.i32] : !llvm<"<3 x float>">
+// CHECK: %[[T15:.*]] = llvm.shufflevector %[[T14]], %[[T12]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
+// CHECK: %[[T16:.*]] = llvm.insertvalue %[[T15]], %[[T8]][1] : !llvm<"[4 x <3 x float>]">
+// CHECK: %[[T17:.*]] = llvm.extractvalue %[[A]][2] : !llvm<"[4 x <1 x float>]">
+// CHECK: %[[T18:.*]] = llvm.mlir.constant(0 : i64) : !llvm.i64
+// CHECK: %[[T19:.*]] = llvm.extractelement %[[T17]][%[[T18]] : !llvm.i64] : !llvm<"<1 x float>">
+// CHECK: %[[T20:.*]] = llvm.mlir.undef : !llvm<"<3 x float>">
+// CHECK: %[[T21:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+// CHECK: %[[T22:.*]] = llvm.insertelement %[[T19]], %[[T20]][%[[T21]] : !llvm.i32] : !llvm<"<3 x float>">
+// CHECK: %[[T23:.*]] = llvm.shufflevector %[[T22]], %[[T20]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
+// CHECK: %[[T24:.*]] = llvm.insertvalue %[[T23]], %[[T16]][2] : !llvm<"[4 x <3 x float>]">
+// CHECK: %[[T25:.*]] = llvm.extractvalue %[[A]][3] : !llvm<"[4 x <1 x float>]">
+// CHECK: %[[T26:.*]] = llvm.mlir.constant(0 : i64) : !llvm.i64
+// CHECK: %[[T27:.*]] = llvm.extractelement %[[T25]][%[[T26]] : !llvm.i64] : !llvm<"<1 x float>">
+// CHECK: %[[T28:.*]] = llvm.mlir.undef : !llvm<"<3 x float>">
+// CHECK: %[[T29:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+// CHECK: %[[T30:.*]] = llvm.insertelement %[[T27]], %[[T28]][%[[T29]] : !llvm.i32] : !llvm<"<3 x float>">
+// CHECK: %[[T31:.*]] = llvm.shufflevector %[[T30]], %[[T28]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
+// CHECK: %[[T32:.*]] = llvm.insertvalue %[[T31]], %[[T24]][3] : !llvm<"[4 x <3 x float>]">
+// CHECK: llvm.return %[[T32]] : !llvm<"[4 x <3 x float>]">
func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32> {
%0 = vector.broadcast %arg0 : vector<4x1x2xf32> to vector<4x3x2xf32>
return %0 : vector<4x3x2xf32>
}
-// CHECK-LABEL: llvm.func @broadcast_stretch_in_middle
-// CHECK: llvm.mlir.undef : !llvm<"[4 x [3 x <2 x float>]]">
-// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[4 x [1 x <2 x float>]]">
-// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x [3 x <2 x float>]]">
-// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"[4 x [1 x <2 x float>]]">
-// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x [3 x <2 x float>]]">
-// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x [1 x <2 x float>]]">
-// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [3 x <2 x float>]]">
-// CHECK: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x [1 x <2 x float>]]">
-// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x [3 x <2 x float>]]">
-// CHECK: llvm.return {{.*}} : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK-LABEL: llvm.func @broadcast_stretch_in_middle(
+// CHECK-SAME: %[[A:.*]]: !llvm<"[4 x [1 x <2 x float>]]">)
+// CHECK: %[[T0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<4x3x2xf32>) : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: %[[T1:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<3x2xf32>) : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T2:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[4 x [1 x <2 x float>]]">
+// CHECK: %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm<"[1 x <2 x float>]">
+// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T1]][0] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T3]], %[[T4]][1] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T6:.*]] = llvm.insertvalue %[[T3]], %[[T5]][2] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T0]][0] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: %[[T8:.*]] = llvm.extractvalue %[[A]][1] : !llvm<"[4 x [1 x <2 x float>]]">
+// CHECK: %[[T9:.*]] = llvm.extractvalue %[[T8]][0] : !llvm<"[1 x <2 x float>]">
+// CHECK: %[[T10:.*]] = llvm.insertvalue %[[T9]], %[[T1]][0] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T11:.*]] = llvm.insertvalue %[[T9]], %[[T10]][1] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T12:.*]] = llvm.insertvalue %[[T9]], %[[T11]][2] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T13:.*]] = llvm.insertvalue %[[T12]], %[[T7]][1] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: %[[T14:.*]] = llvm.extractvalue %[[A]][2] : !llvm<"[4 x [1 x <2 x float>]]">
+// CHECK: %[[T15:.*]] = llvm.extractvalue %[[T14]][0] : !llvm<"[1 x <2 x float>]">
+// CHECK: %[[T16:.*]] = llvm.insertvalue %[[T15]], %[[T1]][0] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T17:.*]] = llvm.insertvalue %[[T15]], %[[T16]][1] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T18:.*]] = llvm.insertvalue %[[T15]], %[[T17]][2] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T19:.*]] = llvm.insertvalue %[[T18]], %[[T13]][2] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: %[[T20:.*]] = llvm.extractvalue %[[A]][3] : !llvm<"[4 x [1 x <2 x float>]]">
+// CHECK: %[[T21:.*]] = llvm.extractvalue %[[T20]][0] : !llvm<"[1 x <2 x float>]">
+// CHECK: %[[T22:.*]] = llvm.insertvalue %[[T21]], %[[T1]][0] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T23:.*]] = llvm.insertvalue %[[T21]], %[[T22]][1] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T24:.*]] = llvm.insertvalue %[[T21]], %[[T23]][2] : !llvm<"[3 x <2 x float>]">
+// CHECK: %[[T25:.*]] = llvm.insertvalue %[[T24]], %[[T19]][3] : !llvm<"[4 x [3 x <2 x float>]]">
+// CHECK: llvm.return %[[T25]] : !llvm<"[4 x [3 x <2 x float>]]">
func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32> {
%2 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32>
@@ -211,16 +209,16 @@ func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32
// CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : i64) : !llvm.i64
// CHECK: %[[T2:.*]] = llvm.extractelement %[[A]][%[[T1]] : !llvm.i64] : !llvm<"<2 x float>">
// CHECK: %[[T3:.*]] = llvm.mlir.undef : !llvm<"<3 x float>">
-// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK: %[[T5:.*]] = llvm.insertelement %[[T2]], %[[T3]][%4 : !llvm.i64] : !llvm<"<3 x float>">
+// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+// CHECK: %[[T5:.*]] = llvm.insertelement %[[T2]], %[[T3]][%4 : !llvm.i32] : !llvm<"<3 x float>">
// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T5]], %[[T3]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
// CHECK: %[[T7:.*]] = llvm.fmul %[[T6]], %[[B]] : !llvm<"<3 x float>">
// CHECK: %[[T8:.*]] = llvm.insertvalue %[[T7]], %[[T0]][0] : !llvm<"[2 x <3 x float>]">
// CHECK: %[[T9:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64
// CHECK: %[[T10:.*]] = llvm.extractelement %[[A]][%9 : !llvm.i64] : !llvm<"<2 x float>">
// CHECK: %[[T11:.*]] = llvm.mlir.undef : !llvm<"<3 x float>">
-// CHECK: %[[T12:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK: %[[T13:.*]] = llvm.insertelement %[[T10]], %[[T11]][%12 : !llvm.i64] : !llvm<"<3 x float>">
+// CHECK: %[[T12:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+// CHECK: %[[T13:.*]] = llvm.insertelement %[[T10]], %[[T11]][%12 : !llvm.i32] : !llvm<"<3 x float>">
// CHECK: %[[T14:.*]] = llvm.shufflevector %[[T13]], %[[T11]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
// CHECK: %[[T15:.*]] = llvm.fmul %[[T14]], %[[B]] : !llvm<"<3 x float>">
// CHECK: %[[T16:.*]] = llvm.insertvalue %[[T15]], %[[T8]][1] : !llvm<"[2 x <3 x float>]">
@@ -238,8 +236,8 @@ func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector
// CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : i64) : !llvm.i64
// CHECK: %[[T2:.*]] = llvm.extractelement %[[A]][%[[T1]] : !llvm.i64] : !llvm<"<2 x float>">
// CHECK: %[[T3:.*]] = llvm.mlir.undef : !llvm<"<3 x float>">
-// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK: %[[T5:.*]] = llvm.insertelement %[[T2]], %[[T3]][%[[T4]] : !llvm.i64] : !llvm<"<3 x float>">
+// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+// CHECK: %[[T5:.*]] = llvm.insertelement %[[T2]], %[[T3]][%[[T4]] : !llvm.i32] : !llvm<"<3 x float>">
// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T5]], %[[T3]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
// CHECK: %[[T7:.*]] = llvm.extractvalue %[[C]][0] : !llvm<"[2 x <3 x float>]">
// CHECK: %[[T8:.*]] = "llvm.intr.fma"(%[[T6]], %[[B]], %[[T7]]) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">)
@@ -247,8 +245,8 @@ func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector
// CHECK: %[[T10:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64
// CHECK: %[[T11:.*]] = llvm.extractelement %[[A]][%[[T10]] : !llvm.i64] : !llvm<"<2 x float>">
// CHECK: %[[T12:.*]] = llvm.mlir.undef : !llvm<"<3 x float>">
-// CHECK: %[[T13:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK: %[[T14:.*]] = llvm.insertelement %[[T11]], %[[T12]][%[[T13]] : !llvm.i64] : !llvm<"<3 x float>">
+// CHECK: %[[T13:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+// CHECK: %[[T14:.*]] = llvm.insertelement %[[T11]], %[[T12]][%[[T13]] : !llvm.i32] : !llvm<"<3 x float>">
// CHECK: %[[T15:.*]] = llvm.shufflevector %[[T14]], %[[T12]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
// CHECK: %[[T16:.*]] = llvm.extractvalue %[[C]][1] : !llvm<"[2 x <3 x float>]">
// CHECK: %[[T17:.*]] = "llvm.intr.fma"(%[[T15]], %[[B]], %[[T16]]) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">)
diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 08140b4ae065..8354677b797c 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -257,11 +257,11 @@ func @full_contract2(%arg0: vector<2x3xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<3xf32>
// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32>
-// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
+// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xf32>
// CHECK: %[[T2:.*]] = mulf %[[T1]], %[[B]] : vector<3xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
+// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<3xf32>
// CHECK: %[[T6:.*]] = mulf %[[T5]], %[[B]] : vector<3xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32>
// CHECK: return %[[T7]] : vector<2x3xf32>
@@ -278,12 +278,12 @@ func @outerproduct_noacc(%arg0: vector<2xf32>,
// CHECK-SAME: %[[C:.*2]]: vector<2x3xf32>
// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32>
-// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
+// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xf32>
// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xf32>
// CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32>
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.broadcast %[[T5]] : f32 to vector<3xf32>
+// CHECK: %[[T6:.*]] = splat %[[T5]] : vector<3xf32>
// CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<2x3xf32>
// CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32>
@@ -389,3 +389,173 @@ func @matmul(%arg0: vector<2x4xf32>,
: vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
return %0 : vector<2x3xf32>
}
+
+// CHECK-LABEL: func @broadcast_vec1d_from_scalar
+// CHECK-SAME: %[[A:.*0]]: f32
+// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2xf32>
+// CHECK: return %[[T0]] : vector<2xf32>
+
+func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
+ %0 = vector.broadcast %arg0 : f32 to vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func @broadcast_vec2d_from_scalar
+// CHECK-SAME: %[[A:.*0]]: f32
+// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2x3xf32>
+// CHECK: return %[[T0]] : vector<2x3xf32>
+
+func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
+ %0 = vector.broadcast %arg0 : f32 to vector<2x3xf32>
+ return %0 : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func @broadcast_vec3d_from_scalar
+// CHECK-SAME: %[[A:.*0]]: f32
+// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2x3x4xf32>
+// CHECK: return %[[T0]] : vector<2x3x4xf32>
+
+func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
+ %0 = vector.broadcast %arg0 : f32 to vector<2x3x4xf32>
+ return %0 : vector<2x3x4xf32>
+}
+
+// CHECK-LABEL: func @broadcast_vec1d_from_vec1d
+// CHECK-SAME: %[[A:.*0]]: vector<2xf32>
+// CHECK: return %[[A]] : vector<2xf32>
+
+func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.broadcast %arg0 : vector<2xf32> to vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func @broadcast_vec2d_from_vec1d
+// CHECK-SAME: %[[A:.*0]]: vector<2xf32>
+// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<3x2xf32>
+// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<2xf32> into vector<3x2xf32>
+// CHECK: return %[[T2]] : vector<3x2xf32>
+
+func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> {
+ %0 = vector.broadcast %arg0 : vector<2xf32> to vector<3x2xf32>
+ return %0 : vector<3x2xf32>
+}
+
+// CHECK-LABEL: func @broadcast_vec3d_from_vec1d
+// CHECK-SAME: %[[A:.*0]]: vector<2xf32>
+// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<3x2xf32>
+// CHECK: %[[C1:.*]] = constant dense<0.000000e+00> : vector<4x3x2xf32>
+// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C1]] [0] : vector<3x2xf32> into vector<4x3x2xf32>
+// CHECK: %[[T4:.*]] = vector.insert %[[T2]], %[[T3]] [1] : vector<3x2xf32> into vector<4x3x2xf32>
+// CHECK: %[[T5:.*]] = vector.insert %[[T2]], %[[T4]] [2] : vector<3x2xf32> into vector<4x3x2xf32>
+// CHECK: %[[T6:.*]] = vector.insert %[[T2]], %[[T5]] [3] : vector<3x2xf32> into vector<4x3x2xf32>
+// CHECK: return %[[T6]] : vector<4x3x2xf32>
+
+func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32> {
+ %0 = vector.broadcast %arg0 : vector<2xf32> to vector<4x3x2xf32>
+ return %0 : vector<4x3x2xf32>
+}
+
+// CHECK-LABEL: func @broadcast_vec3d_from_vec2d
+// CHECK-SAME: %[[A:.*0]]: vector<3x2xf32>
+// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<4x3x2xf32>
+// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<3x2xf32> into vector<4x3x2xf32>
+// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<3x2xf32> into vector<4x3x2xf32>
+// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<3x2xf32> into vector<4x3x2xf32>
+// CHECK: %[[T3:.*]] = vector.insert %[[A]], %[[T2]] [3] : vector<3x2xf32> into vector<4x3x2xf32>
+// CHECK: return %[[T3]] : vector<4x3x2xf32>
+
+func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf32> {
+ %0 = vector.broadcast %arg0 : vector<3x2xf32> to vector<4x3x2xf32>
+ return %0 : vector<4x3x2xf32>
+}
+
+// CHECK-LABEL: func @broadcast_stretch
+// CHECK-SAME: %[[A:.*0]]: vector<1xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<1xf32>
+// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<4xf32>
+// CHECK: return %[[T1]] : vector<4xf32>
+
+func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
+ %0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @broadcast_stretch_at_start
+// CHECK-SAME: %[[A:.*0]]: vector<1x4xf32>
+// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<3x4xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<1x4xf32>
+// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C0]] [0] : vector<4xf32> into vector<3x4xf32>
+// CHECK: %[[T2:.*]] = vector.insert %[[T0]], %[[T1]] [1] : vector<4xf32> into vector<3x4xf32>
+// CHECK: %[[T3:.*]] = vector.insert %[[T0]], %[[T2]] [2] : vector<4xf32> into vector<3x4xf32>
+// CHECK: return %[[T3]] : vector<3x4xf32>
+
+func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<1x4xf32> to vector<3x4xf32>
+ return %0 : vector<3x4xf32>
+}
+
+// CHECK-LABEL: func @broadcast_stretch_at_end
+// CHECK-SAME: %[[A:.*0]]: vector<4x1xf32>
+// CHECK: %[[C:.*]] = constant dense<0.000000e+00> : vector<4x3xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x1xf32>
+// CHECK: %[[T1:.*]] = vector.extract %[[T0]][0] : vector<1xf32>
+// CHECK: %[[T2:.*]] = splat %[[T1]] : vector<3xf32>
+// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<4x3xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<4x1xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[T4]][0] : vector<1xf32>
+// CHECK: %[[T6:.*]] = splat %[[T5]] : vector<3xf32>
+// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[A]][2] : vector<4x1xf32>
+// CHECK: %[[T9:.*]] = vector.extract %[[T8]][0] : vector<1xf32>
+// CHECK: %[[T10:.*]] = splat %[[T9]] : vector<3xf32>
+// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32>
+// CHECK: %[[T12:.*]] = vector.extract %[[A]][3] : vector<4x1xf32>
+// CHECK: %[[T13:.*]] = vector.extract %[[T12]][0] : vector<1xf32>
+// CHECK: %[[T14:.*]] = splat %[[T13]] : vector<3xf32>
+// CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32>
+// CHECK: return %[[T15]] : vector<4x3xf32>
+
+func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> {
+ %0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<4x3xf32>
+ return %0 : vector<4x3xf32>
+}
+
+// CHECK-LABEL: func @broadcast_stretch_in_middle
+// CHECK-SAME: %[[A:.*0]]: vector<4x1x2xf32>
+// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<4x3x2xf32>
+// CHECK: %[[C1:.*]] = constant dense<0.000000e+00> : vector<3x2xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x1x2xf32>
+// CHECK: %[[T1:.*]] = vector.extract %[[T0]][0] : vector<1x2xf32>
+// CHECK: %[[T2:.*]] = vector.insert %[[T1]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T3:.*]] = vector.insert %[[T1]], %[[T2]] [1] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T4:.*]] = vector.insert %[[T1]], %[[T3]] [2] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3x2xf32> into vector<4x3x2xf32>
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<4x1x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[T6]][0] : vector<1x2xf32>
+// CHECK: %[[T8:.*]] = vector.insert %[[T7]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T9:.*]] = vector.insert %[[T7]], %[[T8]] [1] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T10:.*]] = vector.insert %[[T7]], %[[T9]] [2] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3x2xf32> into vector<4x3x2xf32>
+// CHECK: %[[T12:.*]] = vector.extract %[[A]][2] : vector<4x1x2xf32>
+// CHECK: %[[T13:.*]] = vector.extract %[[T12]][0] : vector<1x2xf32>
+// CHECK: %[[T14:.*]] = vector.insert %[[T13]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T15:.*]] = vector.insert %[[T13]], %[[T14]] [1] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T16:.*]] = vector.insert %[[T13]], %[[T15]] [2] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T17:.*]] = vector.insert %[[T16]], %[[T11]] [2] : vector<3x2xf32> into vector<4x3x2xf32>
+// CHECK: %[[T18:.*]] = vector.extract %[[A]][3] : vector<4x1x2xf32>
+// CHECK: %[[T19:.*]] = vector.extract %[[T18]][0] : vector<1x2xf32>
+// CHECK: %[[T20:.*]] = vector.insert %[[T19]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T21:.*]] = vector.insert %[[T19]], %[[T20]] [1] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T22:.*]] = vector.insert %[[T19]], %[[T21]] [2] : vector<2xf32> into vector<3x2xf32>
+// CHECK: %[[T23:.*]] = vector.insert %[[T22]], %[[T17]] [3] : vector<3x2xf32> into vector<4x3x2xf32>
+// CHECK: return %[[T23]] : vector<4x3x2xf32>
+
+func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32> {
+ %0 = vector.broadcast %arg0 : vector<4x1x2xf32> to vector<4x3x2xf32>
+ return %0 : vector<4x3x2xf32>
+}
More information about the Mlir-commits
mailing list