[Mlir-commits] [mlir] [mlir][vector] Improve shape_cast lowering (PR #140800)
James Newling
llvmlistbot at llvm.org
Tue Jun 3 14:31:51 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/140800
>From 49ebfa731b9168cad0946b6e43229ae0ce064bd7 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 20 May 2025 13:07:23 -0700
Subject: [PATCH 1/5] extract as large a chunk as possible in shape_cast
lowering
---
.../Transforms/LowerVectorShapeCast.cpp | 80 +++++++++++--------
...vector-shape-cast-lowering-transforms.mlir | 51 ++++++++++++
2 files changed, 99 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 23324a007377e..d0085bffca23c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -28,17 +28,20 @@ using namespace mlir;
using namespace mlir::vector;
/// Increments n-D `indices` by `step` starting from the innermost dimension.
-static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
+static void incIdx(MutableArrayRef<int64_t> indices, ArrayRef<int64_t> shape,
int step = 1) {
for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
- assert(indices[dim] < vecType.getDimSize(dim) &&
- "Indices are out of bound");
+ int64_t dimSize = shape[dim];
+ assert(indices[dim] < dimSize && "Indices are out of bound");
+
indices[dim] += step;
- if (indices[dim] < vecType.getDimSize(dim))
+
+ int64_t spill = indices[dim] / dimSize;
+ if (spill == 0)
break;
- indices[dim] = 0;
- step = 1;
+ indices[dim] %= dimSize;
+ step = spill;
}
}
@@ -79,8 +82,8 @@ class ShapeCastOpNDDownCastRewritePattern
// and destination slice insertion and generate such instructions.
for (int64_t i = 0; i < numElts; ++i) {
if (i != 0) {
- incIdx(srcIdx, sourceVectorType, /*step=*/1);
- incIdx(resIdx, resultVectorType, /*step=*/extractSize);
+ incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/1);
+ incIdx(resIdx, resultVectorType.getShape(), /*step=*/extractSize);
}
Value extract =
@@ -131,8 +134,8 @@ class ShapeCastOpNDUpCastRewritePattern
Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
for (int64_t i = 0; i < numElts; ++i) {
if (i != 0) {
- incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
- incIdx(resIdx, resultVectorType, /*step=*/1);
+ incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/extractSize);
+ incIdx(resIdx, resultVectorType.getShape(), /*step=*/1);
}
Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
@@ -157,41 +160,54 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- auto sourceVectorType = op.getSourceVectorType();
- auto resultVectorType = op.getResultVectorType();
+ VectorType sourceType = op.getSourceVectorType();
+ VectorType resultType = op.getResultVectorType();
- if (sourceVectorType.isScalable() || resultVectorType.isScalable())
+ if (sourceType.isScalable() || resultType.isScalable())
return failure();
- // Special case for n-D / 1-D lowerings with better implementations.
- int64_t srcRank = sourceVectorType.getRank();
- int64_t resRank = resultVectorType.getRank();
- if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1))
+ // Special case for n-D / 1-D lowerings with implementations that use
+ // extract_strided_slice / insert_strided_slice.
+ int64_t sourceRank = sourceType.getRank();
+ int64_t resultRank = resultType.getRank();
+ if ((sourceRank > 1 && resultRank == 1) ||
+ (sourceRank == 1 && resultRank > 1))
return failure();
- // Generic ShapeCast lowering path goes all the way down to unrolled scalar
- // extract/insert chains.
- int64_t numElts = 1;
- for (int64_t r = 0; r < srcRank; r++)
- numElts *= sourceVectorType.getDimSize(r);
+ int64_t numExtracts = sourceType.getNumElements();
+ int64_t nbCommonInnerDims = 0;
+ while (true) {
+ int64_t sourceDim = sourceRank - 1 - nbCommonInnerDims;
+ int64_t resultDim = resultRank - 1 - nbCommonInnerDims;
+ if (sourceDim < 0 || resultDim < 0)
+ break;
+ int64_t dimSize = sourceType.getDimSize(sourceDim);
+ if (dimSize != resultType.getDimSize(resultDim))
+ break;
+ numExtracts /= dimSize;
+ ++nbCommonInnerDims;
+ }
+
// Replace with data movement operations:
// x[0,0,0] = y[0,0]
// x[0,0,1] = y[0,1]
// x[0,1,0] = y[0,2]
// etc., incrementing the two index vectors "row-major"
// within the source and result shape.
- SmallVector<int64_t> srcIdx(srcRank, 0);
- SmallVector<int64_t> resIdx(resRank, 0);
- Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
- for (int64_t i = 0; i < numElts; i++) {
+ SmallVector<int64_t> sourceIndex(sourceRank - nbCommonInnerDims, 0);
+ SmallVector<int64_t> resultIndex(resultRank - nbCommonInnerDims, 0);
+ Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
+
+ for (int64_t i = 0; i < numExtracts; i++) {
if (i != 0) {
- incIdx(srcIdx, sourceVectorType);
- incIdx(resIdx, resultVectorType);
+ incIdx(sourceIndex, sourceType.getShape().drop_back(nbCommonInnerDims));
+ incIdx(resultIndex, resultType.getShape().drop_back(nbCommonInnerDims));
}
Value extract =
- rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
- result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
+ rewriter.create<vector::ExtractOp>(loc, op.getSource(), sourceIndex);
+ result =
+ rewriter.create<vector::InsertOp>(loc, extract, result, resultIndex);
}
rewriter.replaceOp(op, result);
return success();
@@ -329,8 +345,8 @@ class ScalableShapeCastOpRewritePattern
// 4. Increment the insert/extract indices, stepping by minExtractionSize
// for the trailing dimensions.
- incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize);
- incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize);
+ incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/minExtractionSize);
+ incIdx(resIdx, resultVectorType.getShape(), /*step=*/minExtractionSize);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index ef32f8c6a1cdb..044a73d71e665 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -140,6 +140,57 @@ func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
return %s : vector<f32>
}
+
+// CHECK-LABEL: func.func @shape_cast_squeeze_leading_one(
+// CHECK-SAME: %[[ARG0:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
+// CHECK: %[[EXTRACTED:.*]] = vector.extract %[[ARG0]][0] :
+// CHECK-SAME: vector<2x3xf32> from vector<1x2x3xf32>
+// CHECK: return %[[EXTRACTED]] : vector<2x3xf32>
+func.func @shape_cast_squeeze_leading_one(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> {
+ %s = vector.shape_cast %arg0 : vector<1x2x3xf32> to vector<2x3xf32>
+ return %s : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func.func @shape_cast_squeeze_middle_one(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
+// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<3xf32>
+// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0] : vector<3xf32>
+// CHECK-SAME: into vector<2x3xf32>
+// CHECK: %[[E1:.*]] = vector.extract %[[ARG0]][1, 0] : vector<3xf32>
+// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1] : vector<3xf32>
+// CHECK-SAME: into vector<2x3xf32>
+// CHECK: return %[[I1]] : vector<2x3xf32>
+func.func @shape_cast_squeeze_middle_one(%arg0 : vector<2x1x3xf32>) -> vector<2x3xf32> {
+ %s = vector.shape_cast %arg0 : vector<2x1x3xf32> to vector<2x3xf32>
+ return %s : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func.func @shape_cast_unsqueeze_leading_one(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<1x2x3xf32>
+// CHECK: %[[INSERTED:.*]] = vector.insert %[[ARG0]], %[[UB]] [0]
+// CHECK-SAME: : vector<2x3xf32> into vector<1x2x3xf32>
+// CHECK: return %[[INSERTED]] : vector<1x2x3xf32>
+func.func @shape_cast_unsqueeze_leading_one(%arg0 : vector<2x3xf32>) -> vector<1x2x3xf32> {
+ %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<1x2x3xf32>
+ return %s : vector<1x2x3xf32>
+}
+
+// CHECK-LABEL: func.func @shape_cast_unsqueeze_middle_one(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xf32>) -> vector<2x1x3xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
+// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0] : vector<3xf32>
+// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0, 0] : vector<3xf32>
+// CHECK: %[[E1:.*]] = vector.extract %[[ARG0]][1] : vector<3xf32>
+// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1, 0] : vector<3xf32>
+// CHECK: return %[[I1]] : vector<2x1x3xf32>
+func.func @shape_cast_unsqueeze_middle_one(%arg0 : vector<2x3xf32>) -> vector<2x1x3xf32> {
+ %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<2x1x3xf32>
+ return %s : vector<2x1x3xf32>
+}
+
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
>From 1a4b759e753c5649aceb0a12cfd65f06a3d35a82 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 22 May 2025 11:54:33 -0700
Subject: [PATCH 2/5] test name improvements
---
.../vector-shape-cast-lowering-transforms.mlir | 18 ++++++++++--------
1 file changed, 10 insertions(+), 8 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index 044a73d71e665..2875f159a2df9 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -141,17 +141,19 @@ func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
}
-// CHECK-LABEL: func.func @shape_cast_squeeze_leading_one(
+// The shapes have 2 inner dimension sizes in common, so the extract result is rank-2.
+// CHECK-LABEL: func.func @squeeze_out_prefix_unit_dim(
// CHECK-SAME: %[[ARG0:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
// CHECK: %[[EXTRACTED:.*]] = vector.extract %[[ARG0]][0] :
// CHECK-SAME: vector<2x3xf32> from vector<1x2x3xf32>
// CHECK: return %[[EXTRACTED]] : vector<2x3xf32>
-func.func @shape_cast_squeeze_leading_one(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> {
+func.func @squeeze_out_prefix_unit_dim(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> {
%s = vector.shape_cast %arg0 : vector<1x2x3xf32> to vector<2x3xf32>
return %s : vector<2x3xf32>
}
-// CHECK-LABEL: func.func @shape_cast_squeeze_middle_one(
+// The shapes have 1 inner dimension size in common, so the extract results are rank-1.
+// CHECK-LABEL: func.func @squeeze_out_middle_unit_dim(
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<3xf32>
@@ -161,23 +163,23 @@ func.func @shape_cast_squeeze_leading_one(%arg0 : vector<1x2x3xf32>) -> vector<2
// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1] : vector<3xf32>
// CHECK-SAME: into vector<2x3xf32>
// CHECK: return %[[I1]] : vector<2x3xf32>
-func.func @shape_cast_squeeze_middle_one(%arg0 : vector<2x1x3xf32>) -> vector<2x3xf32> {
+func.func @squeeze_out_middle_unit_dim(%arg0 : vector<2x1x3xf32>) -> vector<2x3xf32> {
%s = vector.shape_cast %arg0 : vector<2x1x3xf32> to vector<2x3xf32>
return %s : vector<2x3xf32>
}
-// CHECK-LABEL: func.func @shape_cast_unsqueeze_leading_one(
+// CHECK-LABEL: func.func @prepend_unit_dim(
// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
// CHECK: %[[UB:.*]] = ub.poison : vector<1x2x3xf32>
// CHECK: %[[INSERTED:.*]] = vector.insert %[[ARG0]], %[[UB]] [0]
// CHECK-SAME: : vector<2x3xf32> into vector<1x2x3xf32>
// CHECK: return %[[INSERTED]] : vector<1x2x3xf32>
-func.func @shape_cast_unsqueeze_leading_one(%arg0 : vector<2x3xf32>) -> vector<1x2x3xf32> {
+func.func @prepend_unit_dim(%arg0 : vector<2x3xf32>) -> vector<1x2x3xf32> {
%s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<1x2x3xf32>
return %s : vector<1x2x3xf32>
}
-// CHECK-LABEL: func.func @shape_cast_unsqueeze_middle_one(
+// CHECK-LABEL: func.func @insert_middle_unit_dim(
// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xf32>) -> vector<2x1x3xf32> {
// CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0] : vector<3xf32>
@@ -185,7 +187,7 @@ func.func @shape_cast_unsqueeze_leading_one(%arg0 : vector<2x3xf32>) -> vector<1
// CHECK: %[[E1:.*]] = vector.extract %[[ARG0]][1] : vector<3xf32>
// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1, 0] : vector<3xf32>
// CHECK: return %[[I1]] : vector<2x1x3xf32>
-func.func @shape_cast_unsqueeze_middle_one(%arg0 : vector<2x3xf32>) -> vector<2x1x3xf32> {
+func.func @insert_middle_unit_dim(%arg0 : vector<2x3xf32>) -> vector<2x1x3xf32> {
%s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<2x1x3xf32>
return %s : vector<2x1x3xf32>
}
>From 31ec6c6bf384da3235dc2c93f8c8bc151e8fb261 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 29 May 2025 09:23:43 -0700
Subject: [PATCH 3/5] unify the 3 shape cast lowering patterns, major refactor
of testing
---
.../Transforms/LowerVectorShapeCast.cpp | 388 ++++++++-------
...vector-shape-cast-lowering-transforms.mlir | 456 +++++++++++++-----
2 files changed, 550 insertions(+), 294 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index d0085bffca23c..e84886e285ba9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -21,138 +21,103 @@
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include <numeric>
#define DEBUG_TYPE "vector-shape-cast-lowering"
using namespace mlir;
-using namespace mlir::vector;
-/// Increments n-D `indices` by `step` starting from the innermost dimension.
-static void incIdx(MutableArrayRef<int64_t> indices, ArrayRef<int64_t> shape,
- int step = 1) {
- for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
- int64_t dimSize = shape[dim];
- assert(indices[dim] < dimSize && "Indices are out of bound");
-
- indices[dim] += step;
-
- int64_t spill = indices[dim] / dimSize;
- if (spill == 0)
+/// Perform the inplace update
+/// rhs <- lhs + rhs
+///
+/// where `rhs` is a number expressed in mixed base `base` with most signficant
+/// dimensions on the left. For example if `rhs` is {a,b,c} and `base` is
+/// {5,3,2} then `rhs` has value a*3*2 + b*2 + c.
+///
+/// Some examples where `base` is {5,3,2}:
+/// rhs = {0,0,0}, lhs = 1 --> rhs = {0,0,1}
+/// rhs = {0,0,1}, lhs = 1 --> rhs = {0,1,0}
+/// rhs = {0,0,0}, lhs = 25 --> rhs = {4,0,1}
+///
+/// Invalid:
+/// rhs = {0,0,2}, lhs = 1 : rhs not in base {5,3,2}
+///
+/// Overflows not handled correctly:
+/// rhs = {4,2,1}, lhs = 2 --> rhs = {0,0,0} (not {0,0,1})
+static void inplaceAdd(int64_t lhs, ArrayRef<int64_t> base,
+ MutableArrayRef<int64_t> rhs) {
+
+ // For dimensions in [numIndices - 1, ..., 3, 2, 1, 0]:
+ for (int dim : llvm::reverse(llvm::seq<int>(0, rhs.size()))) {
+ int64_t dimBase = base[dim];
+ assert(rhs[dim] < dimBase && "rhs not in base");
+
+ int64_t incremented = rhs[dim] + lhs;
+
+ // If the incremented value excedes the dimension base, we must spill to the
+ // next most significant dimension and repeat (we might need to spill to
+ // more significant dimensions multiple times).
+ lhs = incremented / dimBase;
+ rhs[dim] = incremented % dimBase;
+ if (lhs == 0)
break;
-
- indices[dim] %= dimSize;
- step = spill;
}
}
namespace {
-/// ShapeOp n-D -> 1-D downcast serves the purpose of flattening N-D to 1-D
-/// vectors progressively. This iterates over the n-1 major dimensions of the
-/// n-D vector and performs rewrites into:
-/// vector.extract from n-D + vector.insert_strided_slice offset into 1-D
-class ShapeCastOpNDDownCastRewritePattern
- : public OpRewritePattern<vector::ShapeCastOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ShapeCastOp op,
- PatternRewriter &rewriter) const override {
- auto sourceVectorType = op.getSourceVectorType();
- auto resultVectorType = op.getResultVectorType();
- if (sourceVectorType.isScalable() || resultVectorType.isScalable())
- return failure();
-
- int64_t srcRank = sourceVectorType.getRank();
- int64_t resRank = resultVectorType.getRank();
- if (srcRank < 2 || resRank != 1)
- return failure();
-
- // Compute the number of 1-D vector elements involved in the reshape.
- int64_t numElts = 1;
- for (int64_t dim = 0; dim < srcRank - 1; ++dim)
- numElts *= sourceVectorType.getDimSize(dim);
-
- auto loc = op.getLoc();
- SmallVector<int64_t> srcIdx(srcRank - 1, 0);
- SmallVector<int64_t> resIdx(resRank, 0);
- int64_t extractSize = sourceVectorType.getShape().back();
- Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
-
- // Compute the indices of each 1-D vector element of the source extraction
- // and destination slice insertion and generate such instructions.
- for (int64_t i = 0; i < numElts; ++i) {
- if (i != 0) {
- incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/1);
- incIdx(resIdx, resultVectorType.getShape(), /*step=*/extractSize);
- }
-
- Value extract =
- rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
- result = rewriter.create<vector::InsertStridedSliceOp>(
- loc, extract, result,
- /*offsets=*/resIdx, /*strides=*/1);
- }
-
- rewriter.replaceOp(op, result);
- return success();
- }
-};
-
-/// ShapeOp 1-D -> n-D upcast serves the purpose of unflattening n-D from 1-D
-/// vectors progressively. This iterates over the n-1 major dimension of the n-D
-/// vector and performs rewrites into:
-/// vector.extract_strided_slice from 1-D + vector.insert into n-D
-/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
-class ShapeCastOpNDUpCastRewritePattern
- : public OpRewritePattern<vector::ShapeCastOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(vector::ShapeCastOp op,
- PatternRewriter &rewriter) const override {
- auto sourceVectorType = op.getSourceVectorType();
- auto resultVectorType = op.getResultVectorType();
- if (sourceVectorType.isScalable() || resultVectorType.isScalable())
- return failure();
-
- int64_t srcRank = sourceVectorType.getRank();
- int64_t resRank = resultVectorType.getRank();
- if (srcRank != 1 || resRank < 2)
- return failure();
-
- // Compute the number of 1-D vector elements involved in the reshape.
- int64_t numElts = 1;
- for (int64_t dim = 0; dim < resRank - 1; ++dim)
- numElts *= resultVectorType.getDimSize(dim);
-
- // Compute the indices of each 1-D vector element of the source slice
- // extraction and destination insertion and generate such instructions.
- auto loc = op.getLoc();
- SmallVector<int64_t> srcIdx(srcRank, 0);
- SmallVector<int64_t> resIdx(resRank - 1, 0);
- int64_t extractSize = resultVectorType.getShape().back();
- Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
- for (int64_t i = 0; i < numElts; ++i) {
- if (i != 0) {
- incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/extractSize);
- incIdx(resIdx, resultVectorType.getShape(), /*step=*/1);
- }
-
- Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, op.getSource(), /*offsets=*/srcIdx, /*sizes=*/extractSize,
- /*strides=*/1);
- result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
- }
- rewriter.replaceOp(op, result);
- return success();
- }
-};
-
-// We typically should not lower general shape cast operations into data
-// movement instructions, since the assumption is that these casts are
-// optimized away during progressive lowering. For completeness, however,
-// we fall back to a reference implementation that moves all elements
-// into the right place if we get here.
+/// shape_cast is converted to a sequence of extract, extract_strided_slice,
+/// insert_strided_slice, and insert operations. The running example will be:
+///
+/// %0 = vector.shape_cast %arg0 :
+/// vector<2x2x3x4x7x11xi8> to vector<8x6x7x11xi8>
+///
+/// In this example the source and result shapes share a common suffix of 7x11.
+/// This means we can always decompose the shape_cast into extract, insert, and
+/// their strided equivalents, on vectors with shape suffix 7x11.
+///
+/// The greatest common divisor (gcd) of the first dimension preceding the
+/// common suffix is gcd(4,6) = 2. The algorithm implemented here will operate
+/// on vectors with shapes that are `multiples` of (what we define as) the
+/// 'atomic size', 2x7x11. The atomic size is `gcd` x `common-suffix`.
+///
+/// vector<2x2x3x4x7x11xi8> to
+/// vector<8x6x7x11xi8>
+/// ^^^^ ---> common suffix of 7x11
+/// ^ ---> gcd(4,6) is 2 | |
+/// | | |
+/// v v v
+/// atomic size <----- 2x7x11
+///
+///
+///
+/// The decomposition implemented in this patterns consists of a sequence of
+/// repeated steps:
+///
+/// (1) Extract vectors from the suffix of the source.
+/// In our example this is 2x2x3x4x7x11 -> 4x7x11.
+///
+/// (2) Do extract_strided_slice down to the atomic size.
+/// In our example this is 4x7x11 -> 2x7x11.
+///
+/// (3) Do insert_strided_slice to the suffix of the result.
+/// In our example this is 2x7x11 -> 6x7x11.
+///
+/// (4) insert these vectors into the result vector.
+/// In our example this is 6x7x11 -> 8x6x7x11.
+///
+/// These steps occur with different periods. In this example
+/// (1) occurs 12 times,
+/// (2) and (3) occur 24 times, and
+/// (4) occurs 8 times.
+///
+/// Two special cases are handled seperately:
+/// (1) A shape_cast that just does leading 1 insertion/removal
+/// (2) A shape_cast where the gcd is 1.
+///
+/// These 2 cases can have more compact IR generated by not using the generic
+/// algorithm described above.
+///
class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
@@ -164,50 +129,149 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
VectorType resultType = op.getResultVectorType();
if (sourceType.isScalable() || resultType.isScalable())
- return failure();
+ return rewriter.notifyMatchFailure(
+ op, "shape_cast lowering not handled by this pattern");
+
+ const ArrayRef<int64_t> sourceShape = sourceType.getShape();
+ const ArrayRef<int64_t> resultShape = resultType.getShape();
+ const int64_t sourceRank = sourceType.getRank();
+ const int64_t resultRank = resultType.getRank();
+ const int64_t numElms = sourceType.getNumElements();
+ const Value source = op.getSource();
+
+ // Set the first dimension (starting at the end) in the source and result
+ // respectively where the dimension sizes differ. Using the running example:
+ //
+ // dimensions: [0 1 2 3 4 5 ] [0 1 2 3 ]
+ // shapes: (2,2,3,4,7,11) -> (8,6,7,11)
+ // ^ ^
+ // | |
+ // sourceSuffixStartDim is 3 |
+ // |
+ // resultSuffixStartDim is 1
+ int64_t sourceSuffixStartDim = sourceRank - 1;
+ int64_t resultSuffixStartDim = resultRank - 1;
+ while (sourceSuffixStartDim >= 0 && resultSuffixStartDim >= 0 &&
+ (sourceType.getDimSize(sourceSuffixStartDim) ==
+ resultType.getDimSize(resultSuffixStartDim))) {
+ --sourceSuffixStartDim;
+ --resultSuffixStartDim;
+ }
- // Special case for n-D / 1-D lowerings with implementations that use
- // extract_strided_slice / insert_strided_slice.
- int64_t sourceRank = sourceType.getRank();
- int64_t resultRank = resultType.getRank();
- if ((sourceRank > 1 && resultRank == 1) ||
- (sourceRank == 1 && resultRank > 1))
- return failure();
+ // This is the case where there are just some leading ones to contend with
+ // in the source or result. It can be handled with a single extract/insert
+ // pair.
+ if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0) {
+ const int64_t delta = sourceRank - resultRank;
+ const int64_t sourceLeading = delta > 0 ? delta : 0;
+ const int64_t resultLeading = delta > 0 ? 0 : -delta;
+ const Value poison = rewriter.create<ub::PoisonOp>(loc, resultType);
+ const Value extracted = rewriter.create<vector::ExtractOp>(
+ loc, source, SmallVector<int64_t>(sourceLeading, 0));
+ const Value result = rewriter.create<vector::InsertOp>(
+ loc, extracted, poison, SmallVector<int64_t>(resultLeading, 0));
+ rewriter.replaceOp(op, result);
+ return success();
+ }
- int64_t numExtracts = sourceType.getNumElements();
- int64_t nbCommonInnerDims = 0;
- while (true) {
- int64_t sourceDim = sourceRank - 1 - nbCommonInnerDims;
- int64_t resultDim = resultRank - 1 - nbCommonInnerDims;
- if (sourceDim < 0 || resultDim < 0)
- break;
- int64_t dimSize = sourceType.getDimSize(sourceDim);
- if (dimSize != resultType.getDimSize(resultDim))
- break;
- numExtracts /= dimSize;
- ++nbCommonInnerDims;
+ const int64_t sourceSuffixStartDimSize =
+ sourceType.getDimSize(sourceSuffixStartDim);
+ const int64_t resultSuffixStartDimSize =
+ resultType.getDimSize(resultSuffixStartDim);
+ const int64_t greatestCommonDivisor =
+ std::gcd(sourceSuffixStartDimSize, resultSuffixStartDimSize);
+ const int64_t stridedSliceRank = sourceRank - sourceSuffixStartDim;
+ const size_t extractPeriod =
+ sourceSuffixStartDimSize / greatestCommonDivisor;
+ const size_t insertPeriod =
+ resultSuffixStartDimSize / greatestCommonDivisor;
+
+ SmallVector<int64_t> atomicShape(sourceShape.begin() + sourceSuffixStartDim,
+ sourceShape.end());
+ atomicShape[0] = greatestCommonDivisor;
+
+ const int64_t numAtomicElms = std::accumulate(
+ atomicShape.begin(), atomicShape.end(), 1, std::multiplies<int64_t>());
+ const size_t nAtomicSlices = numElms / numAtomicElms;
+
+ // This is the case where the strided dimension size is 1. More compact IR
+ // is generated in this case if we just extract and insert the elements
+ // directly. In other words, we don't use extract_strided_slice and
+ // insert_strided_slice.
+ if (greatestCommonDivisor == 1) {
+ sourceSuffixStartDim += 1;
+ resultSuffixStartDim += 1;
+ SmallVector<int64_t> extractIndex(sourceSuffixStartDim, 0);
+ SmallVector<int64_t> insertIndex(resultSuffixStartDim, 0);
+ Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
+ for (size_t i = 0; i < nAtomicSlices; ++i) {
+ Value extracted =
+ rewriter.create<vector::ExtractOp>(loc, source, extractIndex);
+
+ result = rewriter.create<vector::InsertOp>(loc, extracted, result,
+ insertIndex);
+
+ inplaceAdd(1, sourceShape.take_front(sourceSuffixStartDim),
+ extractIndex);
+ inplaceAdd(1, resultShape.take_front(resultSuffixStartDim),
+ insertIndex);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
}
- // Replace with data movement operations:
- // x[0,0,0] = y[0,0]
- // x[0,0,1] = y[0,1]
- // x[0,1,0] = y[0,2]
- // etc., incrementing the two index vectors "row-major"
- // within the source and result shape.
- SmallVector<int64_t> sourceIndex(sourceRank - nbCommonInnerDims, 0);
- SmallVector<int64_t> resultIndex(resultRank - nbCommonInnerDims, 0);
+ // The insert_strided_slice result's type
+ const ArrayRef<int64_t> insertStridedShape =
+ resultShape.drop_front(resultSuffixStartDim);
+ const VectorType insertStridedType =
+ VectorType::get(insertStridedShape, resultType.getElementType());
+
+ SmallVector<int64_t> extractIndex(sourceSuffixStartDim, 0);
+ SmallVector<int64_t> insertIndex(resultSuffixStartDim, 0);
+ SmallVector<int64_t> extractOffsets(stridedSliceRank, 0);
+ SmallVector<int64_t> insertOffsets(stridedSliceRank, 0);
+ const SmallVector<int64_t> sizes(stridedSliceRank, 1);
+
+ Value extracted = {};
+ Value extractedStrided = {};
+ Value insertedSlice = {};
Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
+ const Value partResult =
+ rewriter.create<ub::PoisonOp>(loc, insertStridedType);
+
+ for (size_t i = 0; i < nAtomicSlices; ++i) {
- for (int64_t i = 0; i < numExtracts; i++) {
- if (i != 0) {
- incIdx(sourceIndex, sourceType.getShape().drop_back(nbCommonInnerDims));
- incIdx(resultIndex, resultType.getShape().drop_back(nbCommonInnerDims));
+ const size_t extractStridedPhase = i % extractPeriod;
+ const size_t insertStridedPhase = i % insertPeriod;
+
+ // vector.extract
+ if (extractStridedPhase == 0) {
+ extracted =
+ rewriter.create<vector::ExtractOp>(loc, source, extractIndex);
+ inplaceAdd(1, sourceShape.take_front(sourceSuffixStartDim),
+ extractIndex);
}
- Value extract =
- rewriter.create<vector::ExtractOp>(loc, op.getSource(), sourceIndex);
- result =
- rewriter.create<vector::InsertOp>(loc, extract, result, resultIndex);
+ // vector.extract_strided_slice
+ extractOffsets[0] = extractStridedPhase * greatestCommonDivisor;
+ extractedStrided = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, extracted, extractOffsets, atomicShape, sizes);
+
+ // vector.insert_strided_slice
+ if (insertStridedPhase == 0) {
+ insertedSlice = partResult;
+ }
+ insertOffsets[0] = insertStridedPhase * greatestCommonDivisor;
+ insertedSlice = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, extractedStrided, insertedSlice, insertOffsets, sizes);
+
+ // vector.insert
+ if (insertStridedPhase + 1 == insertPeriod) {
+ result = rewriter.create<vector::InsertOp>(loc, insertedSlice, result,
+ insertIndex);
+ inplaceAdd(1, resultType.getShape().take_front(resultSuffixStartDim),
+ insertIndex);
+ }
}
rewriter.replaceOp(op, result);
return success();
@@ -345,8 +409,8 @@ class ScalableShapeCastOpRewritePattern
// 4. Increment the insert/extract indices, stepping by minExtractionSize
// for the trailing dimensions.
- incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/minExtractionSize);
- incIdx(resIdx, resultVectorType.getShape(), /*step=*/minExtractionSize);
+ inplaceAdd(minExtractionSize, sourceVectorType.getShape(), srcIdx);
+ inplaceAdd(minExtractionSize, resultVectorType.getShape(), resIdx);
}
rewriter.replaceOp(op, result);
@@ -363,8 +427,6 @@ class ScalableShapeCastOpRewritePattern
void mlir::vector::populateVectorShapeCastLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<ShapeCastOpNDDownCastRewritePattern,
- ShapeCastOpNDUpCastRewritePattern, ShapeCastOpRewritePattern,
- ScalableShapeCastOpRewritePattern>(patterns.getContext(),
- benefit);
+ patterns.add<ShapeCastOpRewritePattern, ScalableShapeCastOpRewritePattern>(
+ patterns.getContext(), benefit);
}
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index 2875f159a2df9..7b843750169b8 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -1,197 +1,391 @@
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
// CHECK-LABEL: func @nop_shape_cast
-// CHECK-SAME: %[[A:.*]]: vector<16xf32>
-// CHECK: return %[[A]] : vector<16xf32>
+// CHECK-SAME: %[[A:.*]]: vector<16xf32>
+// CHECK: return %[[A]] : vector<16xf32>
func.func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32>
return %0 : vector<16xf32>
}
// CHECK-LABEL: func @cancel_shape_cast
-// CHECK-SAME: %[[A:.*]]: vector<16xf32>
-// CHECK: return %[[A]] : vector<16xf32>
-
+// CHECK-SAME: %[[A:.*]]: vector<16xf32>
+// CHECK: return %[[A]] : vector<16xf32>
func.func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
%1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32>
return %1 : vector<16xf32>
}
-// Shape up and downcasts for 2-D vectors, for supporting conversion to
-// llvm.matrix operations
-// CHECK-LABEL: func @shape_casts
-func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) {
- // CHECK-DAG: %[[ub22:.*]] = ub.poison : vector<2x2xf32>
- // CHECK-DAG: %[[ub:.*]] = ub.poison : vector<4xf32>
- // CHECK: %[[ex0:.*]] = vector.extract %{{.*}}[0] : vector<2xf32> from vector<2x2xf32>
- //
- // CHECK: %[[in0:.*]] = vector.insert_strided_slice %[[ex0]], %[[ub]]
- // CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
- //
- // CHECK: %[[ex1:.*]] = vector.extract %{{.*}}[1] : vector<2xf32> from vector<2x2xf32>
- //
- // CHECK: %[[in2:.*]] = vector.insert_strided_slice %[[ex1]], %[[in0]]
- // CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
- //
+// Collapse 2-D to 1-D.
+// CHECK-LABEL: func @shape_cast_2d1d
+// CHECK-SAME: %[[A:.*]]: vector<2x2xf32>) -> vector<4xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<4xf32>
+//
+// CHECK: %[[EX0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[IN0:.*]] = vector.insert_strided_slice %[[EX0]], %[[UB]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+//
+// CHECK: %[[EX1:.*]] = vector.extract %{{.*}}[1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[IN2:.*]] = vector.insert_strided_slice %[[EX1]], %[[IN0]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+// CHECK: return %[[IN2]] : vector<4xf32>
+func.func @shape_cast_2d1d(%a: vector<2x2xf32>) -> (vector<4xf32>) {
%0 = vector.shape_cast %a : vector<2x2xf32> to vector<4xf32>
- // CHECK: %[[add:.*]] = arith.addf %[[in2]], %[[in2]] : vector<4xf32>
- %r0 = arith.addf %0, %0: vector<4xf32>
- //
- // CHECK: %[[ss0:.*]] = vector.extract_strided_slice %[[add]]
- // CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} :
- // CHECK-SAME: vector<4xf32> to vector<2xf32>
- //
- // CHECK: %[[res0:.*]] = vector.insert %[[ss0]], %[[ub22]] [0] :
- // CHECK-SAME: vector<2xf32> into vector<2x2xf32>
- //
- // CHECK: %[[s2:.*]] = vector.extract_strided_slice %[[add]]
- // CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} :
- // CHECK-SAME: vector<4xf32> to vector<2xf32>
- //
- // CHECK: %[[res1:.*]] = vector.insert %[[s2]], %[[res0]] [1] :
- // CHECK-SAME: vector<2xf32> into vector<2x2xf32>
- //
- %1 = vector.shape_cast %r0 : vector<4xf32> to vector<2x2xf32>
- // CHECK: return %[[add]], %[[res1]] : vector<4xf32>, vector<2x2xf32>
- return %r0, %1 : vector<4xf32>, vector<2x2xf32>
-}
-
-// CHECK-LABEL: func @shape_cast_2d2d
-// CHECK-SAME: %[[A:.*]]: vector<3x2xf32>
-// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
-// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<3x2xf32>
-// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[UB]] [0, 0] : f32 into vector<2x3xf32>
-// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : f32 from vector<3x2xf32>
-// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<3x2xf32>
-// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32>
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : f32 from vector<3x2xf32>
-// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<3x2xf32>
-// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32>
-// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : f32 from vector<3x2xf32>
-// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32>
-// CHECK: return %[[T11]] : vector<2x3xf32>
-
-func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
- %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32>
- return %s : vector<2x3xf32>
+ return %0 : vector<4xf32>
}
+// Collapse 3-D to 1-D.
// CHECK-LABEL: func @shape_cast_3d1d
-// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
-// CHECK: %[[UB:.*]] = ub.poison : vector<6xf32>
-// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2xf32> from vector<1x3x2xf32>
-// CHECK: %[[T1:.*]] = vector.insert_strided_slice %[[T0]], %[[UB]]
-// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<6xf32>
-// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2xf32> from vector<1x3x2xf32>
-// CHECK: %[[T3:.*]] = vector.insert_strided_slice %[[T2]], %[[T1]]
-// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<6xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2xf32> from vector<1x3x2xf32>
-// CHECK: %[[T5:.*]] = vector.insert_strided_slice %[[T4]], %[[T3]]
-// CHECK-SAME: {offsets = [4], strides = [1]} : vector<2xf32> into vector<6xf32>
-// CHECK: return %[[T5]] : vector<6xf32>
-
+// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
+// CHECK: %[[UB:.*]] = ub.poison : vector<6xf32>
+//
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2xf32> from vector<1x3x2xf32>
+// CHECK: %[[T1:.*]] = vector.insert_strided_slice %[[T0]], %[[UB]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<6xf32>
+//
+// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2xf32> from vector<1x3x2xf32>
+// CHECK: %[[T3:.*]] = vector.insert_strided_slice %[[T2]], %[[T1]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<6xf32>
+//
+// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2xf32> from vector<1x3x2xf32>
+// CHECK: %[[T5:.*]] = vector.insert_strided_slice %[[T4]], %[[T3]]
+// CHECK-SAME: {offsets = [4], strides = [1]} : vector<2xf32> into vector<6xf32>
+// CHECK: return %[[T5]] : vector<6xf32>
func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
%s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32>
return %s : vector<6xf32>
}
-// CHECK-LABEL: func @shape_cast_1d3d
-// CHECK-SAME: %[[A:.*]]: vector<6xf32>
-// CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
-// CHECK: %[[T0:.*]] = vector.extract_strided_slice %[[A]]
-// CHECK-SAME: {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
-// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[UB]] [0, 0] : vector<3xf32> into vector<2x1x3xf32>
-// CHECK: %[[T2:.*]] = vector.extract_strided_slice %[[A]]
-// CHECK: {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
-// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : vector<3xf32> into vector<2x1x3xf32>
-// CHECK: return %[[T3]] : vector<2x1x3xf32>
+// Expand 1-D to 2-D.
+// CHECK-LABEL: func.func @shape_cast_1d2d(
+// CHECK-SAME: %[[A:.*]]: vector<4xf32>) -> vector<2x2xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<2x2xf32>
+//
+// CHECK: %[[SS0:.*]] = vector.extract_strided_slice %[[A]]
+// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} :
+// CHECK-SAME: vector<4xf32> to vector<2xf32>
+// CHECK: %[[res0:.*]] = vector.insert %[[SS0]], %[[UB]] [0] :
+// CHECK-SAME: vector<2xf32> into vector<2x2xf32>
+//
+// CHECK: %[[SS2:.*]] = vector.extract_strided_slice %[[A]]
+// CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} :
+// CHECK-SAME: vector<4xf32> to vector<2xf32>
+// CHECK: %[[res1:.*]] = vector.insert %[[SS2]], %[[res0]] [1] :
+// CHECK-SAME: vector<2xf32> into vector<2x2xf32>
+// CHECK: return %[[res1]] : vector<2x2xf32>
+func.func @shape_cast_1d2d(%a: vector<4xf32>) -> (vector<2x2xf32>) {
+ %1 = vector.shape_cast %a: vector<4xf32> to vector<2x2xf32>
+ return %1 : vector<2x2xf32>
+}
+// Expand 1-D to 3-D.
+// CHECK-LABEL: func @shape_cast_1d3d
+// CHECK-SAME: %[[A:.*]]: vector<6xf32>
+// CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
+//
+// CHECK: %[[T0:.*]] = vector.extract_strided_slice %[[A]]
+// CHECK-SAME: {offsets = [0], sizes = [3], strides = [1]} :
+// CHECK-SAME: vector<6xf32> to vector<3xf32>
+// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[UB]] [0, 0] :
+// CHECK-SAME: vector<3xf32> into vector<2x1x3xf32>
+//
+// CHECK: %[[T2:.*]] = vector.extract_strided_slice %[[A]]
+// CHECK-SAME: {offsets = [3], sizes = [3], strides = [1]} :
+// CHECK-SAME: vector<6xf32> to vector<3xf32>
+// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] :
+// CHECK-SAME: vector<3xf32> into vector<2x1x3xf32>
+// CHECK: return %[[T3]] : vector<2x1x3xf32>
func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
%s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32>
return %s : vector<2x1x3xf32>
}
-// CHECK-LABEL: func.func @shape_cast_0d1d(
-// CHECK-SAME: %[[ARG0:.*]]: vector<f32>) -> vector<1xf32> {
-// CHECK: %[[UB:.*]] = ub.poison : vector<1xf32>
-// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][] : f32 from vector<f32>
-// CHECK: %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [0] : f32 into vector<1xf32>
-// CHECK: return %[[RES]] : vector<1xf32>
-// CHECK: }
+// 2-D to 2-D where the inner-most dimensions have no common factors. This
+// case requires scalar element by element extraction and insertion.
+// CHECK-LABEL: func @shape_cast_2d2d
+// CHECK-SAME: %[[A:.*]]: vector<3x2xf32>
+// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
+//
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<3x2xf32>
+// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[UB]] [0, 0] :
+//
+// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : f32 from vector<3x2xf32>
+// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] :
+//
+// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<3x2xf32>
+// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] :
+//
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : f32 from vector<3x2xf32>
+// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] :
+//
+// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<3x2xf32>
+// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] :
+//
+// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : f32 from vector<3x2xf32>
+// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] :
+//
+// CHECK: return %[[T11]] : vector<2x3xf32>
+func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
+ %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32>
+ return %s : vector<2x3xf32>
+}
+// CHECK-LABEL: func.func @shape_cast_0d1d(
+// CHECK-SAME: %[[A:.*]]: vector<f32>) -> vector<1xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<1xf32>
+//
+// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[A]][] : f32 from vector<f32>
+// CHECK: %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [0] :
+// CHECK: return %[[RES]] : vector<1xf32>
func.func @shape_cast_0d1d(%arg0 : vector<f32>) -> vector<1xf32> {
%s = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
return %s : vector<1xf32>
}
-// CHECK-LABEL: func.func @shape_cast_1d0d(
-// CHECK-SAME: %[[ARG0:.*]]: vector<1xf32>) -> vector<f32> {
-// CHECK: %[[UB:.*]] = ub.poison : vector<f32>
-// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
-// CHECK: %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [] : f32 into vector<f32>
-// CHECK: return %[[RES]] : vector<f32>
-// CHECK: }
-
+// CHECK-LABEL: func.func @shape_cast_1d0d(
+// CHECK-SAME: %[[A:.*]]: vector<1xf32>) -> vector<f32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<f32>
+//
+// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[A]][0] : f32 from vector<1xf32>
+// CHECK: %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [] :
+// CHECK: return %[[RES]] : vector<f32>
func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
%s = vector.shape_cast %arg0 : vector<1xf32> to vector<f32>
return %s : vector<f32>
}
-
// The shapes have 2 inner dimension sizes in common, so the extract result is rank-2.
-// CHECK-LABEL: func.func @squeeze_out_prefix_unit_dim(
-// CHECK-SAME: %[[ARG0:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
-// CHECK: %[[EXTRACTED:.*]] = vector.extract %[[ARG0]][0] :
-// CHECK-SAME: vector<2x3xf32> from vector<1x2x3xf32>
-// CHECK: return %[[EXTRACTED]] : vector<2x3xf32>
+// CHECK-LABEL: func.func @squeeze_out_prefix_unit_dim(
+// CHECK-SAME: %[[A:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
+//
+// CHECK: %[[EXTRACTED:.*]] = vector.extract %[[A]][0] :
+// CHECK-SAME: vector<2x3xf32> from vector<1x2x3xf32>
+// CHECK: return %[[EXTRACTED]] : vector<2x3xf32>
func.func @squeeze_out_prefix_unit_dim(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> {
%s = vector.shape_cast %arg0 : vector<1x2x3xf32> to vector<2x3xf32>
return %s : vector<2x3xf32>
}
// The shapes have 1 inner dimension size in common, so the extract results are rank-1.
-// CHECK-LABEL: func.func @squeeze_out_middle_unit_dim(
-// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
-// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
-// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<3xf32>
-// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0] : vector<3xf32>
-// CHECK-SAME: into vector<2x3xf32>
-// CHECK: %[[E1:.*]] = vector.extract %[[ARG0]][1, 0] : vector<3xf32>
-// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1] : vector<3xf32>
-// CHECK-SAME: into vector<2x3xf32>
-// CHECK: return %[[I1]] : vector<2x3xf32>
+// CHECK-LABEL: func.func @squeeze_out_middle_unit_dim(
+// CHECK-SAME: %[[A:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
+//
+// CHECK: %[[E0:.*]] = vector.extract %[[A]][0, 0] : vector<3xf32>
+// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0] :
+//
+// CHECK: %[[E1:.*]] = vector.extract %[[A]][1, 0] : vector<3xf32>
+// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1] :
+// CHECK: return %[[I1]] : vector<2x3xf32>
func.func @squeeze_out_middle_unit_dim(%arg0 : vector<2x1x3xf32>) -> vector<2x3xf32> {
%s = vector.shape_cast %arg0 : vector<2x1x3xf32> to vector<2x3xf32>
return %s : vector<2x3xf32>
}
-// CHECK-LABEL: func.func @prepend_unit_dim(
-// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
-// CHECK: %[[UB:.*]] = ub.poison : vector<1x2x3xf32>
-// CHECK: %[[INSERTED:.*]] = vector.insert %[[ARG0]], %[[UB]] [0]
-// CHECK-SAME: : vector<2x3xf32> into vector<1x2x3xf32>
-// CHECK: return %[[INSERTED]] : vector<1x2x3xf32>
+// CHECK-LABEL: func.func @prepend_unit_dim(
+// CHECK-SAME: %[[A:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<1x2x3xf32>
+//
+// CHECK: %[[I0:.*]] = vector.insert %[[A]], %[[UB]] [0]
+// CHECK: return %[[I0]] : vector<1x2x3xf32>
func.func @prepend_unit_dim(%arg0 : vector<2x3xf32>) -> vector<1x2x3xf32> {
%s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<1x2x3xf32>
return %s : vector<1x2x3xf32>
}
-// CHECK-LABEL: func.func @insert_middle_unit_dim(
-// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xf32>) -> vector<2x1x3xf32> {
-// CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
-// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0] : vector<3xf32>
-// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0, 0] : vector<3xf32>
-// CHECK: %[[E1:.*]] = vector.extract %[[ARG0]][1] : vector<3xf32>
-// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1, 0] : vector<3xf32>
-// CHECK: return %[[I1]] : vector<2x1x3xf32>
+// CHECK-LABEL: func.func @insert_middle_unit_dim(
+// CHECK-SAME: %[[A:.*]]: vector<2x3xf32>) -> vector<2x1x3xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
+//
+// CHECK: %[[E0:.*]] = vector.extract %[[A]][0] : vector<3xf32>
+// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0, 0] : vector<3xf32>
+//
+// CHECK: %[[E1:.*]] = vector.extract %[[A]][1] : vector<3xf32>
+// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1, 0] : vector<3xf32>
+// CHECK: return %[[I1]] : vector<2x1x3xf32>
func.func @insert_middle_unit_dim(%arg0 : vector<2x3xf32>) -> vector<2x1x3xf32> {
%s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<2x1x3xf32>
return %s : vector<2x1x3xf32>
}
+// CHECK-LABEL: func.func @postpend_unit_dims(
+// CHECK-SAME: %[[A:.*]]: vector<4xf32>) -> vector<4x1x1xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<4x1x1xf32>
+// CHECK: %[[E0:.*]] = vector.extract %[[A]][0] : f32 from vector<4xf32>
+// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0, 0, 0] : f32 into vector<4x1x1xf32>
+// CHECK: %[[E1:.*]] = vector.extract %[[A]][1] : f32 from vector<4xf32>
+// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1, 0, 0] : f32 into vector<4x1x1xf32>
+// CHECK: vector.extract %[[A]][2]
+// CHECK: vector.insert {{.*}} [2, 0, 0]
+// CHECK: vector.extract %[[A]][3]
+// CHECK: vector.insert {{.*}} [3, 0, 0]
+// CHECK: return
+func.func @postpend_unit_dims(%arg0 : vector<4xf32>) -> vector<4x1x1xf32> {
+ %s = vector.shape_cast %arg0 : vector<4xf32> to vector<4x1x1xf32>
+ return %s : vector<4x1x1xf32>
+}
+
+// CHECK-LABEL: func.func @expand_inner_dims(
+// CHECK-SAME: %[[A:.*]]: vector<2x10xf32>) -> vector<2x2x5xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<2x2x5xf32>
+//
+// CHECK: %[[E0:.*]] = vector.extract %[[A]][0] : vector<10xf32>
+// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[E0]]
+// CHECK-SAME: {offsets = [0], sizes = [5], {{.*}} to vector<5xf32>
+// CHECK: %[[I0:.*]] = vector.insert %[[S0]], %[[UB]] [0, 0]
+//
+// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[E0]]
+// CHECK-SAME: {offsets = [5], sizes = [5], {{.*}} to vector<5xf32>
+// CHECK: %[[I1:.*]] = vector.insert %[[S1]], %[[I0]] [0, 1]
+//
+// CHECK: %[[E1:.*]] = vector.extract %[[A]][1] : vector<10xf32>
+// CHECK: %[[S2:.*]] = vector.extract_strided_slice %[[E1]]
+// CHECK-SAME: {offsets = [0], sizes = [5], {{.*}} to vector<5xf32>
+// CHECK: %[[I2:.*]] = vector.insert %[[S2]], %[[I1]] [1, 0]
+//
+// CHECK: %[[S3:.*]] = vector.extract_strided_slice %[[E1]]
+// CHECK-SAME: {offsets = [5], sizes = [5], {{.*}} to vector<5xf32>
+// CHECK: %[[I3:.*]] = vector.insert %[[S3]], %[[I2]] [1, 1]
+// CHECK: return %[[I3]] : vector<2x2x5xf32>
+func.func @expand_inner_dims(%arg0 : vector<2x10xf32>) -> vector<2x2x5xf32> {
+ %s = vector.shape_cast %arg0 : vector<2x10xf32> to vector<2x2x5xf32>
+ return %s : vector<2x2x5xf32>
+}
+
+
+// Some pseudocode describing how this function is lowered:
+//
+// func collapse_inner_dims(A : vector<2x2x5xi8>) -> vector<1x2x1x10xi8> {
+// v0 = empty of shape (10)
+// v1 = empty of shape (1,2,1,10)
+// v0[0:5] = A[0,0,:]
+// v0[5:10] = A[0,1,:]
+// v1[0,0,0,:] = v0
+// v0[0:5] = A[1,0,:]
+// v0[5:10] = A[1,1,:]
+// v1[0,1,0,:] = v0
+// return v1;
+// }
+// CHECK-LABEL: func.func @collapse_inner_dims(
+// CHECK-SAME: %[[A:.*]]: vector<2x2x5xi8>) -> vector<1x2x1x10xi8> {
+// CHECK-DAG: %[[UBSMALL:.*]] = ub.poison : vector<10xi8>
+// CHECK-DAG: %[[UBLARGE:.*]] = ub.poison : vector<1x2x1x10xi8>
+//
+// CHECK: %[[EX0:.*]] = vector.extract %[[A]][0, 0]
+// CHECK: %[[IN0:.*]] = vector.insert_strided_slice %[[EX0]], %[[UBSMALL]]
+// CHECK-SAME: {offsets = [0], {{.*}}
+// CHECK: %[[EX1:.*]] = vector.extract %[[A]][0, 1]
+// CHECK: %[[IN1:.*]] = vector.insert_strided_slice %[[EX1]], %[[IN0]]
+// CHECK-SAME: {offsets = [5], {{.*}}
+// CHECK: %[[IN2:.*]] = vector.insert %[[IN1]], %[[UBLARGE]] [0, 0, 0]
+//
+// CHECK: %[[EX2:.*]] = vector.extract %[[A]][1, 0]
+// CHECK: %[[IN3:.*]] = vector.insert_strided_slice %[[EX2]], %[[UBSMALL]]
+// CHECK-SAME: {offsets = [0], {{.*}}
+// CHECK: %[[EX3:.*]] = vector.extract %[[A]][1, 1]
+// CHECK: %[[IN4:.*]] = vector.insert_strided_slice %[[EX3]], %[[IN3]]
+// CHECK-SAME: {offsets = [5], {{.*}}
+// CHECK: %[[IN5:.*]] = vector.insert %[[IN4]], %[[IN2]] [0, 1, 0]
+// CHECK: return %[[IN5]] : vector<1x2x1x10xi8>
+func.func @collapse_inner_dims(%arg0 : vector<2x2x5xi8>) -> vector<1x2x1x10xi8> {
+ %s = vector.shape_cast %arg0 : vector<2x2x5xi8> to vector<1x2x1x10xi8>
+ return %s : vector<1x2x1x10xi8>
+}
+
+// Some alternative pseudocode describing how this function is lowered:
+//
+// func non_dividing_gcd_decreasing(A : vector<2x15xi8>) -> vector<3x10xi8> {
+// v0 = empty of shape (10)
+// v1 = empty of shape (3,10)
+// e0 = A[0,:]
+// v0[0:5] = e0[0:5]
+// v0[5:10] = e0[5:10]
+// v1[0,:] = v0
+// v0[0,0:5] = e0[10:15]
+// e1 = A[1,:]
+// v0[0,5:10] = e1[0:5]
+// v1[1,:] = v0
+// v0[0,0:5] = e1[5:10]
+// v0[0,5:10] = e1[10:15]
+// v1[2,:] = v0
+// return v1;
+// }
+// CHECK-LABEL: func.func @non_dividing_gcd_decreasing(
+// CHECK-SAME: %[[A:.*]]: vector<2x15xi8>) -> vector<3x10xi8> {
+// CHECK-DAG: %[[UB0:.*]] = ub.poison : vector<10xi8>
+// CHECK-DAG: %[[UB1:.*]] = ub.poison : vector<3x10xi8>
+//
+// First 10 elements:
+// CHECK: %[[EX0:.*]] = vector.extract %[[A]][0] : vector<15xi8> from vector<2x15xi8>
+// CHECK: %[[SS0:.*]] = vector.extract_strided_slice %[[EX0]]
+// CHECK-SAME: {offsets = [0], {{.*}} to vector<5xi8>
+// CHECK: %[[IN0:.*]] = vector.insert_strided_slice %[[SS0]], %[[UB0]]
+// CHECK-SAME: {offsets = [0], {{.*}}
+// CHECK: %[[SS1:.*]] = vector.extract_strided_slice %[[EX0]]
+// CHECK-SAME: {offsets = [5], {{.*}} to vector<5xi8>
+// CHECK: %[[IN1:.*]] = vector.insert_strided_slice %[[SS1]], %[[IN0]]
+// CHECK-SAME: {offsets = [5], {{.*}}
+// CHECK: %[[IN2:.*]] = vector.insert %[[IN1]], %[[UB1]] [0] : vector<10xi8> into vector<3x10xi8>
+//
+// Next 10 elements:
+// CHECK: %[[SS2:.*]] = vector.extract_strided_slice %[[EX0]]
+// CHECK-SAME: {offsets = [10], {{.*}} to vector<5xi8>
+// CHECK: %[[IN3:.*]] = vector.insert_strided_slice %[[SS2]], %[[UB0]]
+// CHECK-SAME: {offsets = [0], {{.*}}
+// CHECK: %[[EX1:.*]] = vector.extract %[[A]][1] : vector<15xi8> from vector<2x15xi8>
+// CHECK: %[[SS3:.*]] = vector.extract_strided_slice %[[EX1]]
+// CHECK-SAME: {offsets = [0], {{.*}} to vector<5xi8>
+// CHECK: %[[IN4:.*]] = vector.insert_strided_slice %[[SS3]], %[[IN3]]
+// CHECK-SAME: {offsets = [5], {{.*}}
+// CHECK: %[[IN5:.*]] = vector.insert %[[IN4]], %[[IN2]] [1] : vector<10xi8> into vector<3x10xi8>
+//
+// Final 10 elements:
+// CHECK: %[[SS4:.*]] = vector.extract_strided_slice %[[EX1]]
+// CHECK-SAME: {offsets = [5], {{.*}} to vector<5xi8>
+// CHECK: %[[IN6:.*]] = vector.insert_strided_slice %[[SS4]], %[[UB0]]
+// CHECK-SAME: {offsets = [0], {{.*}}
+// CHECK: %[[SS5:.*]] = vector.extract_strided_slice %[[EX1]]
+// CHECK-SAME: {offsets = [10], {{.*}} to vector<5xi8>
+// CHECK: %[[IN7:.*]] = vector.insert_strided_slice %[[SS5]], %[[IN6]]
+// CHECK-SAME: {offsets = [5], {{.*}}
+// CHECK: %[[IN8:.*]] = vector.insert %[[IN7]], %[[IN5]] [2] : vector<10xi8> into vector<3x10xi8>
+// CHECK: return %[[IN8]] : vector<3x10xi8>
+func.func @non_dividing_gcd_decreasing(%arg0 : vector<2x15xi8>) -> vector<3x10xi8> {
+ %0 = vector.shape_cast %arg0 : vector<2x15xi8> to vector<3x10xi8>
+ return %0 : vector<3x10xi8>
+}
+
+// CHECK-LABEL: func.func @non_dividing_gcd_increasing(
+// CHECK-SAME: %[[A:.*]]: vector<3x10xi8>) -> vector<2x15xi8> {
+//
+// CHECK-DAG: ub.poison : vector<15xi8>
+// CHECK-DAG: ub.poison : vector<2x15xi8>
+//
+// Collect the first 15 elements, and insert into the first row of the result.
+// CHECK: vector.extract %[[A]][0]
+// CHECK: extract_strided_slice
+// CHECK: insert_strided_slice
+// CHECK: extract_strided_slice
+// CHECK: insert_strided_slice
+// CHECK: vector.extract %[[A]][1]
+// CHECK: extract_strided_slice
+// CHECK: insert_strided_slice
+// CHECK: vector.insert {{.*}} [0] : vector<15xi8> into vector<2x15xi8>
+//
+// Collect the next 15 elements, and insert into the second row of the result.
+// CHECK: extract_strided_slice
+// CHECK: insert_strided_slice
+// CHECK: vector.extract %[[A]][2]
+// CHECK: extract_strided_slice
+// CHECK: insert_strided_slice
+// CHECK: extract_strided_slice
+// CHECK: insert_strided_slice
+// CHECK: vector.insert {{.*}} [1] : vector<15xi8> into vector<2x15xi8>
+func.func @non_dividing_gcd_increasing(%arg0 : vector<3x10xi8>) -> vector<2x15xi8> {
+ %0 = vector.shape_cast %arg0 : vector<3x10xi8> to vector<2x15xi8>
+ return %0 : vector<2x15xi8>
+}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
>From 5ad484e1a37ea21dd0c16e16e33686693c7f2991 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 29 May 2025 09:42:47 -0700
Subject: [PATCH 4/5] cosmetic fixes and better failure notifications
---
.../Vector/Transforms/LowerVectorShapeCast.cpp | 14 ++++++++------
1 file changed, 8 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index e84886e285ba9..79048c5aab67a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -79,7 +79,7 @@ namespace {
/// The greatest common divisor (gcd) of the first dimension preceding the
/// common suffix is gcd(4,6) = 2. The algorithm implemented here will operate
/// on vectors with shapes that are `multiples` of (what we define as) the
-/// 'atomic size', 2x7x11. The atomic size is `gcd` x `common-suffix`.
+/// 'atomic shape', 2x7x11. The atomic shape is `gcd` x `common-suffix`.
///
/// vector<2x2x3x4x7x11xi8> to
/// vector<8x6x7x11xi8>
@@ -87,17 +87,17 @@ namespace {
/// ^ ---> gcd(4,6) is 2 | |
/// | | |
/// v v v
-/// atomic size <----- 2x7x11
+/// atomic shape <----- 2x7x11
///
///
///
-/// The decomposition implemented in this patterns consists of a sequence of
+/// The decomposition implemented in this pattern consists of a sequence of
/// repeated steps:
///
/// (1) Extract vectors from the suffix of the source.
/// In our example this is 2x2x3x4x7x11 -> 4x7x11.
///
-/// (2) Do extract_strided_slice down to the atomic size.
+/// (2) Do extract_strided_slice down to the atomic shape.
/// In our example this is 4x7x11 -> 2x7x11.
///
/// (3) Do insert_strided_slice to the suffix of the result.
@@ -130,7 +130,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
if (sourceType.isScalable() || resultType.isScalable())
return rewriter.notifyMatchFailure(
- op, "shape_cast lowering not handled by this pattern");
+ op,
+ "shape_cast where vectors are scalable not handled by this pattern");
const ArrayRef<int64_t> sourceShape = sourceType.getShape();
const ArrayRef<int64_t> resultShape = resultType.getShape();
@@ -332,7 +333,8 @@ class ScalableShapeCastOpRewritePattern
// from >= 2-D scalable vectors or scalable vectors of fixed vectors.
if (!isTrailingDimScalable(sourceVectorType) ||
!isTrailingDimScalable(resultVectorType)) {
- return failure();
+ return rewriter.notifyMatchFailure(
+ op, "trailing dims are not scalable, not handled by this pattern");
}
// The sizes of the trailing dimension of the source and result vectors, the
>From 125573b1f889aa791e77935ca1ed0ffa634307be Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 3 Jun 2025 14:32:24 -0700
Subject: [PATCH 5/5] cosmetic improvements and comment clarification
---
.../Transforms/LowerVectorShapeCast.cpp | 21 ++--
...vector-shape-cast-lowering-transforms.mlir | 112 +++++++++---------
2 files changed, 67 insertions(+), 66 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 79048c5aab67a..b10cfa9932464 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -83,8 +83,9 @@ namespace {
///
/// vector<2x2x3x4x7x11xi8> to
/// vector<8x6x7x11xi8>
-/// ^^^^ ---> common suffix of 7x11
-/// ^ ---> gcd(4,6) is 2 | |
+/// | ||||
+/// | ++++------------> common suffix of 7x11
+/// +-----------------> gcd(4,6) is 2 | |
/// | | |
/// v v v
/// atomic shape <----- 2x7x11
@@ -111,9 +112,9 @@ namespace {
/// (2) and (3) occur 24 times, and
/// (4) occurs 8 times.
///
-/// Two special cases are handled seperately:
-/// (1) A shape_cast that just does leading 1 insertion/removal
-/// (2) A shape_cast where the gcd is 1.
+/// Two special cases are handled independently in this pattern
+/// (i) A shape_cast that just does leading 1 insertion/removal
+/// (ii) A shape_cast where the gcd is 1.
///
/// These 2 cases can have more compact IR generated by not using the generic
/// algorithm described above.
@@ -159,9 +160,9 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
--resultSuffixStartDim;
}
- // This is the case where there are just some leading ones to contend with
- // in the source or result. It can be handled with a single extract/insert
- // pair.
+ // This is the case (i) where there are just some leading ones to contend
+ // with in the source or result. It can be handled with a single
+ // extract/insert pair.
if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0) {
const int64_t delta = sourceRank - resultRank;
const int64_t sourceLeading = delta > 0 ? delta : 0;
@@ -195,8 +196,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
atomicShape.begin(), atomicShape.end(), 1, std::multiplies<int64_t>());
const size_t nAtomicSlices = numElms / numAtomicElms;
- // This is the case where the strided dimension size is 1. More compact IR
- // is generated in this case if we just extract and insert the elements
+ // This is the case (ii) where the strided dimension size is 1. More compact
+ // IR is generated in this case if we just extract and insert the elements
// directly. In other words, we don't use extract_strided_slice and
// insert_strided_slice.
if (greatestCommonDivisor == 1) {
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index 7b843750169b8..5011d8b2b2ef6 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
// CHECK-LABEL: func @nop_shape_cast
-// CHECK-SAME: %[[A:.*]]: vector<16xf32>
+// CHECK-SAME: %[[A:.*]]: vector<16xf32>
// CHECK: return %[[A]] : vector<16xf32>
func.func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32>
@@ -9,7 +9,7 @@ func.func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
}
// CHECK-LABEL: func @cancel_shape_cast
-// CHECK-SAME: %[[A:.*]]: vector<16xf32>
+// CHECK-SAME: %[[A:.*]]: vector<16xf32>
// CHECK: return %[[A]] : vector<16xf32>
func.func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
@@ -19,16 +19,16 @@ func.func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
// Collapse 2-D to 1-D.
// CHECK-LABEL: func @shape_cast_2d1d
-// CHECK-SAME: %[[A:.*]]: vector<2x2xf32>) -> vector<4xf32> {
+// CHECK-SAME: %[[A:.*]]: vector<2x2xf32>) -> vector<4xf32> {
// CHECK: %[[UB:.*]] = ub.poison : vector<4xf32>
//
// CHECK: %[[EX0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
// CHECK: %[[IN0:.*]] = vector.insert_strided_slice %[[EX0]], %[[UB]]
-// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
//
// CHECK: %[[EX1:.*]] = vector.extract %{{.*}}[1] : vector<2xf32> from vector<2x2xf32>
// CHECK: %[[IN2:.*]] = vector.insert_strided_slice %[[EX1]], %[[IN0]]
-// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
// CHECK: return %[[IN2]] : vector<4xf32>
func.func @shape_cast_2d1d(%a: vector<2x2xf32>) -> (vector<4xf32>) {
%0 = vector.shape_cast %a : vector<2x2xf32> to vector<4xf32>
@@ -37,20 +37,20 @@ func.func @shape_cast_2d1d(%a: vector<2x2xf32>) -> (vector<4xf32>) {
// Collapse 3-D to 1-D.
// CHECK-LABEL: func @shape_cast_3d1d
-// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
+// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
// CHECK: %[[UB:.*]] = ub.poison : vector<6xf32>
//
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2xf32> from vector<1x3x2xf32>
// CHECK: %[[T1:.*]] = vector.insert_strided_slice %[[T0]], %[[UB]]
-// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<6xf32>
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<6xf32>
//
// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2xf32> from vector<1x3x2xf32>
// CHECK: %[[T3:.*]] = vector.insert_strided_slice %[[T2]], %[[T1]]
-// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<6xf32>
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<6xf32>
//
// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2xf32> from vector<1x3x2xf32>
// CHECK: %[[T5:.*]] = vector.insert_strided_slice %[[T4]], %[[T3]]
-// CHECK-SAME: {offsets = [4], strides = [1]} : vector<2xf32> into vector<6xf32>
+// CHECK-SAME: {offsets = [4], strides = [1]} : vector<2xf32> into vector<6xf32>
// CHECK: return %[[T5]] : vector<6xf32>
func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
%s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32>
@@ -59,20 +59,20 @@ func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
// Expand 1-D to 2-D.
// CHECK-LABEL: func.func @shape_cast_1d2d(
-// CHECK-SAME: %[[A:.*]]: vector<4xf32>) -> vector<2x2xf32> {
+// CHECK-SAME: %[[A:.*]]: vector<4xf32>) -> vector<2x2xf32> {
// CHECK: %[[UB:.*]] = ub.poison : vector<2x2xf32>
//
// CHECK: %[[SS0:.*]] = vector.extract_strided_slice %[[A]]
-// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} :
-// CHECK-SAME: vector<4xf32> to vector<2xf32>
+// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} :
+// CHECK-SAME: vector<4xf32> to vector<2xf32>
// CHECK: %[[res0:.*]] = vector.insert %[[SS0]], %[[UB]] [0] :
-// CHECK-SAME: vector<2xf32> into vector<2x2xf32>
+// CHECK-SAME: vector<2xf32> into vector<2x2xf32>
//
// CHECK: %[[SS2:.*]] = vector.extract_strided_slice %[[A]]
-// CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} :
-// CHECK-SAME: vector<4xf32> to vector<2xf32>
+// CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} :
+// CHECK-SAME: vector<4xf32> to vector<2xf32>
// CHECK: %[[res1:.*]] = vector.insert %[[SS2]], %[[res0]] [1] :
-// CHECK-SAME: vector<2xf32> into vector<2x2xf32>
+// CHECK-SAME: vector<2xf32> into vector<2x2xf32>
// CHECK: return %[[res1]] : vector<2x2xf32>
func.func @shape_cast_1d2d(%a: vector<4xf32>) -> (vector<2x2xf32>) {
%1 = vector.shape_cast %a: vector<4xf32> to vector<2x2xf32>
@@ -81,20 +81,20 @@ func.func @shape_cast_1d2d(%a: vector<4xf32>) -> (vector<2x2xf32>) {
// Expand 1-D to 3-D.
// CHECK-LABEL: func @shape_cast_1d3d
-// CHECK-SAME: %[[A:.*]]: vector<6xf32>
+// CHECK-SAME: %[[A:.*]]: vector<6xf32>
// CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
//
// CHECK: %[[T0:.*]] = vector.extract_strided_slice %[[A]]
-// CHECK-SAME: {offsets = [0], sizes = [3], strides = [1]} :
-// CHECK-SAME: vector<6xf32> to vector<3xf32>
+// CHECK-SAME: {offsets = [0], sizes = [3], strides = [1]} :
+// CHECK-SAME: vector<6xf32> to vector<3xf32>
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[UB]] [0, 0] :
-// CHECK-SAME: vector<3xf32> into vector<2x1x3xf32>
+// CHECK-SAME: vector<3xf32> into vector<2x1x3xf32>
//
// CHECK: %[[T2:.*]] = vector.extract_strided_slice %[[A]]
-// CHECK-SAME: {offsets = [3], sizes = [3], strides = [1]} :
-// CHECK-SAME: vector<6xf32> to vector<3xf32>
+// CHECK-SAME: {offsets = [3], sizes = [3], strides = [1]} :
+// CHECK-SAME: vector<6xf32> to vector<3xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] :
-// CHECK-SAME: vector<3xf32> into vector<2x1x3xf32>
+// CHECK-SAME: vector<3xf32> into vector<2x1x3xf32>
// CHECK: return %[[T3]] : vector<2x1x3xf32>
func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
%s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32>
@@ -104,7 +104,7 @@ func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
// 2-D to 2-D where the inner-most dimensions have no common factors. This
// case requires scalar element by element extraction and insertion.
// CHECK-LABEL: func @shape_cast_2d2d
-// CHECK-SAME: %[[A:.*]]: vector<3x2xf32>
+// CHECK-SAME: %[[A:.*]]: vector<3x2xf32>
// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
//
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<3x2xf32>
@@ -132,7 +132,7 @@ func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
}
// CHECK-LABEL: func.func @shape_cast_0d1d(
-// CHECK-SAME: %[[A:.*]]: vector<f32>) -> vector<1xf32> {
+// CHECK-SAME: %[[A:.*]]: vector<f32>) -> vector<1xf32> {
// CHECK: %[[UB:.*]] = ub.poison : vector<1xf32>
//
// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[A]][] : f32 from vector<f32>
@@ -144,7 +144,7 @@ func.func @shape_cast_0d1d(%arg0 : vector<f32>) -> vector<1xf32> {
}
// CHECK-LABEL: func.func @shape_cast_1d0d(
-// CHECK-SAME: %[[A:.*]]: vector<1xf32>) -> vector<f32> {
+// CHECK-SAME: %[[A:.*]]: vector<1xf32>) -> vector<f32> {
// CHECK: %[[UB:.*]] = ub.poison : vector<f32>
//
// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[A]][0] : f32 from vector<1xf32>
@@ -157,10 +157,10 @@ func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
// The shapes have 2 inner dimension sizes in common, so the extract result is rank-2.
// CHECK-LABEL: func.func @squeeze_out_prefix_unit_dim(
-// CHECK-SAME: %[[A:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
+// CHECK-SAME: %[[A:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
//
// CHECK: %[[EXTRACTED:.*]] = vector.extract %[[A]][0] :
-// CHECK-SAME: vector<2x3xf32> from vector<1x2x3xf32>
+// CHECK-SAME: vector<2x3xf32> from vector<1x2x3xf32>
// CHECK: return %[[EXTRACTED]] : vector<2x3xf32>
func.func @squeeze_out_prefix_unit_dim(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> {
%s = vector.shape_cast %arg0 : vector<1x2x3xf32> to vector<2x3xf32>
@@ -169,7 +169,7 @@ func.func @squeeze_out_prefix_unit_dim(%arg0 : vector<1x2x3xf32>) -> vector<2x3x
// The shapes have 1 inner dimension size in common, so the extract results are rank-1.
// CHECK-LABEL: func.func @squeeze_out_middle_unit_dim(
-// CHECK-SAME: %[[A:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
+// CHECK-SAME: %[[A:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
//
// CHECK: %[[E0:.*]] = vector.extract %[[A]][0, 0] : vector<3xf32>
@@ -184,7 +184,7 @@ func.func @squeeze_out_middle_unit_dim(%arg0 : vector<2x1x3xf32>) -> vector<2x3x
}
// CHECK-LABEL: func.func @prepend_unit_dim(
-// CHECK-SAME: %[[A:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
+// CHECK-SAME: %[[A:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
// CHECK: %[[UB:.*]] = ub.poison : vector<1x2x3xf32>
//
// CHECK: %[[I0:.*]] = vector.insert %[[A]], %[[UB]] [0]
@@ -195,7 +195,7 @@ func.func @prepend_unit_dim(%arg0 : vector<2x3xf32>) -> vector<1x2x3xf32> {
}
// CHECK-LABEL: func.func @insert_middle_unit_dim(
-// CHECK-SAME: %[[A:.*]]: vector<2x3xf32>) -> vector<2x1x3xf32> {
+// CHECK-SAME: %[[A:.*]]: vector<2x3xf32>) -> vector<2x1x3xf32> {
// CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
//
// CHECK: %[[E0:.*]] = vector.extract %[[A]][0] : vector<3xf32>
@@ -210,7 +210,7 @@ func.func @insert_middle_unit_dim(%arg0 : vector<2x3xf32>) -> vector<2x1x3xf32>
}
// CHECK-LABEL: func.func @postpend_unit_dims(
-// CHECK-SAME: %[[A:.*]]: vector<4xf32>) -> vector<4x1x1xf32> {
+// CHECK-SAME: %[[A:.*]]: vector<4xf32>) -> vector<4x1x1xf32> {
// CHECK: %[[UB:.*]] = ub.poison : vector<4x1x1xf32>
// CHECK: %[[E0:.*]] = vector.extract %[[A]][0] : f32 from vector<4xf32>
// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0, 0, 0] : f32 into vector<4x1x1xf32>
@@ -227,25 +227,25 @@ func.func @postpend_unit_dims(%arg0 : vector<4xf32>) -> vector<4x1x1xf32> {
}
// CHECK-LABEL: func.func @expand_inner_dims(
-// CHECK-SAME: %[[A:.*]]: vector<2x10xf32>) -> vector<2x2x5xf32> {
+// CHECK-SAME: %[[A:.*]]: vector<2x10xf32>) -> vector<2x2x5xf32> {
// CHECK: %[[UB:.*]] = ub.poison : vector<2x2x5xf32>
//
// CHECK: %[[E0:.*]] = vector.extract %[[A]][0] : vector<10xf32>
// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[E0]]
-// CHECK-SAME: {offsets = [0], sizes = [5], {{.*}} to vector<5xf32>
+// CHECK-SAME: {offsets = [0], sizes = [5], {{.*}} to vector<5xf32>
// CHECK: %[[I0:.*]] = vector.insert %[[S0]], %[[UB]] [0, 0]
//
// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[E0]]
-// CHECK-SAME: {offsets = [5], sizes = [5], {{.*}} to vector<5xf32>
+// CHECK-SAME: {offsets = [5], sizes = [5], {{.*}} to vector<5xf32>
// CHECK: %[[I1:.*]] = vector.insert %[[S1]], %[[I0]] [0, 1]
//
// CHECK: %[[E1:.*]] = vector.extract %[[A]][1] : vector<10xf32>
// CHECK: %[[S2:.*]] = vector.extract_strided_slice %[[E1]]
-// CHECK-SAME: {offsets = [0], sizes = [5], {{.*}} to vector<5xf32>
+// CHECK-SAME: {offsets = [0], sizes = [5], {{.*}} to vector<5xf32>
// CHECK: %[[I2:.*]] = vector.insert %[[S2]], %[[I1]] [1, 0]
//
// CHECK: %[[S3:.*]] = vector.extract_strided_slice %[[E1]]
-// CHECK-SAME: {offsets = [5], sizes = [5], {{.*}} to vector<5xf32>
+// CHECK-SAME: {offsets = [5], sizes = [5], {{.*}} to vector<5xf32>
// CHECK: %[[I3:.*]] = vector.insert %[[S3]], %[[I2]] [1, 1]
// CHECK: return %[[I3]] : vector<2x2x5xf32>
func.func @expand_inner_dims(%arg0 : vector<2x10xf32>) -> vector<2x2x5xf32> {
@@ -268,24 +268,24 @@ func.func @expand_inner_dims(%arg0 : vector<2x10xf32>) -> vector<2x2x5xf32> {
// return v1;
// }
// CHECK-LABEL: func.func @collapse_inner_dims(
-// CHECK-SAME: %[[A:.*]]: vector<2x2x5xi8>) -> vector<1x2x1x10xi8> {
+// CHECK-SAME: %[[A:.*]]: vector<2x2x5xi8>) -> vector<1x2x1x10xi8> {
// CHECK-DAG: %[[UBSMALL:.*]] = ub.poison : vector<10xi8>
// CHECK-DAG: %[[UBLARGE:.*]] = ub.poison : vector<1x2x1x10xi8>
//
// CHECK: %[[EX0:.*]] = vector.extract %[[A]][0, 0]
// CHECK: %[[IN0:.*]] = vector.insert_strided_slice %[[EX0]], %[[UBSMALL]]
-// CHECK-SAME: {offsets = [0], {{.*}}
+// CHECK-SAME: {offsets = [0], {{.*}}
// CHECK: %[[EX1:.*]] = vector.extract %[[A]][0, 1]
// CHECK: %[[IN1:.*]] = vector.insert_strided_slice %[[EX1]], %[[IN0]]
-// CHECK-SAME: {offsets = [5], {{.*}}
+// CHECK-SAME: {offsets = [5], {{.*}}
// CHECK: %[[IN2:.*]] = vector.insert %[[IN1]], %[[UBLARGE]] [0, 0, 0]
//
// CHECK: %[[EX2:.*]] = vector.extract %[[A]][1, 0]
// CHECK: %[[IN3:.*]] = vector.insert_strided_slice %[[EX2]], %[[UBSMALL]]
-// CHECK-SAME: {offsets = [0], {{.*}}
+// CHECK-SAME: {offsets = [0], {{.*}}
// CHECK: %[[EX3:.*]] = vector.extract %[[A]][1, 1]
// CHECK: %[[IN4:.*]] = vector.insert_strided_slice %[[EX3]], %[[IN3]]
-// CHECK-SAME: {offsets = [5], {{.*}}
+// CHECK-SAME: {offsets = [5], {{.*}}
// CHECK: %[[IN5:.*]] = vector.insert %[[IN4]], %[[IN2]] [0, 1, 0]
// CHECK: return %[[IN5]] : vector<1x2x1x10xi8>
func.func @collapse_inner_dims(%arg0 : vector<2x2x5xi8>) -> vector<1x2x1x10xi8> {
@@ -312,43 +312,43 @@ func.func @collapse_inner_dims(%arg0 : vector<2x2x5xi8>) -> vector<1x2x1x10xi8>
// return v1;
// }
// CHECK-LABEL: func.func @non_dividing_gcd_decreasing(
-// CHECK-SAME: %[[A:.*]]: vector<2x15xi8>) -> vector<3x10xi8> {
+// CHECK-SAME: %[[A:.*]]: vector<2x15xi8>) -> vector<3x10xi8> {
// CHECK-DAG: %[[UB0:.*]] = ub.poison : vector<10xi8>
// CHECK-DAG: %[[UB1:.*]] = ub.poison : vector<3x10xi8>
//
// First 10 elements:
// CHECK: %[[EX0:.*]] = vector.extract %[[A]][0] : vector<15xi8> from vector<2x15xi8>
// CHECK: %[[SS0:.*]] = vector.extract_strided_slice %[[EX0]]
-// CHECK-SAME: {offsets = [0], {{.*}} to vector<5xi8>
+// CHECK-SAME: {offsets = [0], {{.*}} to vector<5xi8>
// CHECK: %[[IN0:.*]] = vector.insert_strided_slice %[[SS0]], %[[UB0]]
-// CHECK-SAME: {offsets = [0], {{.*}}
+// CHECK-SAME: {offsets = [0], {{.*}}
// CHECK: %[[SS1:.*]] = vector.extract_strided_slice %[[EX0]]
-// CHECK-SAME: {offsets = [5], {{.*}} to vector<5xi8>
+// CHECK-SAME: {offsets = [5], {{.*}} to vector<5xi8>
// CHECK: %[[IN1:.*]] = vector.insert_strided_slice %[[SS1]], %[[IN0]]
-// CHECK-SAME: {offsets = [5], {{.*}}
+// CHECK-SAME: {offsets = [5], {{.*}}
// CHECK: %[[IN2:.*]] = vector.insert %[[IN1]], %[[UB1]] [0] : vector<10xi8> into vector<3x10xi8>
//
// Next 10 elements:
// CHECK: %[[SS2:.*]] = vector.extract_strided_slice %[[EX0]]
-// CHECK-SAME: {offsets = [10], {{.*}} to vector<5xi8>
+// CHECK-SAME: {offsets = [10], {{.*}} to vector<5xi8>
// CHECK: %[[IN3:.*]] = vector.insert_strided_slice %[[SS2]], %[[UB0]]
-// CHECK-SAME: {offsets = [0], {{.*}}
+// CHECK-SAME: {offsets = [0], {{.*}}
// CHECK: %[[EX1:.*]] = vector.extract %[[A]][1] : vector<15xi8> from vector<2x15xi8>
// CHECK: %[[SS3:.*]] = vector.extract_strided_slice %[[EX1]]
-// CHECK-SAME: {offsets = [0], {{.*}} to vector<5xi8>
+// CHECK-SAME: {offsets = [0], {{.*}} to vector<5xi8>
// CHECK: %[[IN4:.*]] = vector.insert_strided_slice %[[SS3]], %[[IN3]]
-// CHECK-SAME: {offsets = [5], {{.*}}
+// CHECK-SAME: {offsets = [5], {{.*}}
// CHECK: %[[IN5:.*]] = vector.insert %[[IN4]], %[[IN2]] [1] : vector<10xi8> into vector<3x10xi8>
//
// Final 10 elements:
// CHECK: %[[SS4:.*]] = vector.extract_strided_slice %[[EX1]]
-// CHECK-SAME: {offsets = [5], {{.*}} to vector<5xi8>
+// CHECK-SAME: {offsets = [5], {{.*}} to vector<5xi8>
// CHECK: %[[IN6:.*]] = vector.insert_strided_slice %[[SS4]], %[[UB0]]
-// CHECK-SAME: {offsets = [0], {{.*}}
+// CHECK-SAME: {offsets = [0], {{.*}}
// CHECK: %[[SS5:.*]] = vector.extract_strided_slice %[[EX1]]
-// CHECK-SAME: {offsets = [10], {{.*}} to vector<5xi8>
+// CHECK-SAME: {offsets = [10], {{.*}} to vector<5xi8>
// CHECK: %[[IN7:.*]] = vector.insert_strided_slice %[[SS5]], %[[IN6]]
-// CHECK-SAME: {offsets = [5], {{.*}}
+// CHECK-SAME: {offsets = [5], {{.*}}
// CHECK: %[[IN8:.*]] = vector.insert %[[IN7]], %[[IN5]] [2] : vector<10xi8> into vector<3x10xi8>
// CHECK: return %[[IN8]] : vector<3x10xi8>
func.func @non_dividing_gcd_decreasing(%arg0 : vector<2x15xi8>) -> vector<3x10xi8> {
@@ -357,7 +357,7 @@ func.func @non_dividing_gcd_decreasing(%arg0 : vector<2x15xi8>) -> vector<3x10xi
}
// CHECK-LABEL: func.func @non_dividing_gcd_increasing(
-// CHECK-SAME: %[[A:.*]]: vector<3x10xi8>) -> vector<2x15xi8> {
+// CHECK-SAME: %[[A:.*]]: vector<3x10xi8>) -> vector<2x15xi8> {
//
// CHECK-DAG: ub.poison : vector<15xi8>
// CHECK-DAG: ub.poison : vector<2x15xi8>
More information about the Mlir-commits
mailing list