[Mlir-commits] [mlir] db6f8eb - [mlir][Vector] Support 0-D vectors in ShuffleOp
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Aug 29 00:40:06 PDT 2022
Author: Nicolas Vasilache
Date: 2022-08-29T00:39:57-07:00
New Revision: db6f8ebe066f0be13f94418f090c626a228225c4
URL: https://github.com/llvm/llvm-project/commit/db6f8ebe066f0be13f94418f090c626a228225c4
DIFF: https://github.com/llvm/llvm-project/commit/db6f8ebe066f0be13f94418f090c626a228225c4.diff
LOG: [mlir][Vector] Support 0-D vectors in ShuffleOp
Co-authored-by: Michal Terepeta <michalt at google.com>
Reviewed-by: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D115744
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/canonicalize.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index b7783c601a92b..d2fe879dcd2b8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -437,22 +437,25 @@ def Vector_ShuffleOp :
PredOpTrait<"second operand v2 and result have same element type",
TCresVTEtIsSameAsOpBase<0, 1>>,
DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
- Arguments<(ins AnyVector:$v1, AnyVector:$v2, I64ArrayAttr:$mask)>,
+ Arguments<(ins AnyVectorOfAnyRank:$v1, AnyVectorOfAnyRank:$v2,
+ I64ArrayAttr:$mask)>,
Results<(outs AnyVector:$vector)> {
let summary = "shuffle operation";
let description = [{
The shuffle operation constructs a permutation (or duplication) of elements
from two input vectors, returning a vector with the same element type as
the input and a length that is the same as the shuffle mask. The two input
- vectors must have the same element type, rank, and trailing dimension sizes
- and shuffles their values in the leading dimension (which may
diff er in size)
- according to the given mask. The legality rules are:
+ vectors must have the same element type, same rank , and trailing dimension
+ sizes and shuffles their values in the
+ leading dimension (which may
diff er in size) according to the given mask.
+ The legality rules are:
* the two operands must have the same element type as the result
- * the two operands and the result must have the same rank and trailing
- dimension sizes, viz. given two k-D operands
- v1 : <s_1 x s_2 x .. x s_k x type> and
- v2 : <t_1 x t_2 x .. x t_k x type>
- we have s_i = t_i for all 1 < i <= k
+ - Either, the two operands and the result must have the same
+ rank and trailing dimension sizes, viz. given two k-D operands
+ v1 : <s_1 x s_2 x .. x s_k x type> and
+ v2 : <t_1 x t_2 x .. x t_k x type>
+ we have s_i = t_i for all 1 < i <= k
+ - Or, the two operands must be 0-D vectors and the result is a 1-D vector.
* the mask length equals the leading dimension size of the result
* numbering the input vector indices left to right across the operands, all
mask values must be within range, viz. given two k-D operands v1 and v2
@@ -467,12 +470,15 @@ def Vector_ShuffleOp :
: vector<2x16xf32>, vector<1x16xf32> ; yields vector<3x16xf32>
%2 = vector.shuffle %a, %b[3, 2, 1, 0]
: vector<2xf32>, vector<2xf32> ; yields vector<4xf32>
+ %3 = vector.shuffle %a, %b[0, 1]
+ : vector<f32>, vector<f32> ; yields vector<2xf32>
```
}];
let builders = [
OpBuilder<(ins "Value":$v1, "Value":$v2, "ArrayRef<int64_t>")>
];
let hasFolder = 1;
+ let hasCanonicalizer = 1;
let extraClassDeclaration = [{
static StringRef getMaskAttrStrName() { return "mask"; }
VectorType getV1VectorType() {
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 8804f971ad5d2..ce1168dbcfb33 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -595,13 +595,15 @@ class VectorShuffleOpConversion
// Get rank and dimension sizes.
int64_t rank = vectorType.getRank();
- assert(v1Type.getRank() == rank);
- assert(v2Type.getRank() == rank);
- int64_t v1Dim = v1Type.getDimSize(0);
-
- // For rank 1, where both operands have *exactly* the same vector type,
- // there is direct shuffle support in LLVM. Use it!
- if (rank == 1 && v1Type == v2Type) {
+ bool wellFormed0DCase =
+ v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
+ bool wellFormedNDCase =
+ v1Type.getRank() == rank && v2Type.getRank() == rank;
+ assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed");
+
+ // For rank 0 and 1, where both operands have *exactly* the same vector
+ // type, there is direct shuffle support in LLVM. Use it!
+ if (rank <= 1 && v1Type == v2Type) {
Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
loc, adaptor.getV1(), adaptor.getV2(),
LLVM::convertArrayToIndices<int32_t>(maskArrayAttr));
@@ -610,6 +612,7 @@ class VectorShuffleOpConversion
}
// For all other cases, insert the individual values individually.
+ int64_t v1Dim = v1Type.getDimSize(0);
Type eltType;
if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>())
eltType = arrayType.getElementType();
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ef37005ddc913..3d3b872d61edd 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1794,8 +1794,11 @@ LogicalResult ShuffleOp::verify() {
int64_t resRank = resultType.getRank();
int64_t v1Rank = v1Type.getRank();
int64_t v2Rank = v2Type.getRank();
- if (resRank != v1Rank || v1Rank != v2Rank)
+ bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
+ bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
+ if (!wellFormed0DCase && !wellFormedNDCase)
return emitOpError("rank mismatch");
+
// Verify all but leading dimension sizes.
for (int64_t r = 1; r < v1Rank; ++r) {
int64_t resDim = resultType.getDimSize(r);
@@ -1812,7 +1815,8 @@ LogicalResult ShuffleOp::verify() {
if (maskLength != resultType.getDimSize(0))
return emitOpError("mask length mismatch");
// Verify all indices.
- int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0);
+ int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
+ (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
for (const auto &en : llvm::enumerate(maskAttr)) {
auto attr = en.value().dyn_cast<IntegerAttr>();
if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
@@ -1828,12 +1832,15 @@ ShuffleOp::inferReturnTypes(MLIRContext *, Optional<Location>,
SmallVectorImpl<Type> &inferredReturnTypes) {
ShuffleOp::Adaptor op(operands, attributes);
auto v1Type = op.getV1().getType().cast<VectorType>();
- // Construct resulting type: leading dimension matches mask length,
- // all trailing dimensions match the operands.
+ auto v1Rank = v1Type.getRank();
+ // Construct resulting type: leading dimension matches mask
+ // length, all trailing dimensions match the operands.
SmallVector<int64_t, 4> shape;
- shape.reserve(v1Type.getRank());
+ shape.reserve(v1Rank);
shape.push_back(std::max<size_t>(1, op.getMask().size()));
- llvm::append_range(shape, v1Type.getShape().drop_front());
+ // In the 0-D case there is no trailing shape to append.
+ if (v1Rank > 0)
+ llvm::append_range(shape, v1Type.getShape().drop_front());
inferredReturnTypes.push_back(
VectorType::get(shape, v1Type.getElementType()));
return success();
@@ -1849,9 +1856,15 @@ static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width) {
}
OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) {
+ VectorType v1Type = getV1VectorType();
+ // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
+ // but must be a canonicalization into a vector.broadcast.
+ if (v1Type.getRank() == 0)
+ return {};
+
// fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1
- if (!getV1VectorType().isScalable() &&
- isStepIndexArray(getMask(), 0, getV1VectorType().getDimSize(0)))
+ if (!v1Type.isScalable() &&
+ isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
return getV1();
// fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2
if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
@@ -1887,6 +1900,30 @@ OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) {
namespace {
+// Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
+// to a broadcast.
+struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
+ using OpRewritePattern<ShuffleOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
+ PatternRewriter &rewriter) const override {
+ VectorType v1VectorType = shuffleOp.getV1VectorType();
+ ArrayAttr mask = shuffleOp.getMask();
+ if (v1VectorType.getRank() > 0)
+ return failure();
+ if (mask.size() != 1)
+ return failure();
+ Type resType = VectorType::Builder(v1VectorType).setShape({1});
+ if (mask[0].cast<IntegerAttr>().getInt() == 0)
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
+ shuffleOp.getV1());
+ else
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
+ shuffleOp.getV2());
+ return success();
+ }
+};
+
/// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
public:
@@ -1912,7 +1949,7 @@ class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ShuffleSplat>(context);
+ results.add<ShuffleSplat, Canonicalize0DShuffleOp>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 828fc22f18346..ed4d398780e16 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -416,6 +416,19 @@ func.func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: v
// CHECK: %[[T19:.*]] = builtin.unrealized_conversion_cast %[[T18]] : !llvm.array<2 x vector<3xf32>> to vector<2x3xf32>
// CHECK: return %[[T19]] : vector<2x3xf32>
+
+// -----
+
+func.func @shuffle_0D_direct(%arg0: vector<f32>) -> vector<3xf32> {
+ %1 = vector.shuffle %arg0, %arg0 [0, 1, 0] : vector<f32>, vector<f32>
+ return %1 : vector<3xf32>
+}
+// CHECK-LABEL: @shuffle_0D_direct(
+// CHECK-SAME: %[[A:.*]]: vector<f32>
+// CHECK: %[[c:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
+// CHECK: %[[s:.*]] = llvm.shufflevector %[[c]], %[[c]] [0, 1, 0] : vector<1xf32>
+// CHECK: return %[[s]] : vector<3xf32>
+
// -----
func.func @shuffle_1D_direct(%arg0: vector<2xf32>, %arg1: vector<2xf32>) -> vector<2xf32> {
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index be86fe91c17c1..6fe6c2776f563 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1487,6 +1487,13 @@ func.func @shuffle_1d() -> vector<4xi32> {
return %shuffle : vector<4xi32>
}
+// CHECK-LABEL: func @shuffle_canonicalize_0d
+func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
+ // CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
+ %shuffle = vector.shuffle %v0, %v1 [0] : vector<i32>, vector<i32>
+ return %shuffle : vector<1xi32>
+}
+
// CHECK-LABEL: func @shuffle_fold1
// CHECK: %arg0 : vector<4xi32>
func.func @shuffle_fold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<4xi32> {
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index fa2516466ad50..3ebf6dffd0ec3 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -56,6 +56,13 @@ func.func @shuffle_rank_mismatch(%arg0: vector<2xf32>, %arg1: vector<4x2xf32>) {
%1 = vector.shuffle %arg0, %arg1 [0, 1] : vector<2xf32>, vector<4x2xf32>
}
+// -----
+
+func.func @shuffle_rank_mismatch_0d(%arg0: vector<f32>, %arg1: vector<1xf32>) {
+ // expected-error at +1 {{'vector.shuffle' op rank mismatch}}
+ %1 = vector.shuffle %arg0, %arg1 [0, 1] : vector<f32>, vector<1xf32>
+}
+
// -----
func.func @shuffle_trailing_dim_size_mismatch(%arg0: vector<2x2xf32>, %arg1: vector<2x4xf32>) {
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index e4e260a37bb13..e026965eae648 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -165,6 +165,13 @@ func.func @vector_broadcast(%a: f32, %b: vector<f32>, %c: vector<16xf32>, %d: ve
return %4 : vector<8x16xf32>
}
+// CHECK-LABEL: @shuffle0D
+func.func @shuffle0D(%a: vector<f32>) -> vector<3xf32> {
+ // CHECK: vector.shuffle %{{.*}}, %{{.*}}[0, 1, 0] : vector<f32>, vector<f32>
+ %1 = vector.shuffle %a, %a[0, 1, 0] : vector<f32>, vector<f32>
+ return %1 : vector<3xf32>
+}
+
// CHECK-LABEL: @shuffle1D
func.func @shuffle1D(%a: vector<2xf32>, %b: vector<4xf32>) -> vector<2xf32> {
// CHECK: vector.shuffle %{{.*}}, %{{.*}}[0, 1, 2, 3] : vector<2xf32>, vector<2xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
index 8a100dcc91b35..4065d7b865fa0 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
@@ -127,6 +127,13 @@ func.func @transpose_0d(%arg: vector<i32>) {
return
}
+func.func @shuffle_0d(%v0: vector<i32>, %v1: vector<i32>) {
+ %1 = vector.shuffle %v0, %v1 [0, 1, 0] : vector<i32>, vector<i32>
+ // CHECK: ( 42, 43, 42 )
+ vector.print %1: vector<3xi32>
+ return
+}
+
func.func @entry() {
%0 = arith.constant 42.0 : f32
%1 = arith.constant dense<0.0> : vector<f32>
@@ -159,7 +166,9 @@ func.func @entry() {
%5 = arith.constant dense<4.0> : vector<f32>
call @fma_0d(%5) : (vector<f32>) -> ()
%6 = arith.constant dense<42> : vector<i32>
+ %7 = arith.constant dense<43> : vector<i32>
call @transpose_0d(%6) : (vector<i32>) -> ()
+ call @shuffle_0d(%6, %7) : (vector<i32>, vector<i32>) -> ()
return
}
More information about the Mlir-commits
mailing list