[Mlir-commits] [mlir] [mlir][vector] Add 1:N vector to llvm conversion (PR #174240)

Erick Ochoa Lopez llvmlistbot at llvm.org
Fri Jan 2 12:31:03 PST 2026


https://github.com/amd-eochoalo created https://github.com/llvm/llvm-project/pull/174240

None

>From 1e77401114d03ef989f59c18bc856b8c7a9fffbf Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 22 Dec 2025 10:01:26 -0500
Subject: [PATCH 1/2] Use one to n type conversion for vector.

---
 mlir/include/mlir/Conversion/Passes.td        |  2 +
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  | 65 +++++++++++++++++++
 2 files changed, 67 insertions(+)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 7f24e58671aab..8c5af3c8529b0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1578,6 +1578,8 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
            clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle16x16, "shuffle16x16",
             "Lower 2-D transpose to `vector.shuffle` on 16x16 vector.")
           )}]>,
+    Option<"enableOneToNConversion", "enable-one-to-n-conversion",
+           "bool", /*default=*/"false", "1:N conversion">,
   ];
 }
 
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index f958edf2746e9..43fd1843cda28 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -114,6 +114,71 @@ void ConvertVectorToLLVMPass::runOnOperation() {
   // Convert to the LLVM IR dialect.
   LowerToLLVMOptions options(&getContext());
   LLVMTypeConverter converter(&getContext(), options);
+
+  if (enableOneToNConversion) {
+
+    converter.addConversion(
+        [&](VectorType type,
+            SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+          auto elementType = converter.convertType(type.getElementType());
+          if (!elementType)
+            return failure();
+          if (type.getShape().empty()) {
+            result.push_back(VectorType::get({1}, elementType));
+            return success();
+          }
+          Type vectorType = VectorType::get(type.getShape().back(), elementType,
+                                            type.getScalableDims().back());
+          assert(LLVM::isCompatibleVectorType(vectorType) &&
+                 "expected vector type compatible with the LLVM dialect");
+          // For n-D vector types for which a _non-trailing_ dim is scalable,
+          // return a failure. Supporting such cases would require LLVM
+          // to support something akin "scalable arrays" of vectors.
+          if (llvm::is_contained(type.getScalableDims().drop_back(), true))
+            return failure();
+
+          ArrayRef<int64_t> shapeLeadingDims = type.getShape().drop_back();
+          int64_t numVectors = ShapedType::getNumElements(shapeLeadingDims);
+          for (int64_t i = 0; i < numVectors; i++)
+            result.push_back(vectorType);
+
+          return success();
+        });
+
+    converter.addTargetMaterialization(
+        [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+            Location loc) -> SmallVector<Value> {
+          // from ('vector<4x4xf32>')
+          // to ('vector<4xf32>', 'vector<4xf32>', 'vector<4xf32>',
+          // 'vector<4xf32>')
+          Type ty = resultTypes[0];
+          for (Type ithTy : resultTypes)
+            if (ithTy != ty)
+              return {};
+
+          if (!isa<VectorType>(ty))
+            return {};
+
+          if (inputs.size() != 1)
+            return {};
+
+          Type inputTy = inputs[0].getType();
+          if (!isa<VectorType>(inputTy))
+            return {};
+
+          VectorType inputVectorTy = cast<VectorType>(inputTy);
+          ArrayRef<int64_t> inputShape = inputVectorTy.getShape();
+          size_t numElements =
+              ShapedType::getNumElements(inputShape.drop_back());
+          if (numElements != resultTypes.size())
+            return {};
+
+          return UnrealizedConversionCastOp::create(builder, loc, resultTypes,
+                                                    inputs)
+              .getResults();
+        });
+  }
+
   RewritePatternSet patterns(&getContext());
   populateVectorTransferLoweringPatterns(patterns);
   populateVectorToLLVMConversionPatterns(

>From 398ef7a6796fc6586ade0d26393f0b79ccf15882 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 31 Dec 2025 11:21:21 -0500
Subject: [PATCH 2/2] Ok, this looks like it is looking ok

---
 .../VectorToLLVM/ConvertVectorToLLVM.h        |   5 +
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 176 ++++++++++++++++++
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |  11 +-
 .../VectorToLLVM/vector-to-llvm-one-to-n.mlir |  85 +++++++++
 4 files changed, 274 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Conversion/VectorToLLVM/vector-to-llvm-one-to-n.mlir

diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index cfb6cc313bc63..4f52950d597e5 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -19,6 +19,11 @@ void populateVectorToLLVMConversionPatterns(
     bool reassociateFPReductions = false, bool force32BitVectorIndices = false,
     bool useVectorAlignment = false);
 
+void populateVectorOneToNLLVMConversionPatterns(
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    bool reassociateFPReductions = false, bool force32BitVectorIndices = false,
+    bool useVectorAlignment = false);
+
 namespace vector {
 void registerConvertVectorToLLVMInterface(DialectRegistry &registry);
 }
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 05d541fe80356..0df169f99dc25 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1088,6 +1088,82 @@ class VectorShuffleOpConversion
   }
 };
 
+class VectorExtractOneToNOpConversion
+    : public ConvertOpToLLVMPattern<vector::ExtractOp> {
+public:
+  using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ExtractOp extractOp, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    // This rewrite is for a vector.extract op using 1:N conversion.
+    // adaptor.getSource() will have multiple values.
+    // This rewrite will be extracting one vector from these values,
+    // or one scalar from one of these values.
+    if (extractOp.getSourceVectorType().getRank() <= 1)
+      return rewriter.notifyMatchFailure(
+          extractOp, "source vector type must be rank 2 or higher.");
+
+    if (adaptor.getDynamicPosition().size())
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "position must be statically known.");
+
+    Type resultType = extractOp.getResult().getType();
+    Type llvmResultType = typeConverter->convertType(resultType);
+    if (!llvmResultType)
+      return rewriter.notifyMatchFailure(extractOp, "type conversion failed.");
+
+    SmallVector<Value> sources = adaptor.getSource();
+    ArrayRef<int64_t> position = adaptor.getStaticPosition();
+
+    // Determine if we need to extract a scalar as the result. We extract
+    // a scalar if the extract is full rank, i.e., the number of indices is
+    // equal to source vector rank.
+    bool extractsScalar = static_cast<int64_t>(position.size()) ==
+                          extractOp.getSourceVectorType().getRank();
+
+    VectorType sourceTy = extractOp.getSourceVectorType();
+    ArrayRef<int64_t> sourceShape = sourceTy.getShape();
+
+    SmallVector<int64_t> strides(sourceShape.drop_front());
+    strides.push_back(1);
+
+    for (int i = strides.size() - 2; i >= 0; --i) {
+      strides[i] *= strides[i + 1];
+    }
+
+    ArrayRef<int64_t> positionVec = position;
+    if (extractsScalar) {
+      positionVec = position.drop_back();
+    }
+
+    int64_t linearIdx = 0;
+    for (auto [offset, coeff] :
+         llvm::zip(llvm::reverse(positionVec), llvm::reverse(strides))) {
+      linearIdx += offset * coeff;
+    }
+
+    Value replacement = sources[linearIdx];
+    Location loc = extractOp.getLoc();
+
+    if (extractsScalar) {
+      Type idxType = rewriter.getIndexType();
+      Type llvmIdxType = typeConverter->convertType(idxType);
+      assert(llvmIdxType && "expected type conversion to succeed.");
+      auto posAttr = rewriter.getIntegerAttr(llvmIdxType, position.back());
+      Value pos =
+          arith::ConstantOp::create(rewriter, loc, llvmIdxType, posAttr);
+      replacement = LLVM::ExtractElementOp::create(
+          rewriter, loc, replacement, getAsLLVMValue(rewriter, loc, pos));
+    }
+
+    rewriter.replaceOp(extractOp, replacement);
+
+    return success();
+  }
+};
+
 class VectorExtractOpConversion
     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
 public:
@@ -1191,6 +1267,79 @@ class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
   }
 };
 
+class VectorInsertOneToNOpConversion
+    : public ConvertOpToLLVMPattern<vector::InsertOp> {
+public:
+  using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::InsertOp insertOp, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    if (adaptor.getDynamicPosition().size())
+      return rewriter.notifyMatchFailure(
+          insertOp, "position is expected to be statically known.");
+
+    SmallVector<Value> valueToStore = adaptor.getValueToStore();
+    if (valueToStore.size() > 1)
+      return rewriter.notifyMatchFailure(
+          insertOp,
+          "expected to insert single value into a collection of vectors.");
+
+    SmallVector<Value> dest = adaptor.getDest();
+    Location loc = insertOp->getLoc();
+    VectorType destVectorType = insertOp.getDestVectorType();
+
+    ArrayRef<int64_t> positionVec = adaptor.getStaticPosition();
+
+    // Determine if we need to insert a scalar into the 1D vector.
+    bool insertIntoInnermostDim =
+        static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
+
+    unsigned stop =
+        insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size();
+    ArrayRef<int64_t> positionOf1DVectorWithinAggregate(positionVec.begin(),
+                                                        stop);
+
+    int64_t positionOfScalarWithin1DVector = positionVec.back();
+
+    ArrayRef<int64_t> destShape = destVectorType.getShape();
+    SmallVector<int64_t> strides(destShape.drop_front());
+    strides.push_back(1);
+
+    for (int i = strides.size() - 2; i >= 0; --i) {
+      strides[i] *= strides[i + 1];
+    }
+
+    int64_t linearIdx = 0;
+    for (auto [offset, coeff] :
+         llvm::zip(llvm::reverse(positionOf1DVectorWithinAggregate),
+                   llvm::reverse(strides))) {
+      linearIdx += offset * coeff;
+    }
+
+    Value replacement = dest[linearIdx];
+    if (insertIntoInnermostDim) {
+      Type idxType = rewriter.getIndexType();
+      Type llvmIdxType = typeConverter->convertType(idxType);
+      assert(llvmIdxType && "expected type conversion to succeed.");
+      auto posAttr =
+          rewriter.getIntegerAttr(llvmIdxType, positionOfScalarWithin1DVector);
+      Value pos =
+          arith::ConstantOp::create(rewriter, loc, llvmIdxType, posAttr);
+      replacement =
+          LLVM::InsertElementOp::create(rewriter, loc, replacement.getType(),
+                                        replacement, valueToStore[0], pos);
+    } else {
+      replacement = valueToStore[0];
+    }
+
+    dest[linearIdx] = replacement;
+    rewriter.replaceOpWithMultiple(insertOp, {dest});
+    return success();
+  }
+};
+
 class VectorInsertOpConversion
     : public ConvertOpToLLVMPattern<vector::InsertOp> {
 public:
@@ -2195,6 +2344,33 @@ void mlir::vector::populateVectorTransposeToFlatTranspose(
                                                        benefit);
 }
 
+void mlir::populateVectorOneToNLLVMConversionPatterns(
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    bool reassociateFPReductions, bool force32BitVectorIndices,
+    bool useVectorAlignment) {
+  MLIRContext *ctx = converter.getDialect()->getContext();
+  patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
+  patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
+  patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
+               VectorLoadStoreConversion<vector::MaskedLoadOp>,
+               VectorLoadStoreConversion<vector::StoreOp>,
+               VectorLoadStoreConversion<vector::MaskedStoreOp>,
+               VectorGatherOpConversion, VectorScatterOpConversion>(
+      converter, useVectorAlignment);
+  patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
+               VectorFMAOp1DConversion, VectorPrintOpConversion,
+               VectorExtractOneToNOpConversion, VectorInsertOneToNOpConversion,
+               VectorTypeCastOpConversion, VectorScaleOpConversion,
+               VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
+               VectorBroadcastScalarToLowRankLowering,
+               VectorBroadcastScalarToNdLowering,
+               VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
+               MaskedReductionOpConversion, VectorInterleaveOpLowering,
+               VectorDeinterleaveOpLowering, VectorFromElementsLowering,
+               VectorToElementsLowering, VectorScalableStepOpLowering>(
+      converter);
+}
+
 /// Populate the given list with patterns that convert from Vector to LLVM.
 void mlir::populateVectorToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 43fd1843cda28..93609d147e240 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -181,9 +181,14 @@ void ConvertVectorToLLVMPass::runOnOperation() {
 
   RewritePatternSet patterns(&getContext());
   populateVectorTransferLoweringPatterns(patterns);
-  populateVectorToLLVMConversionPatterns(
-      converter, patterns, reassociateFPReductions, force32BitVectorIndices,
-      useVectorAlignment);
+  if (enableOneToNConversion)
+    populateVectorOneToNLLVMConversionPatterns(
+        converter, patterns, reassociateFPReductions, force32BitVectorIndices,
+        useVectorAlignment);
+  else
+    populateVectorToLLVMConversionPatterns(
+        converter, patterns, reassociateFPReductions, force32BitVectorIndices,
+        useVectorAlignment);
 
   // Architecture specific augmentations.
   LLVMConversionTarget target(getContext());
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-one-to-n.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-one-to-n.mlir
new file mode 100644
index 0000000000000..b90bac1f0c8cf
--- /dev/null
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-one-to-n.mlir
@@ -0,0 +1,85 @@
+// RUN: mlir-opt  --convert-vector-to-llvm="enable-one-to-n-conversion=true" --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: func @vector_extract_vector(
+// CHECK-SAME: %[[ARG0:.+]]: vector<4x4xf32>
+func.func @vector_extract_vector(%arg0: vector<4x4xf32>) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) {
+    // CHECK-NEXT: %[[CAST:.+]]:4 = builtin.unrealized_conversion_cast %[[ARG0]] : vector<4x4xf32> to vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
+    %0 = vector.extract %arg0[0] : vector<4xf32> from vector<4x4xf32>
+    %1 = vector.extract %arg0[1] : vector<4xf32> from vector<4x4xf32>
+    %2 = vector.extract %arg0[2] : vector<4xf32> from vector<4x4xf32>
+    %3 = vector.extract %arg0[3] : vector<4xf32> from vector<4x4xf32>
+    // CHECK-NEXT: return %[[CAST]]#3, %[[CAST]]#2, %[[CAST]]#1, %[[CAST]]#0
+    return %3, %2, %1, %0 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_extract_linearize(
+// CHECK-SAME: %[[ARG0:.+]]: vector<5x4x3xf32>
+func.func @vector_extract_linearize(%arg0: vector<5x4x3xf32>) -> (vector<3xf32>, vector<3xf32>, vector<3xf32>) {
+    // CHECK-NEXT: %[[CAST:.+]]:20 = builtin.unrealized_conversion_cast %[[ARG0]] : vector<5x4x3xf32> to vector<3xf32>
+    %0 = vector.extract %arg0[0, 0] : vector<3xf32> from vector<5x4x3xf32>
+    %1 = vector.extract %arg0[0, 1] : vector<3xf32> from vector<5x4x3xf32>
+    %2 = vector.extract %arg0[1, 0] : vector<3xf32> from vector<5x4x3xf32>
+    // CHECK-NEXT: return %[[CAST]]#0, %[[CAST]]#1, %[[CAST]]#3
+    return %0, %1, %2 : vector<3xf32>, vector<3xf32>, vector<3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_extract_scalar(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
+func.func @vector_extract_scalar(%arg0: vector<2x2xf32>) -> (f32) {
+  // CHECK: %[[CAST:.+]]:2 = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to vector<2xf32>, vector<2xf32>
+  // CHECK: %[[C0:.+]] = arith.constant 0 : i64
+  // CHECK: %[[EXTRACTED:.+]] = llvm.extractelement %[[CAST]]#0[%[[C0]] : i64]
+  %0 = vector.extract %arg0[0, 0] : f32 from vector<2x2xf32>
+  // CHECK: return %[[EXTRACTED]]
+  return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_insert_vector(
+// CHECK-SAME: %[[VAL0:.+]]: vector<4xf32>, %[[VAL1:.+]]: vector<4xf32>, %[[VAL2:.+]]: vector<4xf32>, %[[AGG:.+]]: vector<4x4xf32>)
+func.func @vector_insert_vector(%val0: vector<4xf32>, %val1: vector<4xf32>, %val2: vector<4xf32>, %agg: vector<4x4xf32>) -> (vector<4x4xf32>) {
+  // CHECK: %[[CAST:.+]]:4 = builtin.unrealized_conversion_cast %[[AGG]] : vector<4x4xf32> to vector<4xf32>
+  %0 = vector.insert %val0, %agg[0] : vector<4xf32> into vector<4x4xf32>
+  %1 = vector.insert %val1, %0[1] : vector<4xf32> into vector<4x4xf32>
+  %2 = vector.insert %val2, %1[2] : vector<4xf32> into vector<4x4xf32>
+
+  // CHECK: %[[INSERTION_CAST:.+]] = builtin.unrealized_conversion_cast %[[VAL0]], %[[VAL1]], %[[VAL2]], %[[CAST]]#3
+  // CHECK: return %[[INSERTION_CAST]]
+  return %2 : vector<4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_insert_linearize(
+// CHECK-SAME: %[[VAL:.+]]: vector<3xf32>, %[[AGG:.+]]: vector<5x4x3xf32>)
+func.func @vector_insert_linearize(%val: vector<3xf32>, %agg: vector<5x4x3xf32>) -> (vector<5x4x3xf32>) {
+  // CHECK: %[[CAST:.+]]:20 = builtin.unrealized_conversion_cast %[[AGG]] : vector<5x4x3xf32> to vector<3xf32>
+
+  %0 = vector.insert %val, %agg[0, 0] : vector<3xf32> into vector<5x4x3xf32>
+  %1 = vector.insert %val, %0[0, 1] : vector<3xf32> into vector<5x4x3xf32>
+  %2 = vector.insert %val, %1[1, 1] : vector<3xf32> into vector<5x4x3xf32>
+
+  // CHECK: %[[INSERTION_CAST:.+]] = builtin.unrealized_conversion_cast %[[VAL]], %[[VAL]], %[[CAST]]#2, %[[CAST]]#3, %[[VAL]]
+  // CHECK: return %[[INSERTION_CAST]]
+  return %2 : vector<5x4x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_insert_scalar(
+// CHECK-SAME: %[[VAL:.+]]: f32, %[[AGG:.+]]: vector<2x2xf32>)
+func.func @vector_insert_scalar(%val: f32, %agg: vector<2x2xf32>) -> (vector<2x2xf32>) {
+  // CHECK-DAG: %[[CAST:.+]]:2 = builtin.unrealized_conversion_cast %[[AGG]] : vector<2x2xf32> to vector<2xf32>
+  // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i64
+  // CHECK: %[[MODIFIED_VECTOR:.+]] = llvm.insertelement %[[VAL]], %[[CAST]]#0[%[[C1]] : i64]
+  // CHECK: %[[INSERTION_CAST:.+]] = builtin.unrealized_conversion_cast %[[MODIFIED_VECTOR]], %[[CAST]]#1
+  %0 = vector.insert %val, %agg[0, 1] : f32 into vector<2x2xf32>
+
+  // CHECK: return %[[INSERTION_CAST]]
+  return %0 : vector<2x2xf32>
+}



More information about the Mlir-commits mailing list