[Mlir-commits] [mlir] 8b9c70d - [mlir] Move vector.{to_elements, from_elements} unrolling to `VectorUnroll.cpp` (#159118)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 18 09:03:58 PDT 2025


Author: Erick Ochoa Lopez
Date: 2025-09-18T12:03:54-04:00
New Revision: 8b9c70dcdbc80f01945c6560232717afd228d532

URL: https://github.com/llvm/llvm-project/commit/8b9c70dcdbc80f01945c6560232717afd228d532
DIFF: https://github.com/llvm/llvm-project/commit/8b9c70dcdbc80f01945c6560232717afd228d532.diff

LOG: [mlir] Move vector.{to_elements,from_elements} unrolling to `VectorUnroll.cpp` (#159118)

This PR moves the patterns that unroll vector.to_elements and
vector.from_elements into the file with other vector unrolling
operations. This PR also adds these unrolling patterns into the
`populateVectorUnrollPatterns`. And renames
`populateVectorToElementsLoweringPatterns`
`populateVectorFromElementsLoweringPatterns` to
`populateVectorToElementsUnrollPatterns`
`populateVectorFromElementsUnrollPatterns`.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
    mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
    mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
    mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir

Removed: 
    mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index b896506f29eef..7bd96c8a6d1a1 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -306,20 +306,6 @@ void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
 void populateVectorToFromElementsToShuffleTreePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
-/// Populate the pattern set with the following patterns:
-///
-/// [UnrollFromElements]
-/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
-/// outermost dimension.
-void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
-                                                PatternBenefit benefit = 1);
-
-/// Populate the pattern set with the following patterns:
-///
-/// [UnrollToElements]
-void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
-                                              PatternBenefit benefit = 1);
-
 /// Populate the pattern set with the following patterns:
 ///
 /// [ContractionOpToMatmulOpLowering]

diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 08f439222a9a0..69438011d2287 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -322,6 +322,16 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
                                   const UnrollVectorOptions &options,
                                   PatternBenefit benefit = 1);
 
+/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
+/// outermost dimension of the operand.
+void populateVectorToElementsUnrollPatterns(RewritePatternSet &patterns,
+                                            PatternBenefit benefit = 1);
+
+/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
+/// outermost dimension.
+void populateVectorFromElementsUnrollPatterns(RewritePatternSet &patterns,
+                                              PatternBenefit benefit = 1);
+
 /// Collect a set of leading one dimension removal patterns.
 ///
 /// These patterns insert vector.shape_cast to remove leading one dimensions

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index e516118f75207..5994b64f3d9a5 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -534,7 +534,7 @@ void GpuToLLVMConversionPass::runOnOperation() {
                                                    /*maxTransferRank=*/1);
     // Transform N-D vector.from_elements to 1-D vector.from_elements before
     // conversion.
-    vector::populateVectorFromElementsLoweringPatterns(patterns);
+    vector::populateVectorFromElementsUnrollPatterns(patterns);
     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
       return signalPassFailure();
   }

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 76a7e0f3831a2..a95263bb55f69 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -372,7 +372,7 @@ struct LowerGpuOpsToNVVMOpsPass final
       populateGpuRewritePatterns(patterns);
       // Transform N-D vector.from_elements to 1-D vector.from_elements before
       // conversion.
-      vector::populateVectorFromElementsLoweringPatterns(patterns);
+      vector::populateVectorFromElementsUnrollPatterns(patterns);
       if (failed(applyPatternsGreedily(m, std::move(patterns))))
         return signalPassFailure();
     }

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 0b44ca7ceee42..cae490e5f03e7 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -94,8 +94,8 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorStepLoweringPatterns(patterns);
     populateVectorRankReducingFMAPattern(patterns);
     populateVectorGatherLoweringPatterns(patterns);
-    populateVectorFromElementsLoweringPatterns(patterns);
-    populateVectorToElementsLoweringPatterns(patterns);
+    populateVectorFromElementsUnrollPatterns(patterns);
+    populateVectorToElementsUnrollPatterns(patterns);
     if (armI8MM) {
       if (armNeon)
         arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);

diff  --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 18f105ef62e38..7faa222a9e574 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -146,12 +146,12 @@ void transform::ApplyLowerGatherPatternsOp::populatePatterns(
 
 void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  vector::populateVectorFromElementsLoweringPatterns(patterns);
+  vector::populateVectorFromElementsUnrollPatterns(patterns);
 }
 
 void transform::ApplyUnrollToElementsPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  vector::populateVectorToElementsLoweringPatterns(patterns);
+  vector::populateVectorToElementsUnrollPatterns(patterns);
 }
 
 void transform::ApplyLowerScanPatternsOp::populatePatterns(

diff  --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index fecf445720173..4e0f07af95984 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -3,7 +3,6 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorBitCast.cpp
   LowerVectorBroadcast.cpp
   LowerVectorContract.cpp
-  LowerVectorFromElements.cpp
   LowerVectorGather.cpp
   LowerVectorInterleave.cpp
   LowerVectorMask.cpp
@@ -12,7 +11,6 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorShapeCast.cpp
   LowerVectorShuffle.cpp
   LowerVectorStep.cpp
-  LowerVectorToElements.cpp
   LowerVectorToFromElementsToShuffleTree.cpp
   LowerVectorTransfer.cpp
   LowerVectorTranspose.cpp

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp
deleted file mode 100644
index c22fd54cef46b..0000000000000
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp
+++ /dev/null
@@ -1,65 +0,0 @@
-//===- LowerVectorFromElements.cpp - Lower 'vector.from_elements' op -----===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements target-independent rewrites and utilities to lower the
-// 'vector.from_elements' operation.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
-
-#define DEBUG_TYPE "lower-vector-from-elements"
-
-using namespace mlir;
-
-namespace {
-
-/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
-/// outermost dimension. For example:
-/// ```
-/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
-///
-/// ==>
-///
-/// %0   = ub.poison : vector<2x3xf32>
-/// %v0  = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
-/// %1   = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
-/// %v1  = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
-/// %v   = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
-/// ```
-///
-/// When applied exhaustively, this will produce a sequence of 1-d from_elements
-/// ops.
-struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::FromElementsOp op,
-                                PatternRewriter &rewriter) const override {
-    ValueRange allElements = op.getElements();
-
-    auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
-                                    VectorType subTy, int64_t index) {
-      size_t subTyNumElements = subTy.getNumElements();
-      assert((index + 1) * subTyNumElements <= allElements.size() &&
-             "out of bounds");
-      ValueRange subElements =
-          allElements.slice(index * subTyNumElements, subTyNumElements);
-      return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
-    };
-
-    return unrollVectorOp(op, rewriter, unrollFromElementsFn);
-  }
-};
-
-} // namespace
-
-void mlir::vector::populateVectorFromElementsLoweringPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<UnrollFromElements>(patterns.getContext(), benefit);
-}

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
deleted file mode 100644
index a53a183ec31bc..0000000000000
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
+++ /dev/null
@@ -1,53 +0,0 @@
-//===- LowerVectorToElements.cpp - Lower 'vector.to_elements' op ----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements target-independent rewrites and utilities to lower the
-// 'vector.to_elements' operation.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
-
-#define DEBUG_TYPE "lower-vector-to-elements"
-
-using namespace mlir;
-
-namespace {
-
-struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ToElementsOp op,
-                                PatternRewriter &rewriter) const override {
-
-    TypedValue<VectorType> source = op.getSource();
-    FailureOr<SmallVector<Value>> result =
-        vector::unrollVectorValue(source, rewriter);
-    if (failed(result)) {
-      return failure();
-    }
-    SmallVector<Value> vectors = *result;
-
-    SmallVector<Value> results;
-    for (const Value &vector : vectors) {
-      auto subElements =
-          vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
-      llvm::append_range(results, subElements.getResults());
-    }
-    rewriter.replaceOp(op, results);
-    return success();
-  }
-};
-
-} // namespace
-
-void mlir::vector::populateVectorToElementsLoweringPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<UnrollToElements>(patterns.getContext(), benefit);
-}

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 79786f33a2d46..14639c5f1cdd3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
 #include "llvm/ADT/MapVector.h"
@@ -809,6 +810,55 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
   vector::UnrollVectorOptions options;
 };
 
+/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
+/// outermost dimension of the operand. For example:
+///
+/// ```
+/// %0:4 = vector.to_elements %v : vector<2x2xf32>
+///
+/// ==>
+///
+/// %v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32>
+/// %v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32>
+/// %0:4 = vector.to_elements %v0 : vector<2x2xf32>
+/// %1:4 = vector.to_elements %v1 : vector<2x2xf32>
+/// ```
+///
+/// When this pattern is applied until a fixed-point is reached,
+/// this will produce a sequence of 1-d from_elements
+/// ops.
+struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
+  UnrollToElements(MLIRContext *context,
+                   const vector::UnrollVectorOptions &options,
+                   PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::ToElementsOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::ToElementsOp op,
+                                PatternRewriter &rewriter) const override {
+
+    TypedValue<VectorType> source = op.getSource();
+    FailureOr<SmallVector<Value>> result =
+        vector::unrollVectorValue(source, rewriter);
+    if (failed(result)) {
+      return failure();
+    }
+    SmallVector<Value> vectors = *result;
+
+    SmallVector<Value> results;
+    for (Value vector : vectors) {
+      auto subElements =
+          vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
+      llvm::append_range(results, subElements.getResults());
+    }
+    rewriter.replaceOp(op, results);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 /// This pattern unrolls `vector.step` operations according to the provided
 /// target unroll shape. It decomposes a large step vector into smaller step
 /// vectors (segments) and assembles the result by inserting each computed
@@ -884,6 +934,51 @@ struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
   vector::UnrollVectorOptions options;
 };
 
+/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
+/// outermost dimension. For example:
+/// ```
+/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
+///
+/// ==>
+///
+/// %0   = ub.poison : vector<2x3xf32>
+/// %v0  = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
+/// %1   = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
+/// %v1  = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
+/// %v   = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
+/// ```
+///
+/// When this pattern is applied until a fixed-point is reached,
+/// this will produce a sequence of 1-d from_elements
+/// ops.
+struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
+  UnrollFromElements(MLIRContext *context,
+                     const vector::UnrollVectorOptions &options,
+                     PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::FromElementsOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::FromElementsOp op,
+                                PatternRewriter &rewriter) const override {
+    ValueRange allElements = op.getElements();
+
+    auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
+                                    VectorType subTy, int64_t index) {
+      size_t subTyNumElements = subTy.getNumElements();
+      assert((index + 1) * subTyNumElements <= allElements.size() &&
+             "out of bounds");
+      ValueRange subElements =
+          allElements.slice(index * subTyNumElements, subTyNumElements);
+      return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
+    };
+
+    return unrollVectorOp(op, rewriter, unrollFromElementsFn);
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorUnrollPatterns(
@@ -893,6 +988,19 @@ void mlir::vector::populateVectorUnrollPatterns(
                UnrollContractionPattern, UnrollElementwisePattern,
                UnrollReductionPattern, UnrollMultiReductionPattern,
                UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
-               UnrollStorePattern, UnrollBroadcastPattern, UnrollStepPattern>(
-      patterns.getContext(), options, benefit);
+               UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
+               UnrollToElements, UnrollStepPattern>(patterns.getContext(),
+                                                    options, benefit);
+}
+
+void mlir::vector::populateVectorToElementsUnrollPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<UnrollToElements>(patterns.getContext(), UnrollVectorOptions(),
+                                 benefit);
+}
+
+void mlir::vector::populateVectorFromElementsUnrollPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<UnrollFromElements>(patterns.getContext(), UnrollVectorOptions(),
+                                   benefit);
 }

diff  --git a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
index c85f4334ff2e5..0957f67690b97 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
@@ -96,3 +96,47 @@ func.func @transpose(%arg0 : vector<2x3xi32>) -> (vector<3x2xi32>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
   return %0 : vector<3x2xi32>
 }
+
+// -----
+
+// In order to verify that the pattern is applied,
+// we need to make sure that the the 2d vector does not
+// come from the parameters. Otherwise, the pattern
+// in unrollVectorsInSignatures which splits the 2d vector
+// parameter will take precedent. Similarly, let's avoid
+// returning a vector as another pattern would take precendence.
+
+// CHECK-LABEL: @unroll_to_elements_2d
+func.func @unroll_to_elements_2d() -> (f32, f32, f32, f32) {
+  %1 = "test.op"() : () -> (vector<2x2xf32>)
+  // CHECK: %[[VEC2D:.+]] = "test.op"
+  // CHECK: %[[VEC0:.+]] = vector.extract %[[VEC2D]][0] : vector<2xf32> from vector<2x2xf32>
+  // CHECK: %[[VEC1:.+]] = vector.extract %[[VEC2D]][1] : vector<2xf32> from vector<2x2xf32>
+  // CHECK: %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]]
+  // CHECK: %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]]
+  %2:4 = vector.to_elements %1 : vector<2x2xf32>
+  return %2#0, %2#1, %2#2, %2#3 : f32, f32, f32, f32
+}
+
+// -----
+
+// In order to verify that the pattern is applied,
+// we need to make sure that the the 2d vector is used
+// by an operation and that extracts are not folded away.
+// In other words we can't use "test.op" nor return the
+// value `%0 = vector.from_elements`
+
+// CHECK-LABEL: @unroll_from_elements_2d
+// CHECK-SAME: (%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32, %[[ARG3:.+]]: f32)
+func.func @unroll_from_elements_2d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> (vector<2x2xf32>) {
+  // CHECK: %[[VEC0:.+]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
+  // CHECK: %[[VEC1:.+]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32>
+  %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
+
+  // CHECK: %[[RES0:.+]] = arith.addf %[[VEC0]], %[[VEC0]]
+  // CHECK: %[[RES1:.+]] = arith.addf %[[VEC1]], %[[VEC1]]
+  %1 = arith.addf %0, %0 : vector<2x2xf32>
+
+  // return %[[RES0]], %%[[RES1]] : vector<2xf32>, vector<2xf32>
+  return %1 : vector<2x2xf32>
+}


        


More information about the Mlir-commits mailing list