[Mlir-commits] [mlir] c6ff244 - [mlir][vector] Add `vector.from_elements` op (#95938)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 19 00:58:40 PDT 2024


Author: Matthias Springer
Date: 2024-06-19T09:58:37+02:00
New Revision: c6ff2446a4650f23afc9faffb55020aa68cf678c

URL: https://github.com/llvm/llvm-project/commit/c6ff2446a4650f23afc9faffb55020aa68cf678c
DIFF: https://github.com/llvm/llvm-project/commit/c6ff2446a4650f23afc9faffb55020aa68cf678c.diff

LOG: [mlir][vector] Add `vector.from_elements` op (#95938)

This commit adds a new operation to the vector dialect:
`vector.from_elements`

The op constructs a new vector from a given list of scalar values. It is
similar to `tensor.from_elements`.
```mlir
%0 = vector.from_elements %a, %b, %c, %a, %a, %a : vector<2x3xf32>
```

Constructing a new vector from elements was tedious before this op
existed: a typical way was to define an `arith.constant ... :
vector<...>`, followed by a chain of `vector.insert`.

Folders/canonicalizations are added that can fold `vector.extract` ops
and convert the `vector.from_elements` op into a `vector.splat` op.

The LLVM lowering generates an `llvm.mlir.undef`, followed by a sequence
of scalar insertions in the form of `llvm.insertelement`. Only 0-D and
1-D vectors are currently supported in the LLVM lowering.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/canonicalize.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 56d866ac5b40c..c30996351c672 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -720,10 +720,9 @@ def Vector_ExtractOp :
       return getStaticPosition().size();
     }
 
+    /// Return "true" if the op has at least one dynamic position.
     bool hasDynamicPosition() {
-      auto dynPos = getDynamicPosition();
-      return std::any_of(dynPos.begin(), dynPos.end(),
-                         [](Value operand) { return operand != nullptr; });
+      return !getDynamicPosition().empty();
     }
   }];
 
@@ -769,6 +768,41 @@ def Vector_FMAOp :
   }];
 }
 
+def Vector_FromElementsOp : Vector_Op<"from_elements", [
+    Pure,
+    TypesMatchWith<"operand types match result element type",
+                   "result", "elements", "SmallVector<Type>("
+                   "::llvm::cast<VectorType>($_self).getNumElements(), "
+                   "::llvm::cast<VectorType>($_self).getElementType())">]> {
+  let summary = "operation that defines a vector from scalar elements";
+  let description = [{
+    This operation defines a vector from one or multiple scalar elements. The
+    number of elements must match the number of elements in the result type.
+    All elements must have the same type, which must match the element type of
+    the result vector type.
+
+    `elements` are a flattened version of the result vector in row-major order.
+
+    Example:
+
+    ```mlir
+    // %f1
+    %0 = vector.from_elements %f1 : vector<f32>
+    // [%f1, %f2]
+    %1 = vector.from_elements %f1, %f2 : vector<2xf32>
+    // [[%f1, %f2, %f3], [%f4, %f5, %f6]]
+    %2 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<2x3xf32>
+    // [[[%f1, %f2]], [[%f3, %f4]], [[%f5, %f6]]]
+    %3 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<3x1x2xf32>
+    ```
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$elements);
+  let results = (outs AnyVectorOfAnyRank:$result);
+  let assemblyFormat = "$elements attr-dict `:` type($result)";
+  let hasCanonicalizer = 1;
+}
+
 def Vector_InsertElementOp :
   Vector_Op<"insertelement", [Pure,
      TypesMatchWith<"source operand type matches element type of result",

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 60f7e95ade689..0eac55255b133 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1836,6 +1836,30 @@ struct VectorDeinterleaveOpLowering
   }
 };
 
+/// Conversion pattern for a `vector.from_elements`.
+struct VectorFromElementsLowering
+    : public ConvertOpToLLVMPattern<vector::FromElementsOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = fromElementsOp.getLoc();
+    VectorType vectorType = fromElementsOp.getType();
+    // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
+    // Such ops should be handled in the same way as vector.insert.
+    if (vectorType.getRank() > 1)
+      return rewriter.notifyMatchFailure(fromElementsOp,
+                                         "rank > 1 vectors are not supported");
+    Type llvmType = typeConverter->convertType(vectorType);
+    Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
+    for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
+      result = rewriter.create<vector::InsertOp>(loc, val, result, idx);
+    rewriter.replaceOp(fromElementsOp, result);
+    return success();
+  }
+};
+
 } // namespace
 
 /// Populate the given list with patterns that convert from Vector to LLVM.
@@ -1861,7 +1885,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
                VectorSplatOpLowering, VectorSplatNdOpLowering,
                VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
                MaskedReductionOpConversion, VectorInterleaveOpLowering,
-               VectorDeinterleaveOpLowering>(converter);
+               VectorDeinterleaveOpLowering, VectorFromElementsLowering>(
+      converter);
   // Transfer ops with rank > 1 are handled by VectorToSCF.
   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
 }

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2bf4f16f96e6a..89805d90ea1b0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1877,6 +1877,45 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
   return Value();
 }
 
+/// Try to fold the extraction of a scalar from a vector defined by
+/// vector.from_elements. E.g.:
+///
+/// %0 = vector.from_elements %a, %b : vector<2xf32>
+/// %1 = vector.extract %0[0] : f32 from vector<2xf32>
+/// ==> fold to %a
+static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
+  // Dynamic extractions cannot be folded.
+  if (extractOp.hasDynamicPosition())
+    return {};
+
+  // Look for extract(from_elements).
+  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
+  if (!fromElementsOp)
+    return {};
+
+  // Scalable vectors are not supported.
+  auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
+  if (vecType.isScalable())
+    return {};
+
+  // Only extractions of scalars are supported.
+  int64_t rank = vecType.getRank();
+  ArrayRef<int64_t> indices = extractOp.getStaticPosition();
+  if (extractOp.getType() != vecType.getElementType())
+    return {};
+  assert(static_cast<int64_t>(indices.size()) == rank &&
+         "unexpected number of indices");
+
+  // Compute flattened/linearized index and fold to operand.
+  int flatIndex = 0;
+  int stride = 1;
+  for (int i = rank - 1; i >= 0; --i) {
+    flatIndex += indices[i] * stride;
+    stride *= vecType.getDimSize(i);
+  }
+  return fromElementsOp.getElements()[flatIndex];
+}
+
 OpFoldResult ExtractOp::fold(FoldAdaptor) {
   // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
   // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -1895,6 +1934,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
     return val;
   if (auto val = foldExtractStridedOpFromInsertChain(*this))
     return val;
+  if (auto val = foldScalarExtractFromFromElements(*this))
+    return val;
   return OpFoldResult();
 }
 
@@ -2099,6 +2140,52 @@ LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
   return success();
 }
 
+/// Try to canonicalize the extraction of a subvector from a vector defined by
+/// vector.from_elements. E.g.:
+///
+/// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
+/// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
+/// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
+LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
+                                          PatternRewriter &rewriter) {
+  // Dynamic positions are not supported.
+  if (extractOp.hasDynamicPosition())
+    return failure();
+
+  // Scalar extracts are handled by the folder.
+  auto resultType = dyn_cast<VectorType>(extractOp.getType());
+  if (!resultType)
+    return failure();
+
+  // Look for extracts from a from_elements op.
+  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
+  if (!fromElementsOp)
+    return failure();
+  VectorType inputType = fromElementsOp.getType();
+
+  // Scalable vectors are not supported.
+  if (resultType.isScalable() || inputType.isScalable())
+    return failure();
+
+  // Compute the position of first extracted element and flatten/linearize the
+  // position.
+  SmallVector<int64_t> firstElementPos =
+      llvm::to_vector(extractOp.getStaticPosition());
+  firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
+  int flatIndex = 0;
+  int stride = 1;
+  for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
+    flatIndex += firstElementPos[i] * stride;
+    stride *= inputType.getDimSize(i);
+  }
+
+  // Replace the op with a smaller from_elements op.
+  rewriter.replaceOpWithNewOp<FromElementsOp>(
+      extractOp, resultType,
+      fromElementsOp.getElements().slice(flatIndex,
+                                         resultType.getNumElements()));
+  return success();
+}
 } // namespace
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -2106,6 +2193,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
               ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
+  results.add(foldExtractFromFromElements);
 }
 
 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
@@ -2122,6 +2210,29 @@ std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
   return llvm::to_vector<4>(getVectorType().getShape());
 }
 
+//===----------------------------------------------------------------------===//
+// FromElementsOp
+//===----------------------------------------------------------------------===//
+
+/// Rewrite a vector.from_elements into a vector.splat if all elements are the
+/// same SSA value. E.g.:
+///
+/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
+/// ==> rewrite to vector.splat %a : vector<3xf32>
+static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
+                                                PatternRewriter &rewriter) {
+  if (!llvm::all_equal(fromElementsOp.getElements()))
+    return failure();
+  rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
+                                       fromElementsOp.getElements().front());
+  return success();
+}
+
+void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                 MLIRContext *context) {
+  results.add(rewriteFromElementsAsSplat);
+}
+
 //===----------------------------------------------------------------------===//
 // BroadcastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index bf4281ebcdec9..09b79708a9ab2 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2590,3 +2590,34 @@ func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> {
   %0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64>
   return %0 : vector<2x2xi64>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @vector_from_elements_1d(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+//       CHECK:   %[[undef:.*]] = llvm.mlir.undef : vector<3xf32>
+//       CHECK:   %[[c0:.*]] = llvm.mlir.constant(0 : i64) : i64
+//       CHECK:   %[[insert0:.*]] = llvm.insertelement %[[a]], %[[undef]][%[[c0]] : i64] : vector<3xf32>
+//       CHECK:   %[[c1:.*]] = llvm.mlir.constant(1 : i64) : i64
+//       CHECK:   %[[insert1:.*]] = llvm.insertelement %[[b]], %[[insert0]][%[[c1]] : i64] : vector<3xf32>
+//       CHECK:   %[[c2:.*]] = llvm.mlir.constant(2 : i64) : i64
+//       CHECK:   %[[insert2:.*]] = llvm.insertelement %[[a]], %[[insert1]][%[[c2]] : i64] : vector<3xf32>
+//       CHECK:   return %[[insert2]]
+func.func @vector_from_elements_1d(%a: f32, %b: f32) -> vector<3xf32> {
+  %0 = vector.from_elements %a, %b, %a : vector<3xf32>
+  return %0 : vector<3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_from_elements_0d(
+//  CHECK-SAME:     %[[a:.*]]: f32)
+//       CHECK:   %[[undef:.*]] = llvm.mlir.undef : vector<1xf32>
+//       CHECK:   %[[c0:.*]] = llvm.mlir.constant(0 : i64) : i64
+//       CHECK:   %[[insert0:.*]] = llvm.insertelement %[[a]], %[[undef]][%[[c0]] : i64] : vector<1xf32>
+//       CHECK:   %[[cast:.*]] = builtin.unrealized_conversion_cast %[[insert0]] : vector<1xf32> to vector<f32>
+//       CHECK:   return %[[cast]]
+func.func @vector_from_elements_0d(%a: f32) -> vector<f32> {
+  %0 = vector.from_elements %a : vector<f32>
+  return %0 : vector<f32>
+}

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index caccd1f1c9c24..8181f1a8c5d13 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2642,3 +2642,72 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
   // CHECK:   return %[[a]], %[[a]], %[[extract1]], %[[a]], %[[a]], %[[extract2]], %[[extract3]]
   return %1, %3, %5, %7, %9, %10, %11 : f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @extract_scalar_from_from_elements(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
+  // Extract from 0D.
+  %0 = vector.from_elements %a : vector<f32>
+  %1 = vector.extract %0[] : f32 from vector<f32>
+
+  // Extract from 1D.
+  %2 = vector.from_elements %a : vector<1xf32>
+  %3 = vector.extract %2[0] : f32 from vector<1xf32>
+  %4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
+  %5 = vector.extract %4[4] : f32 from vector<5xf32>
+
+  // Extract from 2D.
+  %6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
+  %7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
+  %8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
+  %9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
+  %10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>
+
+  // CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
+  return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_1d_from_from_elements(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
+  %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
+  // CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
+  %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
+  // CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
+  %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
+  // CHECK: return %[[splat1]], %[[splat2]]
+  return %1, %2 : vector<3xf32>, vector<3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_2d_from_from_elements(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
+  %0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
+  // CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
+  %1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
+  // CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
+  %2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
+  // CHECK: return %[[splat1]], %[[splat2]]
+  return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @from_elements_to_splat(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
+  // CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
+  %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
+  // CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
+  %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
+  // CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
+  %2 = vector.from_elements %a : vector<f32>
+  // CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
+  return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
+}

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 1516f51fe1458..d0eaed8f98cc5 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1854,3 +1854,20 @@ func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) {
   %0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<[2]xf32>)
   return
 }
+
+// -----
+
+func.func @invalid_from_elements(%a: f32) {
+  // expected-error @+1 {{'vector.from_elements' 1 operands present, but expected 2}}
+  vector.from_elements %a : vector<2xf32>
+  return
+}
+
+// -----
+
+// expected-note @+1 {{prior use here}}
+func.func @invalid_from_elements(%a: f32, %b: i32) {
+  // expected-error @+1 {{use of value '%b' expects 
diff erent type than prior uses: 'f32' vs 'i32'}}
+  vector.from_elements %a, %b : vector<2xf32>
+  return
+}

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index c868c881d079a..4da09584db88b 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1158,3 +1158,17 @@ func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4
   %0, %1 = vector.deinterleave %arg : vector<2x3x4x[6]xf32> -> vector<2x3x4x[3]xf32>
   return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>
 }
+
+// CHECK-LABEL: func @from_elements(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @from_elements(%a: f32, %b: f32) -> (vector<f32>, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32>) {
+  // CHECK: vector.from_elements %[[a]] : vector<f32>
+  %0 = vector.from_elements %a : vector<f32>
+  // CHECK: vector.from_elements %[[a]] : vector<1xf32>
+  %1 = vector.from_elements %a : vector<1xf32>
+  // CHECK: vector.from_elements %[[a]], %[[b]] : vector<1x2xf32>
+  %2 = vector.from_elements %a, %b : vector<1x2xf32>
+  // CHECK: vector.from_elements %[[b]], %[[b]], %[[a]], %[[a]] : vector<2x2xf32>
+  %3 = vector.from_elements %b, %b, %a, %a : vector<2x2xf32>
+  return %0, %1, %2, %3 : vector<f32>, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32>
+}
\ No newline at end of file


        


More information about the Mlir-commits mailing list