[Mlir-commits] [mlir] [mlir][vector] Add 1:N vector to llvm conversion (PR #174240)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Mon Jan 5 13:51:54 PST 2026
https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/174240
>From 50797d355e8944c9f8b80826ee963756edfdb387 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 22 Dec 2025 10:01:26 -0500
Subject: [PATCH 1/2] Use one to n type conversion for vector.
---
mlir/include/mlir/Conversion/Passes.td | 2 +
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 65 +++++++++++++++++++
2 files changed, 67 insertions(+)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 7f24e58671aab..8c5af3c8529b0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1578,6 +1578,8 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle16x16, "shuffle16x16",
"Lower 2-D transpose to `vector.shuffle` on 16x16 vector.")
)}]>,
+ Option<"enableOneToNConversion", "enable-one-to-n-conversion",
+ "bool", /*default=*/"false", "1:N conversion">,
];
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index f958edf2746e9..43fd1843cda28 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -114,6 +114,71 @@ void ConvertVectorToLLVMPass::runOnOperation() {
// Convert to the LLVM IR dialect.
LowerToLLVMOptions options(&getContext());
LLVMTypeConverter converter(&getContext(), options);
+
+ if (enableOneToNConversion) {
+
+ converter.addConversion(
+ [&](VectorType type,
+ SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+ auto elementType = converter.convertType(type.getElementType());
+ if (!elementType)
+ return failure();
+ if (type.getShape().empty()) {
+ result.push_back(VectorType::get({1}, elementType));
+ return success();
+ }
+ Type vectorType = VectorType::get(type.getShape().back(), elementType,
+ type.getScalableDims().back());
+ assert(LLVM::isCompatibleVectorType(vectorType) &&
+ "expected vector type compatible with the LLVM dialect");
+ // For n-D vector types for which a _non-trailing_ dim is scalable,
+ // return a failure. Supporting such cases would require LLVM
+ // to support something akin "scalable arrays" of vectors.
+ if (llvm::is_contained(type.getScalableDims().drop_back(), true))
+ return failure();
+
+ ArrayRef<int64_t> shapeLeadingDims = type.getShape().drop_back();
+ int64_t numVectors = ShapedType::getNumElements(shapeLeadingDims);
+ for (int64_t i = 0; i < numVectors; i++)
+ result.push_back(vectorType);
+
+ return success();
+ });
+
+ converter.addTargetMaterialization(
+ [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+ Location loc) -> SmallVector<Value> {
+ // from ('vector<4x4xf32>')
+ // to ('vector<4xf32>', 'vector<4xf32>', 'vector<4xf32>',
+ // 'vector<4xf32>')
+ Type ty = resultTypes[0];
+ for (Type ithTy : resultTypes)
+ if (ithTy != ty)
+ return {};
+
+ if (!isa<VectorType>(ty))
+ return {};
+
+ if (inputs.size() != 1)
+ return {};
+
+ Type inputTy = inputs[0].getType();
+ if (!isa<VectorType>(inputTy))
+ return {};
+
+ VectorType inputVectorTy = cast<VectorType>(inputTy);
+ ArrayRef<int64_t> inputShape = inputVectorTy.getShape();
+ size_t numElements =
+ ShapedType::getNumElements(inputShape.drop_back());
+ if (numElements != resultTypes.size())
+ return {};
+
+ return UnrealizedConversionCastOp::create(builder, loc, resultTypes,
+ inputs)
+ .getResults();
+ });
+ }
+
RewritePatternSet patterns(&getContext());
populateVectorTransferLoweringPatterns(patterns);
populateVectorToLLVMConversionPatterns(
>From 634fd863606131665abc36f8c78485b42d9ffc81 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 31 Dec 2025 11:21:21 -0500
Subject: [PATCH 2/2] [mlir] Add 1:N conversion for vector operations.
* Add lowering for vector.insert
* Add lowering for vector.extract
---
.../VectorToLLVM/ConvertVectorToLLVM.h | 4 +-
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 389 +++++++++++++++++-
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 66 +--
.../VectorToLLVM/vector-to-llvm-one-to-n.mlir | 140 +++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 7 +-
5 files changed, 531 insertions(+), 75 deletions(-)
create mode 100644 mlir/test/Conversion/VectorToLLVM/vector-to-llvm-one-to-n.mlir
diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index cfb6cc313bc63..d3eedca39ec96 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -15,9 +15,9 @@ class LLVMTypeConverter;
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMConversionPatterns(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool reassociateFPReductions = false, bool force32BitVectorIndices = false,
- bool useVectorAlignment = false);
+ bool useVectorAlignment = false, bool enableOneToNConversion = false);
namespace vector {
void registerConvertVectorToLLVMInterface(DialectRegistry ®istry);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 05d541fe80356..d073603c0241e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1088,6 +1088,161 @@ class VectorShuffleOpConversion
}
};
+/// VectorExtractOpConversion rewritten to use 1:N conversion.
+///
+/// Handles the same cases as VectorExtractOp:
+/// * Mostly static indices (with the exception of the innermost dimension index
+/// which can be used to select a scalar.
+///
+/// Translates:
+///
+/// ```mlir
+/// %tgt = vector.extract %src[0] : vector<2xf32> from vector<2x2xf32>
+/// ```
+///
+/// Since 1:N conversion is used, src is now a collection of vectors.
+/// and target will be replaced (without any operations) into the correct
+/// vector from the collection. Following the example above:
+///
+/// %src : {vector<2xf32>, vector<2xf32>}
+/// %tgt = %src[0]
+///
+/// This pattern will insert operations only when extracting scalars.
+/// For example:
+///
+/// ```mlir
+/// %scalar = vector.extract %src[%idx] : f32 from vector<2xf32>
+/// ```
+///
+/// %src : {vector<2xf32>}
+/// %tgt = %src[0]
+///
+/// ```mlir
+/// %scalar = llvm.extractelement %tgt[%idx]
+/// ```
+///
+/// There is another case not present in VectorExtractOp where the left hand
+/// side of the statement is written into multiple operations. E.g.,
+///
+/// ```mlir
+/// %vec = vector.extract %nd[0] : vector<2x2xf32> from vector<2x2x2xf32>
+/// ```
+///
+/// In this case, source is a collection of four vector<2xf32> and target
+/// is a collection of two vector<2xf32>
+///
+/// %src : {vector<2xf32>, vector<2xf32>, vector<2xf32>, vector<2xf32>}
+/// %tgt : {%src[0], %src[1]}
+class VectorExtractOneToNOpConversion
+ : public ConvertOpToLLVMPattern<vector::ExtractOp> {
+public:
+ using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ExtractOp extractOp, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ Type resultType = extractOp.getResult().getType();
+ SmallVector<Type> llvmResultTypes;
+ if (failed(typeConverter->convertType(resultType, llvmResultTypes)))
+ return rewriter.notifyMatchFailure(extractOp,
+ "expected conversion to succeed.");
+
+ // Unlike VectorExtractOpConversion, dynamicPositions may
+ // be multiple values, even though we don't realistically expect that.
+ // Let's just verify our assumptions and get the input into a single
+ // SmallVector.
+ ArrayRef<ValueRange> dynamicPositions = adaptor.getDynamicPosition();
+ SmallVector<Value> dynamicPositionsSafe;
+ for (auto dynamicPosition : dynamicPositions)
+ if (dynamicPosition.size() > 1)
+ return rewriter.notifyMatchFailure(
+ extractOp, "expected single value in dynamic position.");
+ else
+ dynamicPositionsSafe.push_back(dynamicPosition[0]);
+
+ auto loc = extractOp->getLoc();
+ SmallVector<OpFoldResult> positionVec = getMixedValues(
+ adaptor.getStaticPosition(), dynamicPositionsSafe, rewriter);
+
+ // The Vector -> LLVM 1:N lowering models N-D vectors as a collection of
+ // 1-d vectors. We do this conversion by:
+ // - Selecting the correct values that correspond to the target vector.
+ // No operations are produced at this stage.
+ // - Extract a scalar out of the vector if needed. This is done using
+ // `llvm.extractelement`.
+
+ // Determine if we need to extract a scalar as the result. We extract
+ // a scalar if the extract is full rank, i.e., the number of indices is
+ // equal to source vector rank.
+ bool extractsScalar = static_cast<int64_t>(positionVec.size()) ==
+ extractOp.getSourceVectorType().getRank();
+
+ // Since the LLVM type converter converts 0-d vectors to 1-d vectors, we
+ // need to add a position for this change.
+ VectorType sourceTy = extractOp.getSourceVectorType();
+ bool isZeroRank = sourceTy.getRank() == 0;
+ if (isZeroRank) {
+ Type idxType = typeConverter->convertType(rewriter.getIndexType());
+ positionVec.push_back(rewriter.getZeroAttr(idxType));
+ }
+
+ ArrayRef<int64_t> sourceShape = sourceTy.getShape();
+
+ SmallVector<int64_t> strides(sourceShape);
+ if (!isZeroRank)
+ strides[strides.size() - 1] = 1;
+ else
+ strides.push_back(1);
+
+ for (int64_t i = strides.size() - 2; i >= 0; --i)
+ strides[i] *= strides[i + 1];
+
+ // Unlike VectorExtractOp, a source here will be a SmallVector<Value>
+ SmallVector<Value> extracted = adaptor.getSource();
+ ArrayRef<Value> selected;
+
+ ArrayRef<OpFoldResult> position(positionVec);
+
+ // If we are extracting a scalar from the extracted member, we drop
+ // the last index, which will be used to extract the scalar out of the
+ // vector.
+ if (extractsScalar)
+ position = position.drop_back();
+
+ if (!llvm::all_of(position, llvm::IsaPred<Attribute>))
+ return rewriter.notifyMatchFailure(
+ extractOp, "expected leading indices to be statically known.");
+
+ SmallVector<int64_t> positionInts = getAsIntegers(position);
+
+ int64_t linearIdx = 0;
+ for (auto [offset, coeff] :
+ llvm::zip(llvm::reverse(positionInts), llvm::reverse(strides)))
+ linearIdx += offset * coeff;
+
+ selected =
+ ArrayRef<Value>(extracted).slice(linearIdx, llvmResultTypes.size());
+
+ if (extractsScalar && selected.size() != 1)
+ return rewriter.notifyMatchFailure(
+ extractOp, "expected selected vectors to be a single vector");
+
+ SmallVector<Value> replacements;
+ if (extractsScalar) {
+ Value scalar = LLVM::ExtractElementOp::create(
+ rewriter, loc, selected[0],
+ getAsLLVMValue(rewriter, loc, positionVec.back()));
+ replacements.push_back(scalar);
+ } else {
+ replacements = SmallVector<Value>(selected);
+ }
+
+ rewriter.replaceOpWithMultiple(extractOp, {replacements});
+ return success();
+ }
+};
+
class VectorExtractOpConversion
: public ConvertOpToLLVMPattern<vector::ExtractOp> {
public:
@@ -1191,6 +1346,156 @@ class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
}
};
+/// VectorInsertOpConversion rewritten to use 1:N conversion.
+///
+/// Handles the same cases as VectorInsertOp:
+/// * Mostly static indices (with the exception of the innermost dimension index
+/// which can be used to select a scalar.
+///
+/// ```mlir
+/// %upd = vector.insert %src, %tgt[0] : vector<2xf32> into vector<2x2xf32>
+/// ```
+///
+/// Since 1:N conversion is used, src and tgt may now be a collection of
+/// vectors. Update will be replaced into the correct vectors from src and tgt
+/// the collection. Following the example above:
+///
+/// %src : {vector<2xf32>}
+/// %tgt : {vector<2x2xf32>, vector<2x2xf32>}
+/// %upd : {%src[0], %tgt[1]}
+///
+/// This pattern will insert operations only when inserting scalars.
+/// For example:
+///
+/// ```mlir
+/// %upd = vector.insert %src, tgt[%idx] : f32 into vector<2xf32>
+/// ```
+///
+/// %src : {f32}
+/// %tgt : {vector<2xf32>}
+///
+/// ```mlir
+/// %upd = llvm.insertelement %src, %tgt[%idx] : vector<2xf32>
+/// ```
+class VectorInsertOneToNOpConversion
+ : public ConvertOpToLLVMPattern<vector::InsertOp> {
+public:
+ using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::InsertOp insertOp, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto loc = insertOp->getLoc();
+ auto destVectorType = insertOp.getDestVectorType();
+ SmallVector<Type> llvmResultTypes;
+ if (failed(typeConverter->convertType(destVectorType, llvmResultTypes)))
+ return rewriter.notifyMatchFailure(insertOp,
+ "expected conversion to succeed.");
+
+ // Unlike VectorInsertOpConversion, dynamicPositions may
+ // be multiple values, even though we don't realistically expect that.
+ // Let's just verify our assumptions and get the input into a single
+ // SmallVector.
+ ArrayRef<ValueRange> dynamicPositions = adaptor.getDynamicPosition();
+ SmallVector<Value> dynamicPositionsSafe;
+ for (auto dynamicPosition : dynamicPositions)
+ if (dynamicPosition.size() > 1)
+ return rewriter.notifyMatchFailure(
+ insertOp, "expected single value in dynamic position.");
+ else
+ dynamicPositionsSafe.push_back(dynamicPosition[0]);
+
+ SmallVector<OpFoldResult> positionVec = getMixedValues(
+ adaptor.getStaticPosition(), dynamicPositionsSafe, rewriter);
+
+ // The logic in this pattern mirrors VectorExtractOneToNOpConversion. Refer
+ // to its explanatory comment about how N-D vectors are converted to a
+ // collection of vectors.
+ //
+ // The innermost dimension of the destination vector, when converted to a
+ // collection of vectors, will always be a 1D vector.
+ //
+ // * If the insertion is happening into the innermost dimension of the
+ // destination vector:
+ // - Select the appropriate vectors that correspond to the position
+ // indices.
+ // Unlike VectorInsertOpConversion, the inserted element may be
+ // converted to multiple values.
+ // - From the selection just done, use the innermost dimension's index
+ // to decide where to insert. This is done with
+ // llvm.insertelement.
+ // * Return the original destination vector but with the elements selected
+ // above replacing the original ones.
+
+ // Determine if we need to insert a scalar into the 1D vector.
+ bool insertIntoInnermostDim =
+ static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
+
+ bool isZeroRank = destVectorType.getRank() == 0;
+
+ ArrayRef<OpFoldResult> positionOf1DVector(
+ positionVec.begin(), insertIntoInnermostDim && !isZeroRank
+ ? positionVec.size() - 1
+ : positionVec.size());
+
+ if (!llvm::all_of(positionOf1DVector, llvm::IsaPred<Attribute>))
+ return rewriter.notifyMatchFailure(
+ insertOp,
+ "dynamic dimensions are not supported for picking 1d vectors.");
+
+ OpFoldResult positionOfScalarWithin1DVector;
+ if (isZeroRank) {
+ // Since the LLVM type converter converts 0D vectors to 1D vectors, we
+ // need to create a 0 here as the position into the 1D vector.
+ Type idxType = typeConverter->convertType(rewriter.getIndexType());
+ positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
+ } else if (insertIntoInnermostDim) {
+ positionOfScalarWithin1DVector = positionVec.back();
+ }
+
+ ArrayRef<int64_t> destShape = destVectorType.getShape();
+ SmallVector<int64_t> strides(destShape);
+ if (!isZeroRank)
+ strides[strides.size() - 1] = 1;
+ else
+ strides.push_back(1);
+
+ for (int64_t i = strides.size() - 2; i >= 0; --i)
+ strides[i] *= strides[i + 1];
+
+ SmallVector<int64_t> positionInts = getAsIntegers(positionOf1DVector);
+ int64_t linearIdx = 0;
+ for (auto [offset, coeff] :
+ llvm::zip(llvm::reverse(positionInts), llvm::reverse(strides)))
+ linearIdx += offset * coeff;
+
+ SmallVector<Value> sources = adaptor.getValueToStore();
+ SmallVector<Value> dests = adaptor.getDest();
+
+ SmallVector<Value> selected(
+ ArrayRef<Value>(dests).slice(linearIdx, sources.size()));
+
+ if (insertIntoInnermostDim) {
+ assert(selected.size() == 1 && "expected selected to be a scalar");
+ Value destVector = selected[0];
+ assert(sources.size() == 1 && "expected to to store one scalar");
+ Value scalar = sources[0];
+
+ // Insert the scalar into the 1D vector.
+ sources[0] = LLVM::InsertElementOp::create(
+ rewriter, loc, destVector.getType(), destVector, scalar,
+ getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
+ }
+
+ for (auto [idx, val] : llvm::enumerate(sources))
+ dests[linearIdx + idx] = val;
+
+ rewriter.replaceOpWithMultiple(insertOp, {dests});
+ return success();
+ }
+};
+
class VectorInsertOpConversion
: public ConvertOpToLLVMPattern<vector::InsertOp> {
public:
@@ -2196,10 +2501,12 @@ void mlir::vector::populateVectorTransposeToFlatTranspose(
}
/// Populate the given list with patterns that convert from Vector to LLVM.
-void mlir::populateVectorToLLVMConversionPatterns(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns,
- bool reassociateFPReductions, bool force32BitVectorIndices,
- bool useVectorAlignment) {
+void mlir::populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ bool reassociateFPReductions,
+ bool force32BitVectorIndices,
+ bool useVectorAlignment,
+ bool enableOneToNConversion) {
// This function populates only ConversionPatterns, not RewritePatterns.
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
@@ -2211,8 +2518,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorGatherOpConversion, VectorScatterOpConversion>(
converter, useVectorAlignment);
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
- VectorExtractOpConversion, VectorFMAOp1DConversion,
- VectorInsertOpConversion, VectorPrintOpConversion,
+ VectorFMAOp1DConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
VectorBroadcastScalarToLowRankLowering,
@@ -2222,6 +2528,77 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorDeinterleaveOpLowering, VectorFromElementsLowering,
VectorToElementsLowering, VectorScalableStepOpLowering>(
converter);
+ if (enableOneToNConversion)
+ patterns
+ .add<VectorInsertOneToNOpConversion, VectorExtractOneToNOpConversion>(
+ converter);
+ else
+ patterns.add<VectorInsertOpConversion, VectorExtractOpConversion>(
+ converter);
+
+ if (enableOneToNConversion) {
+ converter.addConversion(
+ [&](VectorType type,
+ SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+ auto elementType = converter.convertType(type.getElementType());
+ if (!elementType)
+ return failure();
+ if (type.getShape().empty()) {
+ result.push_back(VectorType::get({1}, elementType));
+ return success();
+ }
+ Type vectorType = VectorType::get(type.getShape().back(), elementType,
+ type.getScalableDims().back());
+ assert(LLVM::isCompatibleVectorType(vectorType) &&
+ "expected vector type compatible with the LLVM dialect");
+ // For n-D vector types for which a _non-trailing_ dim is scalable,
+ // return a failure. Supporting such cases would require LLVM
+ // to support something akin "scalable arrays" of vectors.
+ if (llvm::is_contained(type.getScalableDims().drop_back(), true))
+ return failure();
+
+ ArrayRef<int64_t> shapeLeadingDims = type.getShape().drop_back();
+ int64_t numVectors = ShapedType::getNumElements(shapeLeadingDims);
+ for (int64_t i = 0; i < numVectors; i++)
+ result.push_back(vectorType);
+
+ return success();
+ });
+ converter.addTargetMaterialization(
+ [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+ Location loc) -> SmallVector<Value> {
+ // from ('vector<4x4xf32>')
+ // to ('vector<4xf32>', 'vector<4xf32>', 'vector<4xf32>',
+ // 'vector<4xf32>')
+ Type ty = resultTypes[0];
+ for (Type ithTy : resultTypes)
+ if (ithTy != ty)
+ return {};
+
+ if (!isa<VectorType>(ty))
+ return {};
+
+ if (inputs.size() != 1)
+ return {};
+
+ Type inputTy = inputs[0].getType();
+ if (!isa<VectorType>(inputTy))
+ return {};
+
+ VectorType inputVectorTy = cast<VectorType>(inputTy);
+ ArrayRef<int64_t> inputShape = inputVectorTy.getShape();
+ size_t numElements =
+ !inputShape.empty()
+ ? ShapedType::getNumElements(inputShape.drop_back())
+ : 1;
+ if (numElements != resultTypes.size())
+ return {};
+
+ return UnrealizedConversionCastOp::create(builder, loc, resultTypes,
+ inputs)
+ .getResults();
+ });
+ }
}
namespace {
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 43fd1843cda28..a5b1255657c96 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -115,75 +115,11 @@ void ConvertVectorToLLVMPass::runOnOperation() {
LowerToLLVMOptions options(&getContext());
LLVMTypeConverter converter(&getContext(), options);
- if (enableOneToNConversion) {
-
- converter.addConversion(
- [&](VectorType type,
- SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
- auto elementType = converter.convertType(type.getElementType());
- if (!elementType)
- return failure();
- if (type.getShape().empty()) {
- result.push_back(VectorType::get({1}, elementType));
- return success();
- }
- Type vectorType = VectorType::get(type.getShape().back(), elementType,
- type.getScalableDims().back());
- assert(LLVM::isCompatibleVectorType(vectorType) &&
- "expected vector type compatible with the LLVM dialect");
- // For n-D vector types for which a _non-trailing_ dim is scalable,
- // return a failure. Supporting such cases would require LLVM
- // to support something akin "scalable arrays" of vectors.
- if (llvm::is_contained(type.getScalableDims().drop_back(), true))
- return failure();
-
- ArrayRef<int64_t> shapeLeadingDims = type.getShape().drop_back();
- int64_t numVectors = ShapedType::getNumElements(shapeLeadingDims);
- for (int64_t i = 0; i < numVectors; i++)
- result.push_back(vectorType);
-
- return success();
- });
-
- converter.addTargetMaterialization(
- [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
- Location loc) -> SmallVector<Value> {
- // from ('vector<4x4xf32>')
- // to ('vector<4xf32>', 'vector<4xf32>', 'vector<4xf32>',
- // 'vector<4xf32>')
- Type ty = resultTypes[0];
- for (Type ithTy : resultTypes)
- if (ithTy != ty)
- return {};
-
- if (!isa<VectorType>(ty))
- return {};
-
- if (inputs.size() != 1)
- return {};
-
- Type inputTy = inputs[0].getType();
- if (!isa<VectorType>(inputTy))
- return {};
-
- VectorType inputVectorTy = cast<VectorType>(inputTy);
- ArrayRef<int64_t> inputShape = inputVectorTy.getShape();
- size_t numElements =
- ShapedType::getNumElements(inputShape.drop_back());
- if (numElements != resultTypes.size())
- return {};
-
- return UnrealizedConversionCastOp::create(builder, loc, resultTypes,
- inputs)
- .getResults();
- });
- }
-
RewritePatternSet patterns(&getContext());
populateVectorTransferLoweringPatterns(patterns);
populateVectorToLLVMConversionPatterns(
converter, patterns, reassociateFPReductions, force32BitVectorIndices,
- useVectorAlignment);
+ useVectorAlignment, enableOneToNConversion);
// Architecture specific augmentations.
LLVMConversionTarget target(getContext());
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-one-to-n.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-one-to-n.mlir
new file mode 100644
index 0000000000000..63316a38332f4
--- /dev/null
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-one-to-n.mlir
@@ -0,0 +1,140 @@
+// RUN: mlir-opt --convert-vector-to-llvm="enable-one-to-n-conversion=true" --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: func @vector_extract_vector(
+// CHECK-SAME: %[[ARG0:.+]]: vector<4x4xf32>
+func.func @vector_extract_vector(%arg0: vector<4x4xf32>) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) {
+ // CHECK-NEXT: %[[CAST:.+]]:4 = builtin.unrealized_conversion_cast %[[ARG0]] : vector<4x4xf32> to vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
+ %0 = vector.extract %arg0[0] : vector<4xf32> from vector<4x4xf32>
+ %1 = vector.extract %arg0[1] : vector<4xf32> from vector<4x4xf32>
+ %2 = vector.extract %arg0[2] : vector<4xf32> from vector<4x4xf32>
+ %3 = vector.extract %arg0[3] : vector<4xf32> from vector<4x4xf32>
+ // CHECK-NEXT: return %[[CAST]]#3, %[[CAST]]#2, %[[CAST]]#1, %[[CAST]]#0
+ return %3, %2, %1, %0 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_extract_linearize(
+// CHECK-SAME: %[[ARG0:.+]]: vector<5x4x3xf32>
+func.func @vector_extract_linearize(%arg0: vector<5x4x3xf32>) -> (vector<3xf32>, vector<3xf32>, vector<3xf32>) {
+ // CHECK-NEXT: %[[CAST:.+]]:20 = builtin.unrealized_conversion_cast %[[ARG0]] : vector<5x4x3xf32> to vector<3xf32>
+ %0 = vector.extract %arg0[0, 0] : vector<3xf32> from vector<5x4x3xf32>
+ %1 = vector.extract %arg0[0, 1] : vector<3xf32> from vector<5x4x3xf32>
+ %2 = vector.extract %arg0[1, 1] : vector<3xf32> from vector<5x4x3xf32>
+ // CHECK-NEXT: return %[[CAST]]#0, %[[CAST]]#1, %[[CAST]]#5
+ return %0, %1, %2 : vector<3xf32>, vector<3xf32>, vector<3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_extract_scalar(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
+func.func @vector_extract_scalar(%arg0: vector<2x2xf32>) -> (f32) {
+ // CHECK: %[[CAST:.+]]:2 = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to vector<2xf32>, vector<2xf32>
+ // CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64)
+ // CHECK: %[[EXTRACTED:.+]] = llvm.extractelement %[[CAST]]#0[%[[C0]] : i64]
+ %0 = vector.extract %arg0[0, 0] : f32 from vector<2x2xf32>
+ // CHECK: return %[[EXTRACTED]]
+ return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_extract_lhs_multiple(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2x2xf32>)
+func.func @vector_extract_lhs_multiple(%arg0: vector<2x2x2xf32>) -> vector<2x2xf32> {
+ // CHECK: %[[CAST:.+]]:4 = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2x2xf32> to vector<2xf32>
+ // CHECK: %[[SELECTED:.+]] = builtin.unrealized_conversion_cast %[[CAST]]#0, %[[CAST]]#1 : vector<2xf32>, vector<2xf32>
+ %0 = vector.extract %arg0[0] : vector<2x2xf32> from vector<2x2x2xf32>
+ // CHECK: return %[[SELECTED]]
+ return %0 : vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_extract_rank_0(
+// CHECK-SAME: %[[ARG0:.+]]: vector<f32>)
+func.func @vector_extract_rank_0(%arg0: vector<f32>) -> f32 {
+ // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<f32> to vector<1xf32>
+ // CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64)
+ // CHECK: %[[ELEM:.+]] = llvm.extractelement %[[CAST]][%[[C0]] : i64]
+ %0 = vector.extract %arg0[] : f32 from vector<f32>
+ // CHECK: return %[[ELEM]]
+ return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_insert_vector(
+// CHECK-SAME: %[[VAL0:.+]]: vector<4xf32>, %[[VAL1:.+]]: vector<4xf32>, %[[VAL2:.+]]: vector<4xf32>, %[[AGG:.+]]: vector<4x4xf32>)
+func.func @vector_insert_vector(%val0: vector<4xf32>, %val1: vector<4xf32>, %val2: vector<4xf32>, %agg: vector<4x4xf32>) -> (vector<4x4xf32>) {
+ // CHECK: %[[CAST:.+]]:4 = builtin.unrealized_conversion_cast %[[AGG]] : vector<4x4xf32> to vector<4xf32>
+ %0 = vector.insert %val0, %agg[0] : vector<4xf32> into vector<4x4xf32>
+ %1 = vector.insert %val1, %0[1] : vector<4xf32> into vector<4x4xf32>
+ %2 = vector.insert %val2, %1[2] : vector<4xf32> into vector<4x4xf32>
+
+ // CHECK: %[[INSERTION_CAST:.+]] = builtin.unrealized_conversion_cast %[[VAL0]], %[[VAL1]], %[[VAL2]], %[[CAST]]#3
+ // CHECK: return %[[INSERTION_CAST]]
+ return %2 : vector<4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_insert_linearize(
+// CHECK-SAME: %[[VAL:.+]]: vector<3xf32>, %[[AGG:.+]]: vector<5x4x3xf32>)
+func.func @vector_insert_linearize(%val: vector<3xf32>, %agg: vector<5x4x3xf32>) -> (vector<5x4x3xf32>) {
+ // CHECK: %[[CAST:.+]]:20 = builtin.unrealized_conversion_cast %[[AGG]] : vector<5x4x3xf32> to vector<3xf32>
+
+ %0 = vector.insert %val, %agg[0, 0] : vector<3xf32> into vector<5x4x3xf32>
+ %1 = vector.insert %val, %0[0, 1] : vector<3xf32> into vector<5x4x3xf32>
+ %2 = vector.insert %val, %1[1, 1] : vector<3xf32> into vector<5x4x3xf32>
+
+ // CHECK: %[[INSERTION_CAST:.+]] = builtin.unrealized_conversion_cast %[[VAL]], %[[VAL]], %[[CAST]]#2, %[[CAST]]#3, %[[CAST]]#4, %[[VAL]]
+ // CHECK: return %[[INSERTION_CAST]]
+ return %2 : vector<5x4x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_insert_scalar(
+// CHECK-SAME: %[[VAL:.+]]: f32, %[[AGG:.+]]: vector<2x2xf32>)
+func.func @vector_insert_scalar(%val: f32, %agg: vector<2x2xf32>) -> (vector<2x2xf32>) {
+ // CHECK-DAG: %[[CAST:.+]]:2 = builtin.unrealized_conversion_cast %[[AGG]] : vector<2x2xf32> to vector<2xf32>
+ // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i64)
+ // CHECK: %[[MODIFIED_VECTOR:.+]] = llvm.insertelement %[[VAL]], %[[CAST]]#0[%[[C1]] : i64]
+ // CHECK: %[[INSERTION_CAST:.+]] = builtin.unrealized_conversion_cast %[[MODIFIED_VECTOR]], %[[CAST]]#1
+ %0 = vector.insert %val, %agg[0, 1] : f32 into vector<2x2xf32>
+
+ // CHECK: return %[[INSERTION_CAST]]
+ return %0 : vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_insert_sources_multiple(
+// CHECK-SAME: %[[TO_STORE:.+]]: vector<2x2xf32>,
+// CHECK-SAME: %[[DEST:.+]]: vector<2x2x2xf32>
+func.func @vector_insert_sources_multiple(%val: vector<2x2xf32>, %dest: vector<2x2x2xf32>) -> (vector<2x2x2xf32>) {
+ // CHECK: %[[CAST_DEST:.+]]:4 = builtin.unrealized_conversion_cast %[[DEST]] : vector<2x2x2xf32> to vector<2xf32>
+ // CHECK: %[[CAST_TO_STORE:.+]]:2 = builtin.unrealized_conversion_cast %[[TO_STORE]] : vector<2x2xf32> to vector<2xf32>
+ // CHECK: %[[INSERT:.+]] = builtin.unrealized_conversion_cast %[[CAST_TO_STORE]]#0, %[[CAST_TO_STORE]]#1, %[[CAST_DEST]]#2, %[[CAST_DEST]]#3
+
+ %0 = vector.insert %val, %dest[0] : vector<2x2xf32> into vector<2x2x2xf32>
+ // CHECK: return %[[INSERT]]
+ return %0 : vector<2x2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_insert_rank_0(
+// CHECK-SAME: %[[TO_STORE:.+]]: f32,
+// CHECK-SAME: %[[DEST:.+]]: vector<f32>
+func.func @vector_insert_rank_0(%val: f32, %dest: vector<f32>) -> (vector<f32>) {
+ // CHECK: %[[CAST_DEST:.+]] = builtin.unrealized_conversion_cast %[[DEST]] : vector<f32> to vector<1xf32>
+ // CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64)
+ // CHECK: %[[INSERT:.+]] = llvm.insertelement %[[TO_STORE]], %[[CAST_DEST]][%[[C0]] : i64]
+ // CHECK: %[[INSERT_CAST:.+]] = builtin.unrealized_conversion_cast %[[INSERT]] : vector<1xf32> to vector<f32>
+ %0 = vector.insert %val, %dest[] : f32 into vector<f32>
+ // CHECK: return %[[INSERT_CAST]]
+ return %0 : vector<f32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index db941454f8d8c..544b9521c5d79 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -51,8 +51,8 @@ struct TestVectorToVectorLowering
}
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<affine::AffineDialect>();
- registry.insert<vector::VectorDialect>();
+ registry.insert<affine::AffineDialect, arith::ArithDialect,
+ vector::VectorDialect>();
}
Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
@@ -150,6 +150,9 @@ struct TestVectorUnrollingPatterns
return "Test lowering patterns to unroll contract ops in the vector "
"dialect";
}
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<vector::VectorDialect>();
+ }
TestVectorUnrollingPatterns() = default;
TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
: PassWrapper(pass) {}
More information about the Mlir-commits
mailing list