[Mlir-commits] [mlir] [mlir][vector] Add 1:N vector to llvm conversion (PR #174240)

Erick Ochoa Lopez llvmlistbot at llvm.org
Mon Jan 5 13:47:02 PST 2026


https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/174240

>From 50797d355e8944c9f8b80826ee963756edfdb387 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 22 Dec 2025 10:01:26 -0500
Subject: [PATCH 1/2] Use one to n type conversion for vector.

---
 mlir/include/mlir/Conversion/Passes.td        |  2 +
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  | 65 +++++++++++++++++++
 2 files changed, 67 insertions(+)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 7f24e58671aab..8c5af3c8529b0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1578,6 +1578,8 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
            clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle16x16, "shuffle16x16",
             "Lower 2-D transpose to `vector.shuffle` on 16x16 vector.")
           )}]>,
+    Option<"enableOneToNConversion", "enable-one-to-n-conversion",
+           "bool", /*default=*/"false", "1:N conversion">,
   ];
 }
 
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index f958edf2746e9..43fd1843cda28 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -114,6 +114,71 @@ void ConvertVectorToLLVMPass::runOnOperation() {
   // Convert to the LLVM IR dialect.
   LowerToLLVMOptions options(&getContext());
   LLVMTypeConverter converter(&getContext(), options);
+
+  if (enableOneToNConversion) {
+
+    converter.addConversion(
+        [&](VectorType type,
+            SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+          auto elementType = converter.convertType(type.getElementType());
+          if (!elementType)
+            return failure();
+          if (type.getShape().empty()) {
+            result.push_back(VectorType::get({1}, elementType));
+            return success();
+          }
+          Type vectorType = VectorType::get(type.getShape().back(), elementType,
+                                            type.getScalableDims().back());
+          assert(LLVM::isCompatibleVectorType(vectorType) &&
+                 "expected vector type compatible with the LLVM dialect");
+          // For n-D vector types for which a _non-trailing_ dim is scalable,
+          // return a failure. Supporting such cases would require LLVM
+          // to support something akin "scalable arrays" of vectors.
+          if (llvm::is_contained(type.getScalableDims().drop_back(), true))
+            return failure();
+
+          ArrayRef<int64_t> shapeLeadingDims = type.getShape().drop_back();
+          int64_t numVectors = ShapedType::getNumElements(shapeLeadingDims);
+          for (int64_t i = 0; i < numVectors; i++)
+            result.push_back(vectorType);
+
+          return success();
+        });
+
+    converter.addTargetMaterialization(
+        [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+            Location loc) -> SmallVector<Value> {
+          // from ('vector<4x4xf32>')
+          // to ('vector<4xf32>', 'vector<4xf32>', 'vector<4xf32>',
+          // 'vector<4xf32>')
+          Type ty = resultTypes[0];
+          for (Type ithTy : resultTypes)
+            if (ithTy != ty)
+              return {};
+
+          if (!isa<VectorType>(ty))
+            return {};
+
+          if (inputs.size() != 1)
+            return {};
+
+          Type inputTy = inputs[0].getType();
+          if (!isa<VectorType>(inputTy))
+            return {};
+
+          VectorType inputVectorTy = cast<VectorType>(inputTy);
+          ArrayRef<int64_t> inputShape = inputVectorTy.getShape();
+          size_t numElements =
+              ShapedType::getNumElements(inputShape.drop_back());
+          if (numElements != resultTypes.size())
+            return {};
+
+          return UnrealizedConversionCastOp::create(builder, loc, resultTypes,
+                                                    inputs)
+              .getResults();
+        });
+  }
+
   RewritePatternSet patterns(&getContext());
   populateVectorTransferLoweringPatterns(patterns);
   populateVectorToLLVMConversionPatterns(

>From 9539494c68f077ccd6b14ded103f43b8c33547e1 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 31 Dec 2025 11:21:21 -0500
Subject: [PATCH 2/2] [mlir] Add 1:N conversion for vector operations.

* Add lowering for vector.insert
* Add lowering for vector.extract
---
 .../VectorToLLVM/ConvertVectorToLLVM.h        |   4 +-
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 392 +++++++++++++++++-
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |  66 +--
 .../VectorToLLVM/vector-to-llvm-one-to-n.mlir | 140 +++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   |   7 +-
 5 files changed, 534 insertions(+), 75 deletions(-)
 create mode 100644 mlir/test/Conversion/VectorToLLVM/vector-to-llvm-one-to-n.mlir

diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index cfb6cc313bc63..d3eedca39ec96 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -15,9 +15,9 @@ class LLVMTypeConverter;
 
 /// Collect a set of patterns to convert from the Vector dialect to LLVM.
 void populateVectorToLLVMConversionPatterns(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    LLVMTypeConverter &converter, RewritePatternSet &patterns,
     bool reassociateFPReductions = false, bool force32BitVectorIndices = false,
-    bool useVectorAlignment = false);
+    bool useVectorAlignment = false, bool enableOneToNConversion = false);
 
 namespace vector {
 void registerConvertVectorToLLVMInterface(DialectRegistry &registry);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 05d541fe80356..914c9ccf1673d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1088,6 +1088,162 @@ class VectorShuffleOpConversion
   }
 };
 
+/// VectorExtractOpConversion rewritten to use 1:N conversion.
+///
+/// Handles the same cases as VectorExtractOp:
+/// * Mostly static indices (with the exception of the innermost dimension index
+///   which can be used to select a scalar.
+///
+/// Translates:
+///
+/// ```mlir
+/// %tgt = vector.extract %src[0] : vector<2xf32> from vector<2x2xf32>
+/// ```
+///
+/// Since 1:N conversion is used, src is now a collection of vectors.
+/// and target will be replaced (without any operations) into the correct
+/// vector from the collection. Following the example above:
+///
+/// %src : {vector<2xf32>, vector<2xf32>}
+/// %tgt = %src[0]
+///
+/// This pattern will insert operations only when extracting scalars.
+/// For example:
+///
+/// ```mlir
+/// %scalar = vector.extract %src[%idx] : f32 from vector<2xf32>
+/// ```
+///
+/// %src : {vector<2xf32>}
+/// %tgt = %src[0]
+///
+/// ```mlir
+/// %scalar = llvm.extractelement %tgt[%idx]
+/// ```
+///
+/// There is another case not present in VectorExtractOp where the left hand
+/// side of the statement is written into multiple operations. E.g.,
+///
+/// ```mlir
+/// %vec = vector.extract %nd[0] : vector<2x2xf32> from vector<2x2x2xf32>
+/// ```
+///
+/// In this case, source is a collection of four vector<2xf32> and target
+/// is a collection of two vector<2xf32>
+///
+/// %src : {vector<2xf32>, vector<2xf32>, vector<2xf32>, vector<2xf32>}
+/// %tgt : {%src[0], %src[1]}
+class VectorExtractOneToNOpConversion
+    : public ConvertOpToLLVMPattern<vector::ExtractOp> {
+public:
+  using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ExtractOp extractOp, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Type resultType = extractOp.getResult().getType();
+    SmallVector<Type> llvmResultTypes;
+    if (failed(typeConverter->convertType(resultType, llvmResultTypes)))
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "expected conversion to succeed.");
+
+    // Unlike VectorExtractOpConversion, dynamicPositions may
+    // be multiple values, even though we don't realistically expect that.
+    // Let's just verify our assumptions and get the input into a single
+    // SmallVector.
+    ArrayRef<ValueRange> dynamicPositions = adaptor.getDynamicPosition();
+    SmallVector<Value> dynamicPositionsSafe;
+    for (auto dynamicPosition : dynamicPositions)
+      if (dynamicPosition.size() > 1)
+        return rewriter.notifyMatchFailure(
+            extractOp, "expected single value in dynamic position.");
+      else
+        dynamicPositionsSafe.push_back(dynamicPosition[0]);
+
+    auto loc = extractOp->getLoc();
+    SmallVector<OpFoldResult> positionVec = getMixedValues(
+        adaptor.getStaticPosition(), dynamicPositionsSafe, rewriter);
+
+    // The Vector -> LLVM 1:N lowering models N-D vectors as a collection of
+    // 1-d vectors. We do this conversion by:
+    //  - Selecting the correct values that correspond to the target vector.
+    //    No operations are produced at this stage.
+    //  - Extract a scalar out of the vector if needed. This is done using
+    //   `llvm.extractelement`.
+
+    // Determine if we need to extract a scalar as the result. We extract
+    // a scalar if the extract is full rank, i.e., the number of indices is
+    // equal to source vector rank.
+    bool extractsScalar = static_cast<int64_t>(positionVec.size()) ==
+                          extractOp.getSourceVectorType().getRank();
+
+    // Since the LLVM type converter converts 0-d vectors to 1-d vectors, we
+    // need to add a position for this change.
+    VectorType sourceTy = extractOp.getSourceVectorType();
+    bool isZeroRank = sourceTy.getRank() == 0;
+    if (isZeroRank) {
+      Type idxType = typeConverter->convertType(rewriter.getIndexType());
+      positionVec.push_back(rewriter.getZeroAttr(idxType));
+    }
+
+    ArrayRef<int64_t> sourceShape = sourceTy.getShape();
+
+    SmallVector<int64_t> strides(sourceShape);
+    if (!isZeroRank) {
+      strides[strides.size() - 1] = 1;
+    } else {
+      strides.push_back(1);
+    }
+
+    for (int64_t i = strides.size() - 2; i >= 0; --i)
+      strides[i] *= strides[i + 1];
+
+    // Unlike VectorExtractOp, a source here will be a SmallVector<Value>
+    SmallVector<Value> extracted = adaptor.getSource();
+    ArrayRef<Value> selected;
+
+    ArrayRef<OpFoldResult> position(positionVec);
+
+    // If we are extracting a scalar from the extracted member, we drop
+    // the last index, which will be used to extract the scalar out of the
+    // vector.
+    if (extractsScalar)
+      position = position.drop_back();
+
+    if (!llvm::all_of(position, llvm::IsaPred<Attribute>))
+      return rewriter.notifyMatchFailure(
+          extractOp, "expected leading indices to be statically known.");
+
+    SmallVector<int64_t> positionInts = getAsIntegers(position);
+
+    int64_t linearIdx = 0;
+    for (auto [offset, coeff] :
+         llvm::zip(llvm::reverse(positionInts), llvm::reverse(strides)))
+      linearIdx += offset * coeff;
+
+    selected =
+        ArrayRef<Value>(extracted).slice(linearIdx, llvmResultTypes.size());
+
+    if (extractsScalar && selected.size() != 1)
+      return rewriter.notifyMatchFailure(
+          extractOp, "expected selected vectors to be a single vector");
+
+    SmallVector<Value> replacements;
+    if (extractsScalar) {
+      Value scalar = LLVM::ExtractElementOp::create(
+          rewriter, loc, selected[0],
+          getAsLLVMValue(rewriter, loc, positionVec.back()));
+      replacements.push_back(scalar);
+    } else {
+      replacements = SmallVector<Value>(selected);
+    }
+
+    rewriter.replaceOpWithMultiple(extractOp, {replacements});
+    return success();
+  }
+};
+
 class VectorExtractOpConversion
     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
 public:
@@ -1191,6 +1347,158 @@ class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
   }
 };
 
+/// VectorInsertOpConversion rewritten to use 1:N conversion.
+///
+/// Handles the same cases as VectorInsertOp:
+/// * Mostly static indices (with the exception of the innermost dimension index
+///   which can be used to select a scalar.
+///
+/// ```mlir
+/// %upd = vector.insert %src, %tgt[0] : vector<2xf32> into vector<2x2xf32>
+/// ```
+///
+/// Since 1:N conversion is used, src and tgt may now be a collection of
+/// vectors. Update will be replaced into the correct vectors from src and tgt
+/// the collection. Following the example above:
+///
+/// %src : {vector<2xf32>}
+/// %tgt : {vector<2x2xf32>, vector<2x2xf32>}
+/// %upd : {%src[0], %tgt[1]}
+///
+/// This pattern will insert operations only when inserting scalars.
+/// For example:
+///
+/// ```mlir
+/// %upd = vector.insert %src, tgt[%idx] : f32 into vector<2xf32>
+/// ```
+///
+/// %src : {f32}
+/// %tgt : {vector<2xf32>}
+///
+/// ```mlir
+/// %upd = llvm.insertelement %src, %tgt[%idx] : vector<2xf32>
+/// ```
+class VectorInsertOneToNOpConversion
+    : public ConvertOpToLLVMPattern<vector::InsertOp> {
+public:
+  using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::InsertOp insertOp, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto loc = insertOp->getLoc();
+    auto destVectorType = insertOp.getDestVectorType();
+    SmallVector<Type> llvmResultTypes;
+    if (failed(typeConverter->convertType(destVectorType, llvmResultTypes)))
+      return rewriter.notifyMatchFailure(insertOp,
+                                         "expected conversion to succeed.");
+
+    // Unlike VectorInsertOpConversion, dynamicPositions may
+    // be multiple values, even though we don't realistically expect that.
+    // Let's just verify our assumptions and get the input into a single
+    // SmallVector.
+    ArrayRef<ValueRange> dynamicPositions = adaptor.getDynamicPosition();
+    SmallVector<Value> dynamicPositionsSafe;
+    for (auto dynamicPosition : dynamicPositions)
+      if (dynamicPosition.size() > 1)
+        return rewriter.notifyMatchFailure(
+            insertOp, "expected single value in dynamic position.");
+      else
+        dynamicPositionsSafe.push_back(dynamicPosition[0]);
+
+    SmallVector<OpFoldResult> positionVec = getMixedValues(
+        adaptor.getStaticPosition(), dynamicPositionsSafe, rewriter);
+
+    // The logic in this pattern mirrors VectorExtractOneToNOpConversion. Refer
+    // to its explanatory comment about how N-D vectors are converted to a
+    // collection of vectors.
+    //
+    // The innermost dimension of the destination vector, when converted to a
+    // collection of vectors, will always be a 1D vector.
+    //
+    // * If the insertion is happening into the innermost dimension of the
+    //   destination vector:
+    //   - Select the appropriate vectors that correspond to the position
+    //   indices.
+    //     Unlike VectorInsertOpConversion, the inserted element may be
+    //     converted to multiple values.
+    //   - From the selection just done, use the innermost dimension's index
+    //     to decide where to insert. This is done with
+    //     llvm.insertelement.
+    // * Return the original destination vector but with the elements selected
+    //   above replacing the original ones.
+
+    // Determine if we need to insert a scalar into the 1D vector.
+    bool insertIntoInnermostDim =
+        static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
+
+    bool isZeroRank = destVectorType.getRank() == 0;
+
+    ArrayRef<OpFoldResult> positionOf1DVector(
+        positionVec.begin(), insertIntoInnermostDim && !isZeroRank
+                                 ? positionVec.size() - 1
+                                 : positionVec.size());
+
+    if (!llvm::all_of(positionOf1DVector, llvm::IsaPred<Attribute>))
+      return rewriter.notifyMatchFailure(
+          insertOp,
+          "dynamic dimensions are not supported for picking 1d vectors.");
+
+    OpFoldResult positionOfScalarWithin1DVector;
+    if (isZeroRank) {
+      // Since the LLVM type converter converts 0D vectors to 1D vectors, we
+      // need to create a 0 here as the position into the 1D vector.
+      Type idxType = typeConverter->convertType(rewriter.getIndexType());
+      positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
+    } else if (insertIntoInnermostDim) {
+      positionOfScalarWithin1DVector = positionVec.back();
+    }
+
+    ArrayRef<int64_t> destShape = destVectorType.getShape();
+    SmallVector<int64_t> strides(destShape);
+    if (!isZeroRank) {
+      strides[strides.size() - 1] = 1;
+    } else {
+      strides.push_back(1);
+    }
+
+    for (int i = strides.size() - 2; i >= 0; --i) {
+      strides[i] *= strides[i + 1];
+    }
+
+    SmallVector<int64_t> positionInts = getAsIntegers(positionOf1DVector);
+    int64_t linearIdx = 0;
+    for (auto [offset, coeff] :
+         llvm::zip(llvm::reverse(positionInts), llvm::reverse(strides)))
+      linearIdx += offset * coeff;
+
+    SmallVector<Value> sources = adaptor.getValueToStore();
+    SmallVector<Value> dests = adaptor.getDest();
+
+    SmallVector<Value> selected(
+        ArrayRef<Value>(dests).slice(linearIdx, sources.size()));
+
+    if (insertIntoInnermostDim) {
+      assert(selected.size() == 1 && "expected selected to be a scalar");
+      Value destVector = selected[0];
+      assert(sources.size() == 1 && "expected to to store one scalar");
+      Value scalar = sources[0];
+
+      // Insert the scalar into the 1D vector.
+      sources[0] = LLVM::InsertElementOp::create(
+          rewriter, loc, destVector.getType(), destVector, scalar,
+          getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
+    }
+
+    for (auto [idx, val] : llvm::enumerate(sources))
+      dests[linearIdx + idx] = val;
+
+    rewriter.replaceOpWithMultiple(insertOp, {dests});
+    return success();
+  }
+};
+
 class VectorInsertOpConversion
     : public ConvertOpToLLVMPattern<vector::InsertOp> {
 public:
@@ -2196,10 +2504,12 @@ void mlir::vector::populateVectorTransposeToFlatTranspose(
 }
 
 /// Populate the given list with patterns that convert from Vector to LLVM.
-void mlir::populateVectorToLLVMConversionPatterns(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    bool reassociateFPReductions, bool force32BitVectorIndices,
-    bool useVectorAlignment) {
+void mlir::populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
+                                                  RewritePatternSet &patterns,
+                                                  bool reassociateFPReductions,
+                                                  bool force32BitVectorIndices,
+                                                  bool useVectorAlignment,
+                                                  bool enableOneToNConversion) {
   // This function populates only ConversionPatterns, not RewritePatterns.
   MLIRContext *ctx = converter.getDialect()->getContext();
   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
@@ -2211,8 +2521,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
                VectorGatherOpConversion, VectorScatterOpConversion>(
       converter, useVectorAlignment);
   patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
-               VectorExtractOpConversion, VectorFMAOp1DConversion,
-               VectorInsertOpConversion, VectorPrintOpConversion,
+               VectorFMAOp1DConversion, VectorPrintOpConversion,
                VectorTypeCastOpConversion, VectorScaleOpConversion,
                VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
                VectorBroadcastScalarToLowRankLowering,
@@ -2222,6 +2531,77 @@ void mlir::populateVectorToLLVMConversionPatterns(
                VectorDeinterleaveOpLowering, VectorFromElementsLowering,
                VectorToElementsLowering, VectorScalableStepOpLowering>(
       converter);
+  if (enableOneToNConversion)
+    patterns
+        .add<VectorInsertOneToNOpConversion, VectorExtractOneToNOpConversion>(
+            converter);
+  else
+    patterns.add<VectorInsertOpConversion, VectorExtractOpConversion>(
+        converter);
+
+  if (enableOneToNConversion) {
+    converter.addConversion(
+        [&](VectorType type,
+            SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+          auto elementType = converter.convertType(type.getElementType());
+          if (!elementType)
+            return failure();
+          if (type.getShape().empty()) {
+            result.push_back(VectorType::get({1}, elementType));
+            return success();
+          }
+          Type vectorType = VectorType::get(type.getShape().back(), elementType,
+                                            type.getScalableDims().back());
+          assert(LLVM::isCompatibleVectorType(vectorType) &&
+                 "expected vector type compatible with the LLVM dialect");
+          // For n-D vector types for which a _non-trailing_ dim is scalable,
+          // return a failure. Supporting such cases would require LLVM
+          // to support something akin "scalable arrays" of vectors.
+          if (llvm::is_contained(type.getScalableDims().drop_back(), true))
+            return failure();
+
+          ArrayRef<int64_t> shapeLeadingDims = type.getShape().drop_back();
+          int64_t numVectors = ShapedType::getNumElements(shapeLeadingDims);
+          for (int64_t i = 0; i < numVectors; i++)
+            result.push_back(vectorType);
+
+          return success();
+        });
+    converter.addTargetMaterialization(
+        [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+            Location loc) -> SmallVector<Value> {
+          // from ('vector<4x4xf32>')
+          // to ('vector<4xf32>', 'vector<4xf32>', 'vector<4xf32>',
+          // 'vector<4xf32>')
+          Type ty = resultTypes[0];
+          for (Type ithTy : resultTypes)
+            if (ithTy != ty)
+              return {};
+
+          if (!isa<VectorType>(ty))
+            return {};
+
+          if (inputs.size() != 1)
+            return {};
+
+          Type inputTy = inputs[0].getType();
+          if (!isa<VectorType>(inputTy))
+            return {};
+
+          VectorType inputVectorTy = cast<VectorType>(inputTy);
+          ArrayRef<int64_t> inputShape = inputVectorTy.getShape();
+          size_t numElements =
+              !inputShape.empty()
+                  ? ShapedType::getNumElements(inputShape.drop_back())
+                  : 1;
+          if (numElements != resultTypes.size())
+            return {};
+
+          return UnrealizedConversionCastOp::create(builder, loc, resultTypes,
+                                                    inputs)
+              .getResults();
+        });
+  }
 }
 
 namespace {
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 43fd1843cda28..a5b1255657c96 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -115,75 +115,11 @@ void ConvertVectorToLLVMPass::runOnOperation() {
   LowerToLLVMOptions options(&getContext());
   LLVMTypeConverter converter(&getContext(), options);
 
-  if (enableOneToNConversion) {
-
-    converter.addConversion(
-        [&](VectorType type,
-            SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
-          auto elementType = converter.convertType(type.getElementType());
-          if (!elementType)
-            return failure();
-          if (type.getShape().empty()) {
-            result.push_back(VectorType::get({1}, elementType));
-            return success();
-          }
-          Type vectorType = VectorType::get(type.getShape().back(), elementType,
-                                            type.getScalableDims().back());
-          assert(LLVM::isCompatibleVectorType(vectorType) &&
-                 "expected vector type compatible with the LLVM dialect");
-          // For n-D vector types for which a _non-trailing_ dim is scalable,
-          // return a failure. Supporting such cases would require LLVM
-          // to support something akin "scalable arrays" of vectors.
-          if (llvm::is_contained(type.getScalableDims().drop_back(), true))
-            return failure();
-
-          ArrayRef<int64_t> shapeLeadingDims = type.getShape().drop_back();
-          int64_t numVectors = ShapedType::getNumElements(shapeLeadingDims);
-          for (int64_t i = 0; i < numVectors; i++)
-            result.push_back(vectorType);
-
-          return success();
-        });
-
-    converter.addTargetMaterialization(
-        [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
-            Location loc) -> SmallVector<Value> {
-          // from ('vector<4x4xf32>')
-          // to ('vector<4xf32>', 'vector<4xf32>', 'vector<4xf32>',
-          // 'vector<4xf32>')
-          Type ty = resultTypes[0];
-          for (Type ithTy : resultTypes)
-            if (ithTy != ty)
-              return {};
-
-          if (!isa<VectorType>(ty))
-            return {};
-
-          if (inputs.size() != 1)
-            return {};
-
-          Type inputTy = inputs[0].getType();
-          if (!isa<VectorType>(inputTy))
-            return {};
-
-          VectorType inputVectorTy = cast<VectorType>(inputTy);
-          ArrayRef<int64_t> inputShape = inputVectorTy.getShape();
-          size_t numElements =
-              ShapedType::getNumElements(inputShape.drop_back());
-          if (numElements != resultTypes.size())
-            return {};
-
-          return UnrealizedConversionCastOp::create(builder, loc, resultTypes,
-                                                    inputs)
-              .getResults();
-        });
-  }
-
   RewritePatternSet patterns(&getContext());
   populateVectorTransferLoweringPatterns(patterns);
   populateVectorToLLVMConversionPatterns(
       converter, patterns, reassociateFPReductions, force32BitVectorIndices,
-      useVectorAlignment);
+      useVectorAlignment, enableOneToNConversion);
 
   // Architecture specific augmentations.
   LLVMConversionTarget target(getContext());
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-one-to-n.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-one-to-n.mlir
new file mode 100644
index 0000000000000..63316a38332f4
--- /dev/null
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-one-to-n.mlir
@@ -0,0 +1,140 @@
+// RUN: mlir-opt  --convert-vector-to-llvm="enable-one-to-n-conversion=true" --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: func @vector_extract_vector(
+// CHECK-SAME: %[[ARG0:.+]]: vector<4x4xf32>
+func.func @vector_extract_vector(%arg0: vector<4x4xf32>) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) {
+    // CHECK-NEXT: %[[CAST:.+]]:4 = builtin.unrealized_conversion_cast %[[ARG0]] : vector<4x4xf32> to vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
+    %0 = vector.extract %arg0[0] : vector<4xf32> from vector<4x4xf32>
+    %1 = vector.extract %arg0[1] : vector<4xf32> from vector<4x4xf32>
+    %2 = vector.extract %arg0[2] : vector<4xf32> from vector<4x4xf32>
+    %3 = vector.extract %arg0[3] : vector<4xf32> from vector<4x4xf32>
+    // CHECK-NEXT: return %[[CAST]]#3, %[[CAST]]#2, %[[CAST]]#1, %[[CAST]]#0
+    return %3, %2, %1, %0 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_extract_linearize(
+// CHECK-SAME: %[[ARG0:.+]]: vector<5x4x3xf32>
+func.func @vector_extract_linearize(%arg0: vector<5x4x3xf32>) -> (vector<3xf32>, vector<3xf32>, vector<3xf32>) {
+    // CHECK-NEXT: %[[CAST:.+]]:20 = builtin.unrealized_conversion_cast %[[ARG0]] : vector<5x4x3xf32> to vector<3xf32>
+    %0 = vector.extract %arg0[0, 0] : vector<3xf32> from vector<5x4x3xf32>
+    %1 = vector.extract %arg0[0, 1] : vector<3xf32> from vector<5x4x3xf32>
+    %2 = vector.extract %arg0[1, 1] : vector<3xf32> from vector<5x4x3xf32>
+    // CHECK-NEXT: return %[[CAST]]#0, %[[CAST]]#1, %[[CAST]]#5
+    return %0, %1, %2 : vector<3xf32>, vector<3xf32>, vector<3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_extract_scalar(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
+func.func @vector_extract_scalar(%arg0: vector<2x2xf32>) -> (f32) {
+  // CHECK: %[[CAST:.+]]:2 = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to vector<2xf32>, vector<2xf32>
+  // CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64)
+  // CHECK: %[[EXTRACTED:.+]] = llvm.extractelement %[[CAST]]#0[%[[C0]] : i64]
+  %0 = vector.extract %arg0[0, 0] : f32 from vector<2x2xf32>
+  // CHECK: return %[[EXTRACTED]]
+  return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_extract_lhs_multiple(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2x2xf32>)
+func.func @vector_extract_lhs_multiple(%arg0: vector<2x2x2xf32>) -> vector<2x2xf32> {
+  // CHECK: %[[CAST:.+]]:4 = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2x2xf32> to vector<2xf32>
+  // CHECK: %[[SELECTED:.+]] = builtin.unrealized_conversion_cast %[[CAST]]#0, %[[CAST]]#1 : vector<2xf32>, vector<2xf32>
+  %0 = vector.extract %arg0[0] : vector<2x2xf32> from vector<2x2x2xf32>
+  // CHECK: return %[[SELECTED]]
+  return %0 : vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_extract_rank_0(
+// CHECK-SAME: %[[ARG0:.+]]: vector<f32>)
+func.func @vector_extract_rank_0(%arg0: vector<f32>) -> f32 {
+  // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<f32> to vector<1xf32>
+  // CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64)
+  // CHECK: %[[ELEM:.+]] = llvm.extractelement %[[CAST]][%[[C0]] : i64]
+  %0 = vector.extract %arg0[] : f32 from vector<f32>
+  // CHECK: return %[[ELEM]]
+  return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_insert_vector(
+// CHECK-SAME: %[[VAL0:.+]]: vector<4xf32>, %[[VAL1:.+]]: vector<4xf32>, %[[VAL2:.+]]: vector<4xf32>, %[[AGG:.+]]: vector<4x4xf32>)
+func.func @vector_insert_vector(%val0: vector<4xf32>, %val1: vector<4xf32>, %val2: vector<4xf32>, %agg: vector<4x4xf32>) -> (vector<4x4xf32>) {
+  // CHECK: %[[CAST:.+]]:4 = builtin.unrealized_conversion_cast %[[AGG]] : vector<4x4xf32> to vector<4xf32>
+  %0 = vector.insert %val0, %agg[0] : vector<4xf32> into vector<4x4xf32>
+  %1 = vector.insert %val1, %0[1] : vector<4xf32> into vector<4x4xf32>
+  %2 = vector.insert %val2, %1[2] : vector<4xf32> into vector<4x4xf32>
+
+  // CHECK: %[[INSERTION_CAST:.+]] = builtin.unrealized_conversion_cast %[[VAL0]], %[[VAL1]], %[[VAL2]], %[[CAST]]#3
+  // CHECK: return %[[INSERTION_CAST]]
+  return %2 : vector<4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_insert_linearize(
+// CHECK-SAME: %[[VAL:.+]]: vector<3xf32>, %[[AGG:.+]]: vector<5x4x3xf32>)
+func.func @vector_insert_linearize(%val: vector<3xf32>, %agg: vector<5x4x3xf32>) -> (vector<5x4x3xf32>) {
+  // CHECK: %[[CAST:.+]]:20 = builtin.unrealized_conversion_cast %[[AGG]] : vector<5x4x3xf32> to vector<3xf32>
+
+  %0 = vector.insert %val, %agg[0, 0] : vector<3xf32> into vector<5x4x3xf32>
+  %1 = vector.insert %val, %0[0, 1] : vector<3xf32> into vector<5x4x3xf32>
+  %2 = vector.insert %val, %1[1, 1] : vector<3xf32> into vector<5x4x3xf32>
+
+  // CHECK: %[[INSERTION_CAST:.+]] = builtin.unrealized_conversion_cast %[[VAL]], %[[VAL]], %[[CAST]]#2, %[[CAST]]#3, %[[CAST]]#4, %[[VAL]]
+  // CHECK: return %[[INSERTION_CAST]]
+  return %2 : vector<5x4x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_insert_scalar(
+// CHECK-SAME: %[[VAL:.+]]: f32, %[[AGG:.+]]: vector<2x2xf32>)
+func.func @vector_insert_scalar(%val: f32, %agg: vector<2x2xf32>) -> (vector<2x2xf32>) {
+  // CHECK-DAG: %[[CAST:.+]]:2 = builtin.unrealized_conversion_cast %[[AGG]] : vector<2x2xf32> to vector<2xf32>
+  // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i64)
+  // CHECK: %[[MODIFIED_VECTOR:.+]] = llvm.insertelement %[[VAL]], %[[CAST]]#0[%[[C1]] : i64]
+  // CHECK: %[[INSERTION_CAST:.+]] = builtin.unrealized_conversion_cast %[[MODIFIED_VECTOR]], %[[CAST]]#1
+  %0 = vector.insert %val, %agg[0, 1] : f32 into vector<2x2xf32>
+
+  // CHECK: return %[[INSERTION_CAST]]
+  return %0 : vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_insert_sources_multiple(
+// CHECK-SAME: %[[TO_STORE:.+]]: vector<2x2xf32>,
+// CHECK-SAME: %[[DEST:.+]]: vector<2x2x2xf32>
+func.func @vector_insert_sources_multiple(%val: vector<2x2xf32>, %dest: vector<2x2x2xf32>) -> (vector<2x2x2xf32>) {
+  // CHECK: %[[CAST_DEST:.+]]:4 = builtin.unrealized_conversion_cast %[[DEST]] : vector<2x2x2xf32> to vector<2xf32>
+  // CHECK: %[[CAST_TO_STORE:.+]]:2 = builtin.unrealized_conversion_cast %[[TO_STORE]] : vector<2x2xf32> to vector<2xf32>
+  // CHECK: %[[INSERT:.+]] = builtin.unrealized_conversion_cast %[[CAST_TO_STORE]]#0, %[[CAST_TO_STORE]]#1, %[[CAST_DEST]]#2, %[[CAST_DEST]]#3
+
+  %0 = vector.insert %val, %dest[0] : vector<2x2xf32> into vector<2x2x2xf32>
+  // CHECK: return %[[INSERT]]
+  return %0 : vector<2x2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_insert_rank_0(
+// CHECK-SAME: %[[TO_STORE:.+]]: f32,
+// CHECK-SAME: %[[DEST:.+]]: vector<f32>
+func.func @vector_insert_rank_0(%val: f32, %dest: vector<f32>) -> (vector<f32>) {
+  // CHECK: %[[CAST_DEST:.+]] = builtin.unrealized_conversion_cast %[[DEST]] : vector<f32> to vector<1xf32>
+  // CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64)
+  // CHECK: %[[INSERT:.+]] = llvm.insertelement %[[TO_STORE]], %[[CAST_DEST]][%[[C0]] : i64]
+  // CHECK: %[[INSERT_CAST:.+]] = builtin.unrealized_conversion_cast %[[INSERT]] : vector<1xf32> to vector<f32>
+  %0 = vector.insert %val, %dest[] : f32 into vector<f32>
+  // CHECK: return %[[INSERT_CAST]]
+  return %0 : vector<f32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index db941454f8d8c..544b9521c5d79 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -51,8 +51,8 @@ struct TestVectorToVectorLowering
   }
 
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<affine::AffineDialect>();
-    registry.insert<vector::VectorDialect>();
+    registry.insert<affine::AffineDialect, arith::ArithDialect,
+                    vector::VectorDialect>();
   }
 
   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
@@ -150,6 +150,9 @@ struct TestVectorUnrollingPatterns
     return "Test lowering patterns to unroll contract ops in the vector "
            "dialect";
   }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<vector::VectorDialect>();
+  }
   TestVectorUnrollingPatterns() = default;
   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
       : PassWrapper(pass) {}



More information about the Mlir-commits mailing list