[Mlir-commits] [mlir] 7ce315d - [mlir][vector] Improve shape_cast lowering (#140800)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 5 10:18:42 PDT 2025
Author: James Newling
Date: 2025-06-05T10:18:38-07:00
New Revision: 7ce315d14aa5c084574cc3a17552625f322e1d16
URL: https://github.com/llvm/llvm-project/commit/7ce315d14aa5c084574cc3a17552625f322e1d16
DIFF: https://github.com/llvm/llvm-project/commit/7ce315d14aa5c084574cc3a17552625f322e1d16.diff
LOG: [mlir][vector] Improve shape_cast lowering (#140800)
Before this PR, a rank-m -> rank-n vector.shape_cast with m,n>1 was
lowered to extracts/inserts of single elements, so that a shape_cast on
a vector with N elements would always require N extracts/inserts. While
this is necessary in the worst case scenario it is sometimes possible to
use fewer, larger extracts/inserts. Specifically, the largest common
suffix on the shapes of the source and result can be extracted/inserted.
For example:
```mlir
%0 = vector.shape_cast %arg0 : vector<10x2x3xf32> to vector<2x5x2x3xf32>
```
has common suffix of shape `2x3`. Before this PR, this would be lowered
to 60 extract/insert pairs with extracts of the form
`vector.extract %arg0 [a, b, c] : f32 from vector<10x2x3xf32>`. With
this PR it is 10 extract/insert pairs with extracts of the form
`vector.extract %arg0 [a] : vector<2x3xf32> from vector<10x2x3xf32>`.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 23324a007377e..39c16fab21c4e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -21,177 +21,298 @@
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include <numeric>
#define DEBUG_TYPE "vector-shape-cast-lowering"
using namespace mlir;
-using namespace mlir::vector;
-
-/// Increments n-D `indices` by `step` starting from the innermost dimension.
-static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
- int step = 1) {
- for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
- assert(indices[dim] < vecType.getDimSize(dim) &&
- "Indices are out of bound");
- indices[dim] += step;
- if (indices[dim] < vecType.getDimSize(dim))
- break;
- indices[dim] = 0;
- step = 1;
+/// Perform the inplace update
+/// rhs <- lhs + rhs
+///
+/// where `rhs` is a number expressed in mixed base `base` with most signficant
+/// dimensions on the left. For example if `rhs` is {a,b,c} and `base` is
+/// {5,3,2} then `rhs` has value a*3*2 + b*2 + c.
+///
+/// Some examples where `base` is {5,3,2}:
+/// rhs = {0,0,0}, lhs = 1 --> rhs = {0,0,1}
+/// rhs = {0,0,1}, lhs = 1 --> rhs = {0,1,0}
+/// rhs = {0,0,0}, lhs = 25 --> rhs = {4,0,1}
+///
+/// Invalid:
+/// rhs = {0,0,2}, lhs = 1 : rhs not in base {5,3,2}
+///
+/// Overflows not handled correctly:
+/// rhs = {4,2,1}, lhs = 2 --> rhs = {0,0,0} (not {0,0,1})
+static void inplaceAdd(int64_t lhs, ArrayRef<int64_t> base,
+ MutableArrayRef<int64_t> rhs) {
+
+ // For dimensions in [numIndices - 1, ..., 3, 2, 1, 0]:
+ for (int dim : llvm::reverse(llvm::seq<int>(0, rhs.size()))) {
+ int64_t dimBase = base[dim];
+ assert(rhs[dim] < dimBase && "rhs not in base");
+
+ int64_t incremented = rhs[dim] + lhs;
+
+ // If the incremented value excedes the dimension base, we must spill to the
+ // next most significant dimension and repeat (we might need to spill to
+ // more significant dimensions multiple times).
+ lhs = incremented / dimBase;
+ rhs[dim] = incremented % dimBase;
+ if (lhs == 0)
+ break;
}
}
namespace {
-/// ShapeOp n-D -> 1-D downcast serves the purpose of flattening N-D to 1-D
-/// vectors progressively. This iterates over the n-1 major dimensions of the
-/// n-D vector and performs rewrites into:
-/// vector.extract from n-D + vector.insert_strided_slice offset into 1-D
-class ShapeCastOpNDDownCastRewritePattern
- : public OpRewritePattern<vector::ShapeCastOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(vector::ShapeCastOp op,
- PatternRewriter &rewriter) const override {
- auto sourceVectorType = op.getSourceVectorType();
- auto resultVectorType = op.getResultVectorType();
- if (sourceVectorType.isScalable() || resultVectorType.isScalable())
- return failure();
+/// shape_cast is converted to a sequence of extract, extract_strided_slice,
+/// insert_strided_slice, and insert operations. The running example will be:
+///
+/// %0 = vector.shape_cast %arg0 :
+/// vector<2x2x3x4x7x11xi8> to vector<8x6x7x11xi8>
+///
+/// In this example the source and result shapes share a common suffix of 7x11.
+/// This means we can always decompose the shape_cast into extract, insert, and
+/// their strided equivalents, on vectors with shape suffix 7x11.
+///
+/// The greatest common divisor (gcd) of the first dimension preceding the
+/// common suffix is gcd(4,6) = 2. The algorithm implemented here will operate
+/// on vectors with shapes that are `multiples` of (what we define as) the
+/// 'atomic shape', 2x7x11. The atomic shape is `gcd` x `common-suffix`.
+///
+/// vector<2x2x3x4x7x11xi8> to
+/// vector<8x6x7x11xi8>
+/// | ||||
+/// | ++++------------> common suffix of 7x11
+/// +-----------------> gcd(4,6) is 2 | |
+/// | | |
+/// v v v
+/// atomic shape <----- 2x7x11
+///
+///
+///
+/// The decomposition implemented in this pattern consists of a sequence of
+/// repeated steps:
+///
+/// (1) Extract vectors from the suffix of the source.
+/// In our example this is 2x2x3x4x7x11 -> 4x7x11.
+///
+/// (2) Do extract_strided_slice down to the atomic shape.
+/// In our example this is 4x7x11 -> 2x7x11.
+///
+/// (3) Do insert_strided_slice to the suffix of the result.
+/// In our example this is 2x7x11 -> 6x7x11.
+///
+/// (4) insert these vectors into the result vector.
+/// In our example this is 6x7x11 -> 8x6x7x11.
+///
+/// These steps occur with
diff erent periods. In this example
+/// (1) occurs 12 times,
+/// (2) and (3) occur 24 times, and
+/// (4) occurs 8 times.
+///
+/// Two special cases are handled independently in this pattern
+/// (i) A shape_cast that just does leading 1 insertion/removal
+/// (ii) A shape_cast where the gcd is 1.
+///
+/// These 2 cases can have more compact IR generated by not using the generic
+/// algorithm described above.
+///
+class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
+
+ // Case (i) of description.
+ // Assumes source and result shapes are identical up to some leading ones.
+ static LogicalResult leadingOnesLowering(vector::ShapeCastOp shapeCast,
+ PatternRewriter &rewriter) {
+
+ const Location loc = shapeCast.getLoc();
+ const VectorType sourceType = shapeCast.getSourceVectorType();
+ const VectorType resultType = shapeCast.getResultVectorType();
+
+ const int64_t sourceRank = sourceType.getRank();
+ const int64_t resultRank = resultType.getRank();
+ const int64_t delta = sourceRank - resultRank;
+ const int64_t sourceLeading = delta > 0 ? delta : 0;
+ const int64_t resultLeading = delta > 0 ? 0 : -delta;
+
+ const Value source = shapeCast.getSource();
+ const Value poison = rewriter.create<ub::PoisonOp>(loc, resultType);
+ const Value extracted = rewriter.create<vector::ExtractOp>(
+ loc, source, SmallVector<int64_t>(sourceLeading, 0));
+ const Value result = rewriter.create<vector::InsertOp>(
+ loc, extracted, poison, SmallVector<int64_t>(resultLeading, 0));
+
+ rewriter.replaceOp(shapeCast, result);
+ return success();
+ }
- int64_t srcRank = sourceVectorType.getRank();
- int64_t resRank = resultVectorType.getRank();
- if (srcRank < 2 || resRank != 1)
- return failure();
+ // Case (ii) of description.
+ // Assumes a shape_cast where the suffix shape of the source starting at
+ // `sourceDim` and the suffix shape of the result starting at `resultDim` are
+ // identical.
+ static LogicalResult noStridedSliceLowering(vector::ShapeCastOp shapeCast,
+ int64_t sourceDim,
+ int64_t resultDim,
+ PatternRewriter &rewriter) {
- // Compute the number of 1-D vector elements involved in the reshape.
- int64_t numElts = 1;
- for (int64_t dim = 0; dim < srcRank - 1; ++dim)
- numElts *= sourceVectorType.getDimSize(dim);
+ const Location loc = shapeCast.getLoc();
- auto loc = op.getLoc();
- SmallVector<int64_t> srcIdx(srcRank - 1, 0);
- SmallVector<int64_t> resIdx(resRank, 0);
- int64_t extractSize = sourceVectorType.getShape().back();
- Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
+ const Value source = shapeCast.getSource();
+ const ArrayRef<int64_t> sourceShape =
+ shapeCast.getSourceVectorType().getShape();
- // Compute the indices of each 1-D vector element of the source extraction
- // and destination slice insertion and generate such instructions.
- for (int64_t i = 0; i < numElts; ++i) {
- if (i != 0) {
- incIdx(srcIdx, sourceVectorType, /*step=*/1);
- incIdx(resIdx, resultVectorType, /*step=*/extractSize);
- }
+ const VectorType resultType = shapeCast.getResultVectorType();
+ const ArrayRef<int64_t> resultShape = resultType.getShape();
- Value extract =
- rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
- result = rewriter.create<vector::InsertStridedSliceOp>(
- loc, extract, result,
- /*offsets=*/resIdx, /*strides=*/1);
- }
+ const int64_t nSlices =
+ std::accumulate(sourceShape.begin(), sourceShape.begin() + sourceDim, 1,
+ std::multiplies<int64_t>());
- rewriter.replaceOp(op, result);
- return success();
- }
-};
+ SmallVector<int64_t> extractIndex(sourceDim, 0);
+ SmallVector<int64_t> insertIndex(resultDim, 0);
+ Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
-/// ShapeOp 1-D -> n-D upcast serves the purpose of unflattening n-D from 1-D
-/// vectors progressively. This iterates over the n-1 major dimension of the n-D
-/// vector and performs rewrites into:
-/// vector.extract_strided_slice from 1-D + vector.insert into n-D
-/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
-class ShapeCastOpNDUpCastRewritePattern
- : public OpRewritePattern<vector::ShapeCastOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
+ for (int i = 0; i < nSlices; ++i) {
+ Value extracted =
+ rewriter.create<vector::ExtractOp>(loc, source, extractIndex);
- LogicalResult matchAndRewrite(vector::ShapeCastOp op,
- PatternRewriter &rewriter) const override {
- auto sourceVectorType = op.getSourceVectorType();
- auto resultVectorType = op.getResultVectorType();
- if (sourceVectorType.isScalable() || resultVectorType.isScalable())
- return failure();
-
- int64_t srcRank = sourceVectorType.getRank();
- int64_t resRank = resultVectorType.getRank();
- if (srcRank != 1 || resRank < 2)
- return failure();
-
- // Compute the number of 1-D vector elements involved in the reshape.
- int64_t numElts = 1;
- for (int64_t dim = 0; dim < resRank - 1; ++dim)
- numElts *= resultVectorType.getDimSize(dim);
-
- // Compute the indices of each 1-D vector element of the source slice
- // extraction and destination insertion and generate such instructions.
- auto loc = op.getLoc();
- SmallVector<int64_t> srcIdx(srcRank, 0);
- SmallVector<int64_t> resIdx(resRank - 1, 0);
- int64_t extractSize = resultVectorType.getShape().back();
- Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
- for (int64_t i = 0; i < numElts; ++i) {
- if (i != 0) {
- incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
- incIdx(resIdx, resultVectorType, /*step=*/1);
- }
+ result = rewriter.create<vector::InsertOp>(loc, extracted, result,
+ insertIndex);
- Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, op.getSource(), /*offsets=*/srcIdx, /*sizes=*/extractSize,
- /*strides=*/1);
- result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
+ inplaceAdd(1, sourceShape.take_front(sourceDim), extractIndex);
+ inplaceAdd(1, resultShape.take_front(resultDim), insertIndex);
}
- rewriter.replaceOp(op, result);
+ rewriter.replaceOp(shapeCast, result);
return success();
}
-};
-// We typically should not lower general shape cast operations into data
-// movement instructions, since the assumption is that these casts are
-// optimized away during progressive lowering. For completeness, however,
-// we fall back to a reference implementation that moves all elements
-// into the right place if we get here.
-class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- auto sourceVectorType = op.getSourceVectorType();
- auto resultVectorType = op.getResultVectorType();
+ VectorType sourceType = op.getSourceVectorType();
+ VectorType resultType = op.getResultVectorType();
+
+ if (sourceType.isScalable() || resultType.isScalable())
+ return rewriter.notifyMatchFailure(
+ op,
+ "shape_cast where vectors are scalable not handled by this pattern");
+
+ const ArrayRef<int64_t> sourceShape = sourceType.getShape();
+ const ArrayRef<int64_t> resultShape = resultType.getShape();
+ const int64_t sourceRank = sourceType.getRank();
+ const int64_t resultRank = resultType.getRank();
+ const int64_t numElms = sourceType.getNumElements();
+ const Value source = op.getSource();
+
+ // Set the first dimension (starting at the end) in the source and result
+ // respectively where the dimension sizes
diff er. Using the running example:
+ //
+ // dimensions: [0 1 2 3 4 5 ] [0 1 2 3 ]
+ // shapes: (2,2,3,4,7,11) -> (8,6,7,11)
+ // ^ ^
+ // | |
+ // sourceSuffixStartDim is 3 |
+ // |
+ // resultSuffixStartDim is 1
+ int64_t sourceSuffixStartDim = sourceRank - 1;
+ int64_t resultSuffixStartDim = resultRank - 1;
+ while (sourceSuffixStartDim >= 0 && resultSuffixStartDim >= 0 &&
+ (sourceType.getDimSize(sourceSuffixStartDim) ==
+ resultType.getDimSize(resultSuffixStartDim))) {
+ --sourceSuffixStartDim;
+ --resultSuffixStartDim;
+ }
- if (sourceVectorType.isScalable() || resultVectorType.isScalable())
- return failure();
-
- // Special case for n-D / 1-D lowerings with better implementations.
- int64_t srcRank = sourceVectorType.getRank();
- int64_t resRank = resultVectorType.getRank();
- if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1))
- return failure();
-
- // Generic ShapeCast lowering path goes all the way down to unrolled scalar
- // extract/insert chains.
- int64_t numElts = 1;
- for (int64_t r = 0; r < srcRank; r++)
- numElts *= sourceVectorType.getDimSize(r);
- // Replace with data movement operations:
- // x[0,0,0] = y[0,0]
- // x[0,0,1] = y[0,1]
- // x[0,1,0] = y[0,2]
- // etc., incrementing the two index vectors "row-major"
- // within the source and result shape.
- SmallVector<int64_t> srcIdx(srcRank, 0);
- SmallVector<int64_t> resIdx(resRank, 0);
- Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
- for (int64_t i = 0; i < numElts; i++) {
- if (i != 0) {
- incIdx(srcIdx, sourceVectorType);
- incIdx(resIdx, resultVectorType);
+ // This is the case (i) where there are just some leading ones to contend
+ // with in the source or result. It can be handled with a single
+ // extract/insert pair.
+ if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0)
+ return leadingOnesLowering(op, rewriter);
+
+ const int64_t sourceSuffixStartDimSize =
+ sourceType.getDimSize(sourceSuffixStartDim);
+ const int64_t resultSuffixStartDimSize =
+ resultType.getDimSize(resultSuffixStartDim);
+ const int64_t greatestCommonDivisor =
+ std::gcd(sourceSuffixStartDimSize, resultSuffixStartDimSize);
+ const int64_t stridedSliceRank = sourceRank - sourceSuffixStartDim;
+ const size_t extractPeriod =
+ sourceSuffixStartDimSize / greatestCommonDivisor;
+ const size_t insertPeriod =
+ resultSuffixStartDimSize / greatestCommonDivisor;
+
+ SmallVector<int64_t> atomicShape(sourceShape.begin() + sourceSuffixStartDim,
+ sourceShape.end());
+ atomicShape[0] = greatestCommonDivisor;
+
+ const int64_t numAtomicElms = std::accumulate(
+ atomicShape.begin(), atomicShape.end(), 1, std::multiplies<int64_t>());
+ const size_t nAtomicSlices = numElms / numAtomicElms;
+
+ // This is the case (ii) where the strided dimension size is 1. More compact
+ // IR is generated in this case if we just extract and insert the elements
+ // directly. In other words, we don't use extract_strided_slice and
+ // insert_strided_slice.
+ if (greatestCommonDivisor == 1)
+ return noStridedSliceLowering(op, sourceSuffixStartDim + 1,
+ resultSuffixStartDim + 1, rewriter);
+
+ // The insert_strided_slice result's type
+ const ArrayRef<int64_t> insertStridedShape =
+ resultShape.drop_front(resultSuffixStartDim);
+ const VectorType insertStridedType =
+ VectorType::get(insertStridedShape, resultType.getElementType());
+
+ SmallVector<int64_t> extractIndex(sourceSuffixStartDim, 0);
+ SmallVector<int64_t> insertIndex(resultSuffixStartDim, 0);
+ SmallVector<int64_t> extractOffsets(stridedSliceRank, 0);
+ SmallVector<int64_t> insertOffsets(stridedSliceRank, 0);
+ const SmallVector<int64_t> sizes(stridedSliceRank, 1);
+
+ Value extracted = {};
+ Value extractedStrided = {};
+ Value insertedSlice = {};
+ Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
+ const Value partResult =
+ rewriter.create<ub::PoisonOp>(loc, insertStridedType);
+
+ for (size_t i = 0; i < nAtomicSlices; ++i) {
+
+ const size_t extractStridedPhase = i % extractPeriod;
+ const size_t insertStridedPhase = i % insertPeriod;
+
+ // vector.extract
+ if (extractStridedPhase == 0) {
+ extracted =
+ rewriter.create<vector::ExtractOp>(loc, source, extractIndex);
+ inplaceAdd(1, sourceShape.take_front(sourceSuffixStartDim),
+ extractIndex);
}
- Value extract =
- rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
- result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
+ // vector.extract_strided_slice
+ extractOffsets[0] = extractStridedPhase * greatestCommonDivisor;
+ extractedStrided = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, extracted, extractOffsets, atomicShape, sizes);
+
+ // vector.insert_strided_slice
+ if (insertStridedPhase == 0) {
+ insertedSlice = partResult;
+ }
+ insertOffsets[0] = insertStridedPhase * greatestCommonDivisor;
+ insertedSlice = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, extractedStrided, insertedSlice, insertOffsets, sizes);
+
+ // vector.insert
+ if (insertStridedPhase + 1 == insertPeriod) {
+ result = rewriter.create<vector::InsertOp>(loc, insertedSlice, result,
+ insertIndex);
+ inplaceAdd(1, resultType.getShape().take_front(resultSuffixStartDim),
+ insertIndex);
+ }
}
rewriter.replaceOp(op, result);
return success();
@@ -252,7 +373,8 @@ class ScalableShapeCastOpRewritePattern
// from >= 2-D scalable vectors or scalable vectors of fixed vectors.
if (!isTrailingDimScalable(sourceVectorType) ||
!isTrailingDimScalable(resultVectorType)) {
- return failure();
+ return rewriter.notifyMatchFailure(
+ op, "trailing dims are not scalable, not handled by this pattern");
}
// The sizes of the trailing dimension of the source and result vectors, the
@@ -329,8 +451,8 @@ class ScalableShapeCastOpRewritePattern
// 4. Increment the insert/extract indices, stepping by minExtractionSize
// for the trailing dimensions.
- incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize);
- incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize);
+ inplaceAdd(minExtractionSize, sourceVectorType.getShape(), srcIdx);
+ inplaceAdd(minExtractionSize, resultVectorType.getShape(), resIdx);
}
rewriter.replaceOp(op, result);
@@ -347,8 +469,6 @@ class ScalableShapeCastOpRewritePattern
void mlir::vector::populateVectorShapeCastLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<ShapeCastOpNDDownCastRewritePattern,
- ShapeCastOpNDUpCastRewritePattern, ShapeCastOpRewritePattern,
- ScalableShapeCastOpRewritePattern>(patterns.getContext(),
- benefit);
+ patterns.add<ShapeCastOpRewritePattern, ScalableShapeCastOpRewritePattern>(
+ patterns.getContext(), benefit);
}
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index ef32f8c6a1cdb..5011d8b2b2ef6 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -1,145 +1,392 @@
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
// CHECK-LABEL: func @nop_shape_cast
-// CHECK-SAME: %[[A:.*]]: vector<16xf32>
-// CHECK: return %[[A]] : vector<16xf32>
+// CHECK-SAME: %[[A:.*]]: vector<16xf32>
+// CHECK: return %[[A]] : vector<16xf32>
func.func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32>
return %0 : vector<16xf32>
}
// CHECK-LABEL: func @cancel_shape_cast
-// CHECK-SAME: %[[A:.*]]: vector<16xf32>
-// CHECK: return %[[A]] : vector<16xf32>
-
+// CHECK-SAME: %[[A:.*]]: vector<16xf32>
+// CHECK: return %[[A]] : vector<16xf32>
func.func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
%1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32>
return %1 : vector<16xf32>
}
-// Shape up and downcasts for 2-D vectors, for supporting conversion to
-// llvm.matrix operations
-// CHECK-LABEL: func @shape_casts
-func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) {
- // CHECK-DAG: %[[ub22:.*]] = ub.poison : vector<2x2xf32>
- // CHECK-DAG: %[[ub:.*]] = ub.poison : vector<4xf32>
- // CHECK: %[[ex0:.*]] = vector.extract %{{.*}}[0] : vector<2xf32> from vector<2x2xf32>
- //
- // CHECK: %[[in0:.*]] = vector.insert_strided_slice %[[ex0]], %[[ub]]
- // CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
- //
- // CHECK: %[[ex1:.*]] = vector.extract %{{.*}}[1] : vector<2xf32> from vector<2x2xf32>
- //
- // CHECK: %[[in2:.*]] = vector.insert_strided_slice %[[ex1]], %[[in0]]
- // CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
- //
+// Collapse 2-D to 1-D.
+// CHECK-LABEL: func @shape_cast_2d1d
+// CHECK-SAME: %[[A:.*]]: vector<2x2xf32>) -> vector<4xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<4xf32>
+//
+// CHECK: %[[EX0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[IN0:.*]] = vector.insert_strided_slice %[[EX0]], %[[UB]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+//
+// CHECK: %[[EX1:.*]] = vector.extract %{{.*}}[1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[IN2:.*]] = vector.insert_strided_slice %[[EX1]], %[[IN0]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+// CHECK: return %[[IN2]] : vector<4xf32>
+func.func @shape_cast_2d1d(%a: vector<2x2xf32>) -> (vector<4xf32>) {
%0 = vector.shape_cast %a : vector<2x2xf32> to vector<4xf32>
- // CHECK: %[[add:.*]] = arith.addf %[[in2]], %[[in2]] : vector<4xf32>
- %r0 = arith.addf %0, %0: vector<4xf32>
- //
- // CHECK: %[[ss0:.*]] = vector.extract_strided_slice %[[add]]
- // CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} :
- // CHECK-SAME: vector<4xf32> to vector<2xf32>
- //
- // CHECK: %[[res0:.*]] = vector.insert %[[ss0]], %[[ub22]] [0] :
- // CHECK-SAME: vector<2xf32> into vector<2x2xf32>
- //
- // CHECK: %[[s2:.*]] = vector.extract_strided_slice %[[add]]
- // CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} :
- // CHECK-SAME: vector<4xf32> to vector<2xf32>
- //
- // CHECK: %[[res1:.*]] = vector.insert %[[s2]], %[[res0]] [1] :
- // CHECK-SAME: vector<2xf32> into vector<2x2xf32>
- //
- %1 = vector.shape_cast %r0 : vector<4xf32> to vector<2x2xf32>
- // CHECK: return %[[add]], %[[res1]] : vector<4xf32>, vector<2x2xf32>
- return %r0, %1 : vector<4xf32>, vector<2x2xf32>
-}
-
-// CHECK-LABEL: func @shape_cast_2d2d
-// CHECK-SAME: %[[A:.*]]: vector<3x2xf32>
-// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
-// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<3x2xf32>
-// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[UB]] [0, 0] : f32 into vector<2x3xf32>
-// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : f32 from vector<3x2xf32>
-// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<3x2xf32>
-// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32>
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : f32 from vector<3x2xf32>
-// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<3x2xf32>
-// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32>
-// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : f32 from vector<3x2xf32>
-// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32>
-// CHECK: return %[[T11]] : vector<2x3xf32>
-
-func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
- %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32>
- return %s : vector<2x3xf32>
+ return %0 : vector<4xf32>
}
+// Collapse 3-D to 1-D.
// CHECK-LABEL: func @shape_cast_3d1d
-// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
-// CHECK: %[[UB:.*]] = ub.poison : vector<6xf32>
-// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2xf32> from vector<1x3x2xf32>
-// CHECK: %[[T1:.*]] = vector.insert_strided_slice %[[T0]], %[[UB]]
-// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<6xf32>
-// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2xf32> from vector<1x3x2xf32>
-// CHECK: %[[T3:.*]] = vector.insert_strided_slice %[[T2]], %[[T1]]
-// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<6xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2xf32> from vector<1x3x2xf32>
-// CHECK: %[[T5:.*]] = vector.insert_strided_slice %[[T4]], %[[T3]]
-// CHECK-SAME: {offsets = [4], strides = [1]} : vector<2xf32> into vector<6xf32>
-// CHECK: return %[[T5]] : vector<6xf32>
-
+// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
+// CHECK: %[[UB:.*]] = ub.poison : vector<6xf32>
+//
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2xf32> from vector<1x3x2xf32>
+// CHECK: %[[T1:.*]] = vector.insert_strided_slice %[[T0]], %[[UB]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<6xf32>
+//
+// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2xf32> from vector<1x3x2xf32>
+// CHECK: %[[T3:.*]] = vector.insert_strided_slice %[[T2]], %[[T1]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<6xf32>
+//
+// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2xf32> from vector<1x3x2xf32>
+// CHECK: %[[T5:.*]] = vector.insert_strided_slice %[[T4]], %[[T3]]
+// CHECK-SAME: {offsets = [4], strides = [1]} : vector<2xf32> into vector<6xf32>
+// CHECK: return %[[T5]] : vector<6xf32>
func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
%s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32>
return %s : vector<6xf32>
}
-// CHECK-LABEL: func @shape_cast_1d3d
-// CHECK-SAME: %[[A:.*]]: vector<6xf32>
-// CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
-// CHECK: %[[T0:.*]] = vector.extract_strided_slice %[[A]]
-// CHECK-SAME: {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
-// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[UB]] [0, 0] : vector<3xf32> into vector<2x1x3xf32>
-// CHECK: %[[T2:.*]] = vector.extract_strided_slice %[[A]]
-// CHECK: {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
-// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : vector<3xf32> into vector<2x1x3xf32>
-// CHECK: return %[[T3]] : vector<2x1x3xf32>
+// Expand 1-D to 2-D.
+// CHECK-LABEL: func.func @shape_cast_1d2d(
+// CHECK-SAME: %[[A:.*]]: vector<4xf32>) -> vector<2x2xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<2x2xf32>
+//
+// CHECK: %[[SS0:.*]] = vector.extract_strided_slice %[[A]]
+// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} :
+// CHECK-SAME: vector<4xf32> to vector<2xf32>
+// CHECK: %[[res0:.*]] = vector.insert %[[SS0]], %[[UB]] [0] :
+// CHECK-SAME: vector<2xf32> into vector<2x2xf32>
+//
+// CHECK: %[[SS2:.*]] = vector.extract_strided_slice %[[A]]
+// CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} :
+// CHECK-SAME: vector<4xf32> to vector<2xf32>
+// CHECK: %[[res1:.*]] = vector.insert %[[SS2]], %[[res0]] [1] :
+// CHECK-SAME: vector<2xf32> into vector<2x2xf32>
+// CHECK: return %[[res1]] : vector<2x2xf32>
+func.func @shape_cast_1d2d(%a: vector<4xf32>) -> (vector<2x2xf32>) {
+ %1 = vector.shape_cast %a: vector<4xf32> to vector<2x2xf32>
+ return %1 : vector<2x2xf32>
+}
+// Expand 1-D to 3-D.
+// CHECK-LABEL: func @shape_cast_1d3d
+// CHECK-SAME: %[[A:.*]]: vector<6xf32>
+// CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
+//
+// CHECK: %[[T0:.*]] = vector.extract_strided_slice %[[A]]
+// CHECK-SAME: {offsets = [0], sizes = [3], strides = [1]} :
+// CHECK-SAME: vector<6xf32> to vector<3xf32>
+// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[UB]] [0, 0] :
+// CHECK-SAME: vector<3xf32> into vector<2x1x3xf32>
+//
+// CHECK: %[[T2:.*]] = vector.extract_strided_slice %[[A]]
+// CHECK-SAME: {offsets = [3], sizes = [3], strides = [1]} :
+// CHECK-SAME: vector<6xf32> to vector<3xf32>
+// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] :
+// CHECK-SAME: vector<3xf32> into vector<2x1x3xf32>
+// CHECK: return %[[T3]] : vector<2x1x3xf32>
func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
%s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32>
return %s : vector<2x1x3xf32>
}
-// CHECK-LABEL: func.func @shape_cast_0d1d(
-// CHECK-SAME: %[[ARG0:.*]]: vector<f32>) -> vector<1xf32> {
-// CHECK: %[[UB:.*]] = ub.poison : vector<1xf32>
-// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][] : f32 from vector<f32>
-// CHECK: %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [0] : f32 into vector<1xf32>
-// CHECK: return %[[RES]] : vector<1xf32>
-// CHECK: }
+// 2-D to 2-D where the inner-most dimensions have no common factors. This
+// case requires scalar element by element extraction and insertion.
+// CHECK-LABEL: func @shape_cast_2d2d
+// CHECK-SAME: %[[A:.*]]: vector<3x2xf32>
+// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
+//
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<3x2xf32>
+// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[UB]] [0, 0] :
+//
+// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : f32 from vector<3x2xf32>
+// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] :
+//
+// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<3x2xf32>
+// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] :
+//
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : f32 from vector<3x2xf32>
+// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] :
+//
+// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<3x2xf32>
+// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] :
+//
+// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : f32 from vector<3x2xf32>
+// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] :
+//
+// CHECK: return %[[T11]] : vector<2x3xf32>
+func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
+ %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32>
+ return %s : vector<2x3xf32>
+}
+// CHECK-LABEL: func.func @shape_cast_0d1d(
+// CHECK-SAME: %[[A:.*]]: vector<f32>) -> vector<1xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<1xf32>
+//
+// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[A]][] : f32 from vector<f32>
+// CHECK: %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [0] :
+// CHECK: return %[[RES]] : vector<1xf32>
func.func @shape_cast_0d1d(%arg0 : vector<f32>) -> vector<1xf32> {
%s = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
return %s : vector<1xf32>
}
-// CHECK-LABEL: func.func @shape_cast_1d0d(
-// CHECK-SAME: %[[ARG0:.*]]: vector<1xf32>) -> vector<f32> {
-// CHECK: %[[UB:.*]] = ub.poison : vector<f32>
-// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
-// CHECK: %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [] : f32 into vector<f32>
-// CHECK: return %[[RES]] : vector<f32>
-// CHECK: }
-
+// CHECK-LABEL: func.func @shape_cast_1d0d(
+// CHECK-SAME: %[[A:.*]]: vector<1xf32>) -> vector<f32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<f32>
+//
+// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[A]][0] : f32 from vector<1xf32>
+// CHECK: %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [] :
+// CHECK: return %[[RES]] : vector<f32>
func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
%s = vector.shape_cast %arg0 : vector<1xf32> to vector<f32>
return %s : vector<f32>
}
+// The shapes have 2 inner dimension sizes in common, so the extract result is rank-2.
+// CHECK-LABEL: func.func @squeeze_out_prefix_unit_dim(
+// CHECK-SAME: %[[A:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
+//
+// CHECK: %[[EXTRACTED:.*]] = vector.extract %[[A]][0] :
+// CHECK-SAME: vector<2x3xf32> from vector<1x2x3xf32>
+// CHECK: return %[[EXTRACTED]] : vector<2x3xf32>
+func.func @squeeze_out_prefix_unit_dim(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> {
+ %s = vector.shape_cast %arg0 : vector<1x2x3xf32> to vector<2x3xf32>
+ return %s : vector<2x3xf32>
+}
+
+// The shapes have 1 inner dimension size in common, so the extract results are rank-1.
+// CHECK-LABEL: func.func @squeeze_out_middle_unit_dim(
+// CHECK-SAME: %[[A:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
+//
+// CHECK: %[[E0:.*]] = vector.extract %[[A]][0, 0] : vector<3xf32>
+// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0] :
+//
+// CHECK: %[[E1:.*]] = vector.extract %[[A]][1, 0] : vector<3xf32>
+// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1] :
+// CHECK: return %[[I1]] : vector<2x3xf32>
+func.func @squeeze_out_middle_unit_dim(%arg0 : vector<2x1x3xf32>) -> vector<2x3xf32> {
+ %s = vector.shape_cast %arg0 : vector<2x1x3xf32> to vector<2x3xf32>
+ return %s : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func.func @prepend_unit_dim(
+// CHECK-SAME: %[[A:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<1x2x3xf32>
+//
+// CHECK: %[[I0:.*]] = vector.insert %[[A]], %[[UB]] [0]
+// CHECK: return %[[I0]] : vector<1x2x3xf32>
+func.func @prepend_unit_dim(%arg0 : vector<2x3xf32>) -> vector<1x2x3xf32> {
+ %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<1x2x3xf32>
+ return %s : vector<1x2x3xf32>
+}
+
+// CHECK-LABEL: func.func @insert_middle_unit_dim(
+// CHECK-SAME: %[[A:.*]]: vector<2x3xf32>) -> vector<2x1x3xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
+//
+// CHECK: %[[E0:.*]] = vector.extract %[[A]][0] : vector<3xf32>
+// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0, 0] : vector<3xf32>
+//
+// CHECK: %[[E1:.*]] = vector.extract %[[A]][1] : vector<3xf32>
+// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1, 0] : vector<3xf32>
+// CHECK: return %[[I1]] : vector<2x1x3xf32>
+func.func @insert_middle_unit_dim(%arg0 : vector<2x3xf32>) -> vector<2x1x3xf32> {
+ %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<2x1x3xf32>
+ return %s : vector<2x1x3xf32>
+}
+
+// CHECK-LABEL: func.func @postpend_unit_dims(
+// CHECK-SAME: %[[A:.*]]: vector<4xf32>) -> vector<4x1x1xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<4x1x1xf32>
+// CHECK: %[[E0:.*]] = vector.extract %[[A]][0] : f32 from vector<4xf32>
+// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0, 0, 0] : f32 into vector<4x1x1xf32>
+// CHECK: %[[E1:.*]] = vector.extract %[[A]][1] : f32 from vector<4xf32>
+// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1, 0, 0] : f32 into vector<4x1x1xf32>
+// CHECK: vector.extract %[[A]][2]
+// CHECK: vector.insert {{.*}} [2, 0, 0]
+// CHECK: vector.extract %[[A]][3]
+// CHECK: vector.insert {{.*}} [3, 0, 0]
+// CHECK: return
+func.func @postpend_unit_dims(%arg0 : vector<4xf32>) -> vector<4x1x1xf32> {
+ %s = vector.shape_cast %arg0 : vector<4xf32> to vector<4x1x1xf32>
+ return %s : vector<4x1x1xf32>
+}
+
+// CHECK-LABEL: func.func @expand_inner_dims(
+// CHECK-SAME: %[[A:.*]]: vector<2x10xf32>) -> vector<2x2x5xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<2x2x5xf32>
+//
+// CHECK: %[[E0:.*]] = vector.extract %[[A]][0] : vector<10xf32>
+// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[E0]]
+// CHECK-SAME: {offsets = [0], sizes = [5], {{.*}} to vector<5xf32>
+// CHECK: %[[I0:.*]] = vector.insert %[[S0]], %[[UB]] [0, 0]
+//
+// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[E0]]
+// CHECK-SAME: {offsets = [5], sizes = [5], {{.*}} to vector<5xf32>
+// CHECK: %[[I1:.*]] = vector.insert %[[S1]], %[[I0]] [0, 1]
+//
+// CHECK: %[[E1:.*]] = vector.extract %[[A]][1] : vector<10xf32>
+// CHECK: %[[S2:.*]] = vector.extract_strided_slice %[[E1]]
+// CHECK-SAME: {offsets = [0], sizes = [5], {{.*}} to vector<5xf32>
+// CHECK: %[[I2:.*]] = vector.insert %[[S2]], %[[I1]] [1, 0]
+//
+// CHECK: %[[S3:.*]] = vector.extract_strided_slice %[[E1]]
+// CHECK-SAME: {offsets = [5], sizes = [5], {{.*}} to vector<5xf32>
+// CHECK: %[[I3:.*]] = vector.insert %[[S3]], %[[I2]] [1, 1]
+// CHECK: return %[[I3]] : vector<2x2x5xf32>
+func.func @expand_inner_dims(%arg0 : vector<2x10xf32>) -> vector<2x2x5xf32> {
+ %s = vector.shape_cast %arg0 : vector<2x10xf32> to vector<2x2x5xf32>
+ return %s : vector<2x2x5xf32>
+}
+
+
+// Some pseudocode describing how this function is lowered:
+//
+// func collapse_inner_dims(A : vector<2x2x5xi8>) -> vector<1x2x1x10xi8> {
+// v0 = empty of shape (10)
+// v1 = empty of shape (1,2,1,10)
+// v0[0:5] = A[0,0,:]
+// v0[5:10] = A[0,1,:]
+// v1[0,0,0,:] = v0
+// v0[0:5] = A[1,0,:]
+// v0[5:10] = A[1,1,:]
+// v1[0,1,0,:] = v0
+// return v1;
+// }
+// CHECK-LABEL: func.func @collapse_inner_dims(
+// CHECK-SAME: %[[A:.*]]: vector<2x2x5xi8>) -> vector<1x2x1x10xi8> {
+// CHECK-DAG: %[[UBSMALL:.*]] = ub.poison : vector<10xi8>
+// CHECK-DAG: %[[UBLARGE:.*]] = ub.poison : vector<1x2x1x10xi8>
+//
+// CHECK: %[[EX0:.*]] = vector.extract %[[A]][0, 0]
+// CHECK: %[[IN0:.*]] = vector.insert_strided_slice %[[EX0]], %[[UBSMALL]]
+// CHECK-SAME: {offsets = [0], {{.*}}
+// CHECK: %[[EX1:.*]] = vector.extract %[[A]][0, 1]
+// CHECK: %[[IN1:.*]] = vector.insert_strided_slice %[[EX1]], %[[IN0]]
+// CHECK-SAME: {offsets = [5], {{.*}}
+// CHECK: %[[IN2:.*]] = vector.insert %[[IN1]], %[[UBLARGE]] [0, 0, 0]
+//
+// CHECK: %[[EX2:.*]] = vector.extract %[[A]][1, 0]
+// CHECK: %[[IN3:.*]] = vector.insert_strided_slice %[[EX2]], %[[UBSMALL]]
+// CHECK-SAME: {offsets = [0], {{.*}}
+// CHECK: %[[EX3:.*]] = vector.extract %[[A]][1, 1]
+// CHECK: %[[IN4:.*]] = vector.insert_strided_slice %[[EX3]], %[[IN3]]
+// CHECK-SAME: {offsets = [5], {{.*}}
+// CHECK: %[[IN5:.*]] = vector.insert %[[IN4]], %[[IN2]] [0, 1, 0]
+// CHECK: return %[[IN5]] : vector<1x2x1x10xi8>
+func.func @collapse_inner_dims(%arg0 : vector<2x2x5xi8>) -> vector<1x2x1x10xi8> {
+ %s = vector.shape_cast %arg0 : vector<2x2x5xi8> to vector<1x2x1x10xi8>
+ return %s : vector<1x2x1x10xi8>
+}
+
+// Some alternative pseudocode describing how this function is lowered:
+//
+// func non_dividing_gcd_decreasing(A : vector<2x15xi8>) -> vector<3x10xi8> {
+// v0 = empty of shape (10)
+// v1 = empty of shape (3,10)
+// e0 = A[0,:]
+// v0[0:5] = e0[0:5]
+// v0[5:10] = e0[5:10]
+// v1[0,:] = v0
+// v0[0,0:5] = e0[10:15]
+// e1 = A[1,:]
+// v0[0,5:10] = e1[0:5]
+// v1[1,:] = v0
+// v0[0,0:5] = e1[5:10]
+// v0[0,5:10] = e1[10:15]
+// v1[2,:] = v0
+// return v1;
+// }
+// CHECK-LABEL: func.func @non_dividing_gcd_decreasing(
+// CHECK-SAME: %[[A:.*]]: vector<2x15xi8>) -> vector<3x10xi8> {
+// CHECK-DAG: %[[UB0:.*]] = ub.poison : vector<10xi8>
+// CHECK-DAG: %[[UB1:.*]] = ub.poison : vector<3x10xi8>
+//
+// First 10 elements:
+// CHECK: %[[EX0:.*]] = vector.extract %[[A]][0] : vector<15xi8> from vector<2x15xi8>
+// CHECK: %[[SS0:.*]] = vector.extract_strided_slice %[[EX0]]
+// CHECK-SAME: {offsets = [0], {{.*}} to vector<5xi8>
+// CHECK: %[[IN0:.*]] = vector.insert_strided_slice %[[SS0]], %[[UB0]]
+// CHECK-SAME: {offsets = [0], {{.*}}
+// CHECK: %[[SS1:.*]] = vector.extract_strided_slice %[[EX0]]
+// CHECK-SAME: {offsets = [5], {{.*}} to vector<5xi8>
+// CHECK: %[[IN1:.*]] = vector.insert_strided_slice %[[SS1]], %[[IN0]]
+// CHECK-SAME: {offsets = [5], {{.*}}
+// CHECK: %[[IN2:.*]] = vector.insert %[[IN1]], %[[UB1]] [0] : vector<10xi8> into vector<3x10xi8>
+//
+// Next 10 elements:
+// CHECK: %[[SS2:.*]] = vector.extract_strided_slice %[[EX0]]
+// CHECK-SAME: {offsets = [10], {{.*}} to vector<5xi8>
+// CHECK: %[[IN3:.*]] = vector.insert_strided_slice %[[SS2]], %[[UB0]]
+// CHECK-SAME: {offsets = [0], {{.*}}
+// CHECK: %[[EX1:.*]] = vector.extract %[[A]][1] : vector<15xi8> from vector<2x15xi8>
+// CHECK: %[[SS3:.*]] = vector.extract_strided_slice %[[EX1]]
+// CHECK-SAME: {offsets = [0], {{.*}} to vector<5xi8>
+// CHECK: %[[IN4:.*]] = vector.insert_strided_slice %[[SS3]], %[[IN3]]
+// CHECK-SAME: {offsets = [5], {{.*}}
+// CHECK: %[[IN5:.*]] = vector.insert %[[IN4]], %[[IN2]] [1] : vector<10xi8> into vector<3x10xi8>
+//
+// Final 10 elements:
+// CHECK: %[[SS4:.*]] = vector.extract_strided_slice %[[EX1]]
+// CHECK-SAME: {offsets = [5], {{.*}} to vector<5xi8>
+// CHECK: %[[IN6:.*]] = vector.insert_strided_slice %[[SS4]], %[[UB0]]
+// CHECK-SAME: {offsets = [0], {{.*}}
+// CHECK: %[[SS5:.*]] = vector.extract_strided_slice %[[EX1]]
+// CHECK-SAME: {offsets = [10], {{.*}} to vector<5xi8>
+// CHECK: %[[IN7:.*]] = vector.insert_strided_slice %[[SS5]], %[[IN6]]
+// CHECK-SAME: {offsets = [5], {{.*}}
+// CHECK: %[[IN8:.*]] = vector.insert %[[IN7]], %[[IN5]] [2] : vector<10xi8> into vector<3x10xi8>
+// CHECK: return %[[IN8]] : vector<3x10xi8>
+func.func @non_dividing_gcd_decreasing(%arg0 : vector<2x15xi8>) -> vector<3x10xi8> {
+ %0 = vector.shape_cast %arg0 : vector<2x15xi8> to vector<3x10xi8>
+ return %0 : vector<3x10xi8>
+}
+
+// CHECK-LABEL: func.func @non_dividing_gcd_increasing(
+// CHECK-SAME: %[[A:.*]]: vector<3x10xi8>) -> vector<2x15xi8> {
+//
+// CHECK-DAG: ub.poison : vector<15xi8>
+// CHECK-DAG: ub.poison : vector<2x15xi8>
+//
+// Collect the first 15 elements, and insert into the first row of the result.
+// CHECK: vector.extract %[[A]][0]
+// CHECK: extract_strided_slice
+// CHECK: insert_strided_slice
+// CHECK: extract_strided_slice
+// CHECK: insert_strided_slice
+// CHECK: vector.extract %[[A]][1]
+// CHECK: extract_strided_slice
+// CHECK: insert_strided_slice
+// CHECK: vector.insert {{.*}} [0] : vector<15xi8> into vector<2x15xi8>
+//
+// Collect the next 15 elements, and insert into the second row of the result.
+// CHECK: extract_strided_slice
+// CHECK: insert_strided_slice
+// CHECK: vector.extract %[[A]][2]
+// CHECK: extract_strided_slice
+// CHECK: insert_strided_slice
+// CHECK: extract_strided_slice
+// CHECK: insert_strided_slice
+// CHECK: vector.insert {{.*}} [1] : vector<15xi8> into vector<2x15xi8>
+func.func @non_dividing_gcd_increasing(%arg0 : vector<3x10xi8>) -> vector<2x15xi8> {
+ %0 = vector.shape_cast %arg0 : vector<3x10xi8> to vector<2x15xi8>
+ return %0 : vector<2x15xi8>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
More information about the Mlir-commits
mailing list