[Mlir-commits] [mlir] [mlir][vector] Add 1:N vector to llvm conversion (PR #174240)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Fri Jan 2 12:32:13 PST 2026
https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/174240
>From 1e77401114d03ef989f59c18bc856b8c7a9fffbf Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 22 Dec 2025 10:01:26 -0500
Subject: [PATCH 1/2] Use one to n type conversion for vector.
---
mlir/include/mlir/Conversion/Passes.td | 2 +
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 65 +++++++++++++++++++
2 files changed, 67 insertions(+)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 7f24e58671aab..8c5af3c8529b0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1578,6 +1578,8 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle16x16, "shuffle16x16",
"Lower 2-D transpose to `vector.shuffle` on 16x16 vector.")
)}]>,
+ Option<"enableOneToNConversion", "enable-one-to-n-conversion",
+ "bool", /*default=*/"false", "1:N conversion">,
];
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index f958edf2746e9..43fd1843cda28 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -114,6 +114,71 @@ void ConvertVectorToLLVMPass::runOnOperation() {
// Convert to the LLVM IR dialect.
LowerToLLVMOptions options(&getContext());
LLVMTypeConverter converter(&getContext(), options);
+
+ if (enableOneToNConversion) {
+
+ converter.addConversion(
+ [&](VectorType type,
+ SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+ auto elementType = converter.convertType(type.getElementType());
+ if (!elementType)
+ return failure();
+ if (type.getShape().empty()) {
+ result.push_back(VectorType::get({1}, elementType));
+ return success();
+ }
+ Type vectorType = VectorType::get(type.getShape().back(), elementType,
+ type.getScalableDims().back());
+ assert(LLVM::isCompatibleVectorType(vectorType) &&
+ "expected vector type compatible with the LLVM dialect");
+ // For n-D vector types for which a _non-trailing_ dim is scalable,
+ // return a failure. Supporting such cases would require LLVM
+ // to support something akin "scalable arrays" of vectors.
+ if (llvm::is_contained(type.getScalableDims().drop_back(), true))
+ return failure();
+
+ ArrayRef<int64_t> shapeLeadingDims = type.getShape().drop_back();
+ int64_t numVectors = ShapedType::getNumElements(shapeLeadingDims);
+ for (int64_t i = 0; i < numVectors; i++)
+ result.push_back(vectorType);
+
+ return success();
+ });
+
+ converter.addTargetMaterialization(
+ [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+ Location loc) -> SmallVector<Value> {
+ // from ('vector<4x4xf32>')
+ // to ('vector<4xf32>', 'vector<4xf32>', 'vector<4xf32>',
+ // 'vector<4xf32>')
+ Type ty = resultTypes[0];
+ for (Type ithTy : resultTypes)
+ if (ithTy != ty)
+ return {};
+
+ if (!isa<VectorType>(ty))
+ return {};
+
+ if (inputs.size() != 1)
+ return {};
+
+ Type inputTy = inputs[0].getType();
+ if (!isa<VectorType>(inputTy))
+ return {};
+
+ VectorType inputVectorTy = cast<VectorType>(inputTy);
+ ArrayRef<int64_t> inputShape = inputVectorTy.getShape();
+ size_t numElements =
+ ShapedType::getNumElements(inputShape.drop_back());
+ if (numElements != resultTypes.size())
+ return {};
+
+ return UnrealizedConversionCastOp::create(builder, loc, resultTypes,
+ inputs)
+ .getResults();
+ });
+ }
+
RewritePatternSet patterns(&getContext());
populateVectorTransferLoweringPatterns(patterns);
populateVectorToLLVMConversionPatterns(
>From 897cdf4f2dda57b3648aacdc65fdf1dd0d46dc87 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 31 Dec 2025 11:21:21 -0500
Subject: [PATCH 2/2] [mlir] Add 1:N conversion for vector operations.
* Add lowering for vector.insert
* Add lowering for vector.extract
---
.../VectorToLLVM/ConvertVectorToLLVM.h | 5 +
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 176 ++++++++++++++++++
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 11 +-
.../VectorToLLVM/vector-to-llvm-one-to-n.mlir | 85 +++++++++
4 files changed, 274 insertions(+), 3 deletions(-)
create mode 100644 mlir/test/Conversion/VectorToLLVM/vector-to-llvm-one-to-n.mlir
diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index cfb6cc313bc63..4f52950d597e5 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -19,6 +19,11 @@ void populateVectorToLLVMConversionPatterns(
bool reassociateFPReductions = false, bool force32BitVectorIndices = false,
bool useVectorAlignment = false);
+void populateVectorOneToNLLVMConversionPatterns(
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ bool reassociateFPReductions = false, bool force32BitVectorIndices = false,
+ bool useVectorAlignment = false);
+
namespace vector {
void registerConvertVectorToLLVMInterface(DialectRegistry ®istry);
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 05d541fe80356..0df169f99dc25 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1088,6 +1088,82 @@ class VectorShuffleOpConversion
}
};
+class VectorExtractOneToNOpConversion
+ : public ConvertOpToLLVMPattern<vector::ExtractOp> {
+public:
+ using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ExtractOp extractOp, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ // This rewrite is for a vector.extract op using 1:N conversion.
+ // adaptor.getSource() will have multiple values.
+ // This rewrite will be extracting one vector from these values,
+ // or one scalar from one of these values.
+ if (extractOp.getSourceVectorType().getRank() <= 1)
+ return rewriter.notifyMatchFailure(
+ extractOp, "source vector type must be rank 2 or higher.");
+
+ if (adaptor.getDynamicPosition().size())
+ return rewriter.notifyMatchFailure(extractOp,
+ "position must be statically known.");
+
+ Type resultType = extractOp.getResult().getType();
+ Type llvmResultType = typeConverter->convertType(resultType);
+ if (!llvmResultType)
+ return rewriter.notifyMatchFailure(extractOp, "type conversion failed.");
+
+ SmallVector<Value> sources = adaptor.getSource();
+ ArrayRef<int64_t> position = adaptor.getStaticPosition();
+
+ // Determine if we need to extract a scalar as the result. We extract
+ // a scalar if the extract is full rank, i.e., the number of indices is
+ // equal to source vector rank.
+ bool extractsScalar = static_cast<int64_t>(position.size()) ==
+ extractOp.getSourceVectorType().getRank();
+
+ VectorType sourceTy = extractOp.getSourceVectorType();
+ ArrayRef<int64_t> sourceShape = sourceTy.getShape();
+
+ SmallVector<int64_t> strides(sourceShape.drop_front());
+ strides.push_back(1);
+
+ for (int i = strides.size() - 2; i >= 0; --i) {
+ strides[i] *= strides[i + 1];
+ }
+
+ ArrayRef<int64_t> positionVec = position;
+ if (extractsScalar) {
+ positionVec = position.drop_back();
+ }
+
+ int64_t linearIdx = 0;
+ for (auto [offset, coeff] :
+ llvm::zip(llvm::reverse(positionVec), llvm::reverse(strides))) {
+ linearIdx += offset * coeff;
+ }
+
+ Value replacement = sources[linearIdx];
+ Location loc = extractOp.getLoc();
+
+ if (extractsScalar) {
+ Type idxType = rewriter.getIndexType();
+ Type llvmIdxType = typeConverter->convertType(idxType);
+ assert(llvmIdxType && "expected type conversion to succeed.");
+ auto posAttr = rewriter.getIntegerAttr(llvmIdxType, position.back());
+ Value pos =
+ arith::ConstantOp::create(rewriter, loc, llvmIdxType, posAttr);
+ replacement = LLVM::ExtractElementOp::create(
+ rewriter, loc, replacement, getAsLLVMValue(rewriter, loc, pos));
+ }
+
+ rewriter.replaceOp(extractOp, replacement);
+
+ return success();
+ }
+};
+
class VectorExtractOpConversion
: public ConvertOpToLLVMPattern<vector::ExtractOp> {
public:
@@ -1191,6 +1267,79 @@ class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
}
};
+class VectorInsertOneToNOpConversion
+ : public ConvertOpToLLVMPattern<vector::InsertOp> {
+public:
+ using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::InsertOp insertOp, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ if (adaptor.getDynamicPosition().size())
+ return rewriter.notifyMatchFailure(
+ insertOp, "position is expected to be statically known.");
+
+ SmallVector<Value> valueToStore = adaptor.getValueToStore();
+ if (valueToStore.size() > 1)
+ return rewriter.notifyMatchFailure(
+ insertOp,
+ "expected to insert single value into a collection of vectors.");
+
+ SmallVector<Value> dest = adaptor.getDest();
+ Location loc = insertOp->getLoc();
+ VectorType destVectorType = insertOp.getDestVectorType();
+
+ ArrayRef<int64_t> positionVec = adaptor.getStaticPosition();
+
+ // Determine if we need to insert a scalar into the 1D vector.
+ bool insertIntoInnermostDim =
+ static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
+
+ unsigned stop =
+ insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size();
+ ArrayRef<int64_t> positionOf1DVectorWithinAggregate(positionVec.begin(),
+ stop);
+
+ int64_t positionOfScalarWithin1DVector = positionVec.back();
+
+ ArrayRef<int64_t> destShape = destVectorType.getShape();
+ SmallVector<int64_t> strides(destShape.drop_front());
+ strides.push_back(1);
+
+ for (int i = strides.size() - 2; i >= 0; --i) {
+ strides[i] *= strides[i + 1];
+ }
+
+ int64_t linearIdx = 0;
+ for (auto [offset, coeff] :
+ llvm::zip(llvm::reverse(positionOf1DVectorWithinAggregate),
+ llvm::reverse(strides))) {
+ linearIdx += offset * coeff;
+ }
+
+ Value replacement = dest[linearIdx];
+ if (insertIntoInnermostDim) {
+ Type idxType = rewriter.getIndexType();
+ Type llvmIdxType = typeConverter->convertType(idxType);
+ assert(llvmIdxType && "expected type conversion to succeed.");
+ auto posAttr =
+ rewriter.getIntegerAttr(llvmIdxType, positionOfScalarWithin1DVector);
+ Value pos =
+ arith::ConstantOp::create(rewriter, loc, llvmIdxType, posAttr);
+ replacement =
+ LLVM::InsertElementOp::create(rewriter, loc, replacement.getType(),
+ replacement, valueToStore[0], pos);
+ } else {
+ replacement = valueToStore[0];
+ }
+
+ dest[linearIdx] = replacement;
+ rewriter.replaceOpWithMultiple(insertOp, {dest});
+ return success();
+ }
+};
+
class VectorInsertOpConversion
: public ConvertOpToLLVMPattern<vector::InsertOp> {
public:
@@ -2195,6 +2344,33 @@ void mlir::vector::populateVectorTransposeToFlatTranspose(
benefit);
}
+void mlir::populateVectorOneToNLLVMConversionPatterns(
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ bool reassociateFPReductions, bool force32BitVectorIndices,
+ bool useVectorAlignment) {
+ MLIRContext *ctx = converter.getDialect()->getContext();
+ patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
+ patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
+ patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
+ VectorLoadStoreConversion<vector::MaskedLoadOp>,
+ VectorLoadStoreConversion<vector::StoreOp>,
+ VectorLoadStoreConversion<vector::MaskedStoreOp>,
+ VectorGatherOpConversion, VectorScatterOpConversion>(
+ converter, useVectorAlignment);
+ patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
+ VectorFMAOp1DConversion, VectorPrintOpConversion,
+ VectorExtractOneToNOpConversion, VectorInsertOneToNOpConversion,
+ VectorTypeCastOpConversion, VectorScaleOpConversion,
+ VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
+ VectorBroadcastScalarToLowRankLowering,
+ VectorBroadcastScalarToNdLowering,
+ VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
+ MaskedReductionOpConversion, VectorInterleaveOpLowering,
+ VectorDeinterleaveOpLowering, VectorFromElementsLowering,
+ VectorToElementsLowering, VectorScalableStepOpLowering>(
+ converter);
+}
+
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 43fd1843cda28..93609d147e240 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -181,9 +181,14 @@ void ConvertVectorToLLVMPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateVectorTransferLoweringPatterns(patterns);
- populateVectorToLLVMConversionPatterns(
- converter, patterns, reassociateFPReductions, force32BitVectorIndices,
- useVectorAlignment);
+ if (enableOneToNConversion)
+ populateVectorOneToNLLVMConversionPatterns(
+ converter, patterns, reassociateFPReductions, force32BitVectorIndices,
+ useVectorAlignment);
+ else
+ populateVectorToLLVMConversionPatterns(
+ converter, patterns, reassociateFPReductions, force32BitVectorIndices,
+ useVectorAlignment);
// Architecture specific augmentations.
LLVMConversionTarget target(getContext());
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-one-to-n.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-one-to-n.mlir
new file mode 100644
index 0000000000000..b90bac1f0c8cf
--- /dev/null
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-one-to-n.mlir
@@ -0,0 +1,85 @@
+// RUN: mlir-opt --convert-vector-to-llvm="enable-one-to-n-conversion=true" --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: func @vector_extract_vector(
+// CHECK-SAME: %[[ARG0:.+]]: vector<4x4xf32>
+func.func @vector_extract_vector(%arg0: vector<4x4xf32>) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) {
+ // CHECK-NEXT: %[[CAST:.+]]:4 = builtin.unrealized_conversion_cast %[[ARG0]] : vector<4x4xf32> to vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
+ %0 = vector.extract %arg0[0] : vector<4xf32> from vector<4x4xf32>
+ %1 = vector.extract %arg0[1] : vector<4xf32> from vector<4x4xf32>
+ %2 = vector.extract %arg0[2] : vector<4xf32> from vector<4x4xf32>
+ %3 = vector.extract %arg0[3] : vector<4xf32> from vector<4x4xf32>
+ // CHECK-NEXT: return %[[CAST]]#3, %[[CAST]]#2, %[[CAST]]#1, %[[CAST]]#0
+ return %3, %2, %1, %0 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_extract_linearize(
+// CHECK-SAME: %[[ARG0:.+]]: vector<5x4x3xf32>
+func.func @vector_extract_linearize(%arg0: vector<5x4x3xf32>) -> (vector<3xf32>, vector<3xf32>, vector<3xf32>) {
+ // CHECK-NEXT: %[[CAST:.+]]:20 = builtin.unrealized_conversion_cast %[[ARG0]] : vector<5x4x3xf32> to vector<3xf32>
+ %0 = vector.extract %arg0[0, 0] : vector<3xf32> from vector<5x4x3xf32>
+ %1 = vector.extract %arg0[0, 1] : vector<3xf32> from vector<5x4x3xf32>
+ %2 = vector.extract %arg0[1, 0] : vector<3xf32> from vector<5x4x3xf32>
+ // CHECK-NEXT: return %[[CAST]]#0, %[[CAST]]#1, %[[CAST]]#3
+ return %0, %1, %2 : vector<3xf32>, vector<3xf32>, vector<3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_extract_scalar(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
+func.func @vector_extract_scalar(%arg0: vector<2x2xf32>) -> (f32) {
+ // CHECK: %[[CAST:.+]]:2 = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to vector<2xf32>, vector<2xf32>
+ // CHECK: %[[C0:.+]] = arith.constant 0 : i64
+ // CHECK: %[[EXTRACTED:.+]] = llvm.extractelement %[[CAST]]#0[%[[C0]] : i64]
+ %0 = vector.extract %arg0[0, 0] : f32 from vector<2x2xf32>
+ // CHECK: return %[[EXTRACTED]]
+ return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_insert_vector(
+// CHECK-SAME: %[[VAL0:.+]]: vector<4xf32>, %[[VAL1:.+]]: vector<4xf32>, %[[VAL2:.+]]: vector<4xf32>, %[[AGG:.+]]: vector<4x4xf32>)
+func.func @vector_insert_vector(%val0: vector<4xf32>, %val1: vector<4xf32>, %val2: vector<4xf32>, %agg: vector<4x4xf32>) -> (vector<4x4xf32>) {
+ // CHECK: %[[CAST:.+]]:4 = builtin.unrealized_conversion_cast %[[AGG]] : vector<4x4xf32> to vector<4xf32>
+ %0 = vector.insert %val0, %agg[0] : vector<4xf32> into vector<4x4xf32>
+ %1 = vector.insert %val1, %0[1] : vector<4xf32> into vector<4x4xf32>
+ %2 = vector.insert %val2, %1[2] : vector<4xf32> into vector<4x4xf32>
+
+ // CHECK: %[[INSERTION_CAST:.+]] = builtin.unrealized_conversion_cast %[[VAL0]], %[[VAL1]], %[[VAL2]], %[[CAST]]#3
+ // CHECK: return %[[INSERTION_CAST]]
+ return %2 : vector<4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_insert_linearize(
+// CHECK-SAME: %[[VAL:.+]]: vector<3xf32>, %[[AGG:.+]]: vector<5x4x3xf32>)
+func.func @vector_insert_linearize(%val: vector<3xf32>, %agg: vector<5x4x3xf32>) -> (vector<5x4x3xf32>) {
+ // CHECK: %[[CAST:.+]]:20 = builtin.unrealized_conversion_cast %[[AGG]] : vector<5x4x3xf32> to vector<3xf32>
+
+ %0 = vector.insert %val, %agg[0, 0] : vector<3xf32> into vector<5x4x3xf32>
+ %1 = vector.insert %val, %0[0, 1] : vector<3xf32> into vector<5x4x3xf32>
+ %2 = vector.insert %val, %1[1, 1] : vector<3xf32> into vector<5x4x3xf32>
+
+ // CHECK: %[[INSERTION_CAST:.+]] = builtin.unrealized_conversion_cast %[[VAL]], %[[VAL]], %[[CAST]]#2, %[[CAST]]#3, %[[VAL]]
+ // CHECK: return %[[INSERTION_CAST]]
+ return %2 : vector<5x4x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_insert_scalar(
+// CHECK-SAME: %[[VAL:.+]]: f32, %[[AGG:.+]]: vector<2x2xf32>)
+func.func @vector_insert_scalar(%val: f32, %agg: vector<2x2xf32>) -> (vector<2x2xf32>) {
+ // CHECK-DAG: %[[CAST:.+]]:2 = builtin.unrealized_conversion_cast %[[AGG]] : vector<2x2xf32> to vector<2xf32>
+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i64
+ // CHECK: %[[MODIFIED_VECTOR:.+]] = llvm.insertelement %[[VAL]], %[[CAST]]#0[%[[C1]] : i64]
+ // CHECK: %[[INSERTION_CAST:.+]] = builtin.unrealized_conversion_cast %[[MODIFIED_VECTOR]], %[[CAST]]#1
+ %0 = vector.insert %val, %agg[0, 1] : f32 into vector<2x2xf32>
+
+ // CHECK: return %[[INSERTION_CAST]]
+ return %0 : vector<2x2xf32>
+}
More information about the Mlir-commits
mailing list