[Mlir-commits] [mlir] db6f8eb - [mlir][Vector] Support 0-D vectors in ShuffleOp

Nicolas Vasilache llvmlistbot at llvm.org
Mon Aug 29 00:40:06 PDT 2022


Author: Nicolas Vasilache
Date: 2022-08-29T00:39:57-07:00
New Revision: db6f8ebe066f0be13f94418f090c626a228225c4

URL: https://github.com/llvm/llvm-project/commit/db6f8ebe066f0be13f94418f090c626a228225c4
DIFF: https://github.com/llvm/llvm-project/commit/db6f8ebe066f0be13f94418f090c626a228225c4.diff

LOG: [mlir][Vector] Support 0-D vectors in ShuffleOp

Co-authored-by: Michal Terepeta <michalt at google.com>

Reviewed-by: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D115744

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/canonicalize.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index b7783c601a92b..d2fe879dcd2b8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -437,22 +437,25 @@ def Vector_ShuffleOp :
      PredOpTrait<"second operand v2 and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 1>>,
      DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
-     Arguments<(ins AnyVector:$v1, AnyVector:$v2, I64ArrayAttr:$mask)>,
+     Arguments<(ins AnyVectorOfAnyRank:$v1, AnyVectorOfAnyRank:$v2,
+                    I64ArrayAttr:$mask)>,
      Results<(outs AnyVector:$vector)> {
   let summary = "shuffle operation";
   let description = [{
     The shuffle operation constructs a permutation (or duplication) of elements
     from two input vectors, returning a vector with the same element type as
     the input and a length that is the same as the shuffle mask. The two input
-    vectors must have the same element type, rank, and trailing dimension sizes
-    and shuffles their values in the leading dimension (which may 
diff er in size)
-    according to the given mask. The legality rules are:
+    vectors must have the same element type, same rank , and trailing dimension
+    sizes and shuffles their values in the
+    leading dimension (which may 
diff er in size) according to the given mask.
+    The legality rules are:
     * the two operands must have the same element type as the result
-    * the two operands and the result must have the same rank and trailing
-      dimension sizes, viz. given two k-D operands
-              v1 : <s_1 x s_2 x .. x s_k x type> and
-              v2 : <t_1 x t_2 x .. x t_k x type>
-      we have s_i = t_i for all 1 < i <= k
+      - Either, the two operands and the result must have the same
+        rank and trailing dimension sizes, viz. given two k-D operands
+                v1 : <s_1 x s_2 x .. x s_k x type> and
+                v2 : <t_1 x t_2 x .. x t_k x type>
+        we have s_i = t_i for all 1 < i <= k
+      - Or, the two operands must be 0-D vectors and the result is a 1-D vector.
     * the mask length equals the leading dimension size of the result
     * numbering the input vector indices left to right across the operands, all
       mask values must be within range, viz. given two k-D operands v1 and v2
@@ -467,12 +470,15 @@ def Vector_ShuffleOp :
                : vector<2x16xf32>, vector<1x16xf32> ; yields vector<3x16xf32>
     %2 = vector.shuffle %a, %b[3, 2, 1, 0]
                : vector<2xf32>, vector<2xf32>       ; yields vector<4xf32>
+    %3 = vector.shuffle %a, %b[0, 1]
+               : vector<f32>, vector<f32>           ; yields vector<2xf32>
     ```
   }];
   let builders = [
     OpBuilder<(ins "Value":$v1, "Value":$v2, "ArrayRef<int64_t>")>
   ];
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
   let extraClassDeclaration = [{
     static StringRef getMaskAttrStrName() { return "mask"; }
     VectorType getV1VectorType() {

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 8804f971ad5d2..ce1168dbcfb33 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -595,13 +595,15 @@ class VectorShuffleOpConversion
 
     // Get rank and dimension sizes.
     int64_t rank = vectorType.getRank();
-    assert(v1Type.getRank() == rank);
-    assert(v2Type.getRank() == rank);
-    int64_t v1Dim = v1Type.getDimSize(0);
-
-    // For rank 1, where both operands have *exactly* the same vector type,
-    // there is direct shuffle support in LLVM. Use it!
-    if (rank == 1 && v1Type == v2Type) {
+    bool wellFormed0DCase =
+        v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
+    bool wellFormedNDCase =
+        v1Type.getRank() == rank && v2Type.getRank() == rank;
+    assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed");
+
+    // For rank 0 and 1, where both operands have *exactly* the same vector
+    // type, there is direct shuffle support in LLVM. Use it!
+    if (rank <= 1 && v1Type == v2Type) {
       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
           loc, adaptor.getV1(), adaptor.getV2(),
           LLVM::convertArrayToIndices<int32_t>(maskArrayAttr));
@@ -610,6 +612,7 @@ class VectorShuffleOpConversion
     }
 
     // For all other cases, insert the individual values individually.
+    int64_t v1Dim = v1Type.getDimSize(0);
     Type eltType;
     if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>())
       eltType = arrayType.getElementType();

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ef37005ddc913..3d3b872d61edd 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1794,8 +1794,11 @@ LogicalResult ShuffleOp::verify() {
   int64_t resRank = resultType.getRank();
   int64_t v1Rank = v1Type.getRank();
   int64_t v2Rank = v2Type.getRank();
-  if (resRank != v1Rank || v1Rank != v2Rank)
+  bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
+  bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
+  if (!wellFormed0DCase && !wellFormedNDCase)
     return emitOpError("rank mismatch");
+
   // Verify all but leading dimension sizes.
   for (int64_t r = 1; r < v1Rank; ++r) {
     int64_t resDim = resultType.getDimSize(r);
@@ -1812,7 +1815,8 @@ LogicalResult ShuffleOp::verify() {
   if (maskLength != resultType.getDimSize(0))
     return emitOpError("mask length mismatch");
   // Verify all indices.
-  int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0);
+  int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
+                      (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
   for (const auto &en : llvm::enumerate(maskAttr)) {
     auto attr = en.value().dyn_cast<IntegerAttr>();
     if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
@@ -1828,12 +1832,15 @@ ShuffleOp::inferReturnTypes(MLIRContext *, Optional<Location>,
                             SmallVectorImpl<Type> &inferredReturnTypes) {
   ShuffleOp::Adaptor op(operands, attributes);
   auto v1Type = op.getV1().getType().cast<VectorType>();
-  // Construct resulting type: leading dimension matches mask length,
-  // all trailing dimensions match the operands.
+  auto v1Rank = v1Type.getRank();
+  // Construct resulting type: leading dimension matches mask
+  // length, all trailing dimensions match the operands.
   SmallVector<int64_t, 4> shape;
-  shape.reserve(v1Type.getRank());
+  shape.reserve(v1Rank);
   shape.push_back(std::max<size_t>(1, op.getMask().size()));
-  llvm::append_range(shape, v1Type.getShape().drop_front());
+  // In the 0-D case there is no trailing shape to append.
+  if (v1Rank > 0)
+    llvm::append_range(shape, v1Type.getShape().drop_front());
   inferredReturnTypes.push_back(
       VectorType::get(shape, v1Type.getElementType()));
   return success();
@@ -1849,9 +1856,15 @@ static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width) {
 }
 
 OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) {
+  VectorType v1Type = getV1VectorType();
+  // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
+  // but must be a canonicalization into a vector.broadcast.
+  if (v1Type.getRank() == 0)
+    return {};
+
   // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1
-  if (!getV1VectorType().isScalable() &&
-      isStepIndexArray(getMask(), 0, getV1VectorType().getDimSize(0)))
+  if (!v1Type.isScalable() &&
+      isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
     return getV1();
   // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2
   if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
@@ -1887,6 +1900,30 @@ OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) {
 
 namespace {
 
+// Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
+// to a broadcast.
+struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
+  using OpRewritePattern<ShuffleOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType v1VectorType = shuffleOp.getV1VectorType();
+    ArrayAttr mask = shuffleOp.getMask();
+    if (v1VectorType.getRank() > 0)
+      return failure();
+    if (mask.size() != 1)
+      return failure();
+    Type resType = VectorType::Builder(v1VectorType).setShape({1});
+    if (mask[0].cast<IntegerAttr>().getInt() == 0)
+      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
+                                                       shuffleOp.getV1());
+    else
+      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
+                                                       shuffleOp.getV2());
+    return success();
+  }
+};
+
 /// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
 class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
 public:
@@ -1912,7 +1949,7 @@ class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
 
 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ShuffleSplat>(context);
+  results.add<ShuffleSplat, Canonicalize0DShuffleOp>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 828fc22f18346..ed4d398780e16 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -416,6 +416,19 @@ func.func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: v
 // CHECK:       %[[T19:.*]] = builtin.unrealized_conversion_cast %[[T18]] : !llvm.array<2 x vector<3xf32>> to vector<2x3xf32>
 // CHECK:       return %[[T19]] : vector<2x3xf32>
 
+
+// -----
+ 
+func.func @shuffle_0D_direct(%arg0: vector<f32>) -> vector<3xf32> {
+  %1 = vector.shuffle %arg0, %arg0 [0, 1, 0] : vector<f32>, vector<f32>
+  return %1 : vector<3xf32>
+}
+// CHECK-LABEL: @shuffle_0D_direct(
+//  CHECK-SAME:     %[[A:.*]]: vector<f32>
+//       CHECK:   %[[c:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
+//       CHECK:   %[[s:.*]] = llvm.shufflevector %[[c]], %[[c]] [0, 1, 0] : vector<1xf32>
+//       CHECK:   return %[[s]] : vector<3xf32>
+
 // -----
 
 func.func @shuffle_1D_direct(%arg0: vector<2xf32>, %arg1: vector<2xf32>) -> vector<2xf32> {

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index be86fe91c17c1..6fe6c2776f563 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1487,6 +1487,13 @@ func.func @shuffle_1d() -> vector<4xi32> {
   return %shuffle : vector<4xi32>
 }
 
+// CHECK-LABEL: func @shuffle_canonicalize_0d
+func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
+  // CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
+  %shuffle = vector.shuffle %v0, %v1 [0] : vector<i32>, vector<i32>
+  return %shuffle : vector<1xi32>
+}
+
 // CHECK-LABEL: func @shuffle_fold1
 //       CHECK:   %arg0 : vector<4xi32>
 func.func @shuffle_fold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<4xi32> {

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index fa2516466ad50..3ebf6dffd0ec3 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -56,6 +56,13 @@ func.func @shuffle_rank_mismatch(%arg0: vector<2xf32>, %arg1: vector<4x2xf32>) {
   %1 = vector.shuffle %arg0, %arg1 [0, 1] : vector<2xf32>, vector<4x2xf32>
 }
 
+// -----
+ 
+func.func @shuffle_rank_mismatch_0d(%arg0: vector<f32>, %arg1: vector<1xf32>) {
+  // expected-error at +1 {{'vector.shuffle' op rank mismatch}}
+  %1 = vector.shuffle %arg0, %arg1 [0, 1] : vector<f32>, vector<1xf32>
+}
+
 // -----
 
 func.func @shuffle_trailing_dim_size_mismatch(%arg0: vector<2x2xf32>, %arg1: vector<2x4xf32>) {

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index e4e260a37bb13..e026965eae648 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -165,6 +165,13 @@ func.func @vector_broadcast(%a: f32, %b: vector<f32>, %c: vector<16xf32>, %d: ve
   return %4 : vector<8x16xf32>
 }
 
+// CHECK-LABEL: @shuffle0D
+func.func @shuffle0D(%a: vector<f32>) -> vector<3xf32> {
+  // CHECK: vector.shuffle %{{.*}}, %{{.*}}[0, 1, 0] : vector<f32>, vector<f32>
+  %1 = vector.shuffle %a, %a[0, 1, 0] : vector<f32>, vector<f32>
+  return %1 : vector<3xf32>
+}
+
 // CHECK-LABEL: @shuffle1D
 func.func @shuffle1D(%a: vector<2xf32>, %b: vector<4xf32>) -> vector<2xf32> {
   // CHECK: vector.shuffle %{{.*}}, %{{.*}}[0, 1, 2, 3] : vector<2xf32>, vector<2xf32>

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
index 8a100dcc91b35..4065d7b865fa0 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
@@ -127,6 +127,13 @@ func.func @transpose_0d(%arg: vector<i32>) {
   return
 }
 
+func.func @shuffle_0d(%v0: vector<i32>, %v1: vector<i32>) {
+  %1 = vector.shuffle %v0, %v1 [0, 1, 0] : vector<i32>, vector<i32>
+  // CHECK: ( 42, 43, 42 )
+  vector.print %1: vector<3xi32>
+  return
+}
+
 func.func @entry() {
   %0 = arith.constant 42.0 : f32
   %1 = arith.constant dense<0.0> : vector<f32>
@@ -159,7 +166,9 @@ func.func @entry() {
   %5 = arith.constant dense<4.0> : vector<f32>
   call  @fma_0d(%5) : (vector<f32>) -> ()
   %6 = arith.constant dense<42> : vector<i32>
+  %7 = arith.constant dense<43> : vector<i32>
   call @transpose_0d(%6) : (vector<i32>) -> ()
+  call @shuffle_0d(%6, %7) : (vector<i32>, vector<i32>) -> ()
 
   return
 }


        


More information about the Mlir-commits mailing list