[Mlir-commits] [mlir] [mlir][vector][spirv] Lower `vector.to_elements` to SPIR-V (PR #146618)
Eric Feng
llvmlistbot at llvm.org
Wed Jul 2 11:40:27 PDT 2025
https://github.com/efric updated https://github.com/llvm/llvm-project/pull/146618
>From 2141816222b374c296f432e441334bd6ead2d5b1 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Mon, 30 Jun 2025 23:28:50 -0700
Subject: [PATCH 1/8] add initial lowering pieces
---
.../VectorToSPIRV/VectorToSPIRV.cpp | 38 +++++++++++++++++++
.../VectorToSPIRV/vector-to-spirv.mlir | 30 +++++++++++++++
2 files changed, 68 insertions(+)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index de2af69eba9ec..8e38b5280c527 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -1022,6 +1022,44 @@ struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
}
};
+struct VectorToElementOpConvert final
+ : public OpConversionPattern<vector::ToElementsOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ Type srcType =
+ getTypeConverter()->convertType(toElementsOp.getSource().getType());
+ if (!srcType)
+ return failure();
+
+ // If the input vector was size 1, then it would have been converted to a
+ // scalar. Replace with it directly
+ if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
+ rewriter.replaceOp(toElementsOp, adaptor.getSource());
+ return success();
+ }
+
+ Location loc = toElementsOp.getLoc();
+ SmallVector<Value> results(toElementsOp->getNumResults());
+
+ for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
+ if (element.use_empty())
+ continue;
+
+ Value result = rewriter.create<spirv::CompositeExtractOp>(
+ loc, toElementsOp->getResult(idx).getType(), adaptor.getSource(),
+ rewriter.getI32ArrayAttr({static_cast<int>(idx)}));
+ results[idx] = result;
+ }
+
+ rewriter.replaceOp(toElementsOp, results);
+ return success();
+ }
+};
+
} // namespace
#define CL_INT_MAX_MIN_OPS \
spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 4701ac5d96009..c8d0c25f252cf 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -244,6 +244,36 @@ func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
return %0: f32
}
+// -----
+// we need a test for with dead, no dead, 1 element
+
+// CHECK-LABEL: @from_elements_0d
+// CHECK-SAME: %[[ARG0:.+]]: f32
+// CHECK: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+// CHECK: return %[[RETVAL]]
+func.func @from_elements_0d(%arg0 : f32) -> vector<f32> {
+ %0 = vector.from_elements %arg0 : vector<f32>
+ return %0: vector<f32>
+}
+
+// CHECK-LABEL: @from_elements_1x
+// CHECK-SAME: %[[ARG0:.+]]: f32
+// CHECK: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+// CHECK: return %[[RETVAL]]
+func.func @from_elements_1x(%arg0 : f32) -> vector<1xf32> {
+ %0 = vector.from_elements %arg0 : vector<1xf32>
+ return %0: vector<1xf32>
+}
+
+// CHECK-LABEL: @from_elements_3x
+// CHECK-SAME: %[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32
+// CHECK: %[[RETVAL:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ARG1]], %[[ARG2]] : (f32, f32, f32) -> vector<3xf32>
+// CHECK: return %[[RETVAL]]
+func.func @from_elements_3x(%arg0 : f32, %arg1 : f32, %arg2 : f32) -> vector<3xf32> {
+ %0 = vector.from_elements %arg0, %arg1, %arg2 : vector<3xf32>
+ return %0: vector<3xf32>
+}
+
// -----
// CHECK-LABEL: @from_elements_0d
>From 9fbc662254359835d5dccbbb264fcc03feb34088 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Tue, 1 Jul 2025 14:12:19 -0700
Subject: [PATCH 2/8] test
---
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 8e38b5280c527..b51553ad9d33f 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -13,6 +13,7 @@
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
@@ -1049,9 +1050,10 @@ struct VectorToElementOpConvert final
if (element.use_empty())
continue;
+ auto spirvType = getTypeConverter()->convertType(element.getType());
Value result = rewriter.create<spirv::CompositeExtractOp>(
- loc, toElementsOp->getResult(idx).getType(), adaptor.getSource(),
- rewriter.getI32ArrayAttr({static_cast<int>(idx)}));
+ loc, spirvType, adaptor.getSource(),
+ rewriter.getI32ArrayAttr({static_cast<int32_t>(idx)}));
results[idx] = result;
}
@@ -1076,7 +1078,7 @@ void mlir::populateVectorToSPIRVPatterns(
VectorBitcastConvert, VectorBroadcastConvert,
VectorExtractElementOpConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
- VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
+ VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert, VectorToElementOpConvert,
VectorInsertElementOpConvert, VectorInsertOpConvert,
VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
>From 751cdd47c9fc15f9315de484e78ffda1a34d4624 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Tue, 1 Jul 2025 15:54:03 -0700
Subject: [PATCH 3/8] nits
---
.../VectorToSPIRV/VectorToSPIRV.cpp | 21 +++++++++++--------
1 file changed, 12 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index b51553ad9d33f..9c7a413db8d35 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -13,7 +13,6 @@
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
@@ -1036,23 +1035,27 @@ struct VectorToElementOpConvert final
if (!srcType)
return failure();
- // If the input vector was size 1, then it would have been converted to a
- // scalar. Replace with it directly
+ SmallVector<Value> results(toElementsOp->getNumResults());
+ Location loc = toElementsOp.getLoc();
+
+ // Input vectors of size 1 are converted to scalars by the type converter.
+ // We cannot use `spirv::CompositeExtractOp` directly in this case.
+ // For a scalar source, the result is just the scalar itself.
if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
- rewriter.replaceOp(toElementsOp, adaptor.getSource());
+ results[0] = adaptor.getSource();
+ rewriter.replaceOp(toElementsOp, results);
return success();
}
- Location loc = toElementsOp.getLoc();
- SmallVector<Value> results(toElementsOp->getNumResults());
-
for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
if (element.use_empty())
continue;
- auto spirvType = getTypeConverter()->convertType(element.getType());
+ auto elementType = getTypeConverter()->convertType(element.getType());
+ if (!elementType)
+ return failure();
Value result = rewriter.create<spirv::CompositeExtractOp>(
- loc, spirvType, adaptor.getSource(),
+ loc, elementType, adaptor.getSource(),
rewriter.getI32ArrayAttr({static_cast<int32_t>(idx)}));
results[idx] = result;
}
>From 33b28dedfaa33d2e301c290ed351a7c79d444be3 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Tue, 1 Jul 2025 18:26:08 -0700
Subject: [PATCH 4/8] nits and add test
---
.../VectorToSPIRV/VectorToSPIRV.cpp | 1 +
.../VectorToSPIRV/vector-to-spirv.mlir | 57 ++++++++++---------
2 files changed, 32 insertions(+), 26 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 9c7a413db8d35..b350ef7a42b95 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -1048,6 +1048,7 @@ struct VectorToElementOpConvert final
}
for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
+ // Create an CompositeExtract operation only for results that are not dead.
if (element.use_empty())
continue;
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index c8d0c25f252cf..99ab0e1dc4eef 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -245,33 +245,38 @@ func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
}
// -----
-// we need a test for with dead, no dead, 1 element
-// CHECK-LABEL: @from_elements_0d
-// CHECK-SAME: %[[ARG0:.+]]: f32
-// CHECK: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
-// CHECK: return %[[RETVAL]]
-func.func @from_elements_0d(%arg0 : f32) -> vector<f32> {
- %0 = vector.from_elements %arg0 : vector<f32>
- return %0: vector<f32>
-}
-
-// CHECK-LABEL: @from_elements_1x
-// CHECK-SAME: %[[ARG0:.+]]: f32
-// CHECK: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
-// CHECK: return %[[RETVAL]]
-func.func @from_elements_1x(%arg0 : f32) -> vector<1xf32> {
- %0 = vector.from_elements %arg0 : vector<1xf32>
- return %0: vector<1xf32>
-}
-
-// CHECK-LABEL: @from_elements_3x
-// CHECK-SAME: %[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32
-// CHECK: %[[RETVAL:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ARG1]], %[[ARG2]] : (f32, f32, f32) -> vector<3xf32>
-// CHECK: return %[[RETVAL]]
-func.func @from_elements_3x(%arg0 : f32, %arg1 : f32, %arg2 : f32) -> vector<3xf32> {
- %0 = vector.from_elements %arg0, %arg1, %arg2 : vector<3xf32>
- return %0: vector<3xf32>
+// CHECK-LABEL: func.func @to_elements_one_element
+// CHECK-SAME: %[[A:.*]]: vector<1xf32>)
+// CHECK: %[[ELEM0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<1xf32> to f32
+// CHECK: return %[[ELEM0]] : f32
+func.func @to_elements_one_element(%a: vector<1xf32>) -> (f32) {
+ %0:1 = vector.to_elements %a : vector<1xf32>
+ return %0#0 : f32
+}
+
+// CHECK-LABEL: func.func @to_elements_no_dead_elements
+// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
+// CHECK: %[[ELEM0:.*]] = spirv.CompositeExtract %[[A]][0 : i32] : vector<4xf32>
+// CHECK: %[[ELEM1:.*]] = spirv.CompositeExtract %[[A]][1 : i32] : vector<4xf32>
+// CHECK: %[[ELEM2:.*]] = spirv.CompositeExtract %[[A]][2 : i32] : vector<4xf32>
+// CHECK: %[[ELEM3:.*]] = spirv.CompositeExtract %[[A]][3 : i32] : vector<4xf32>
+// CHECK: return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32
+func.func @to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) {
+ %0:4 = vector.to_elements %a : vector<4xf32>
+ return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
+
+// CHECK-LABEL: func.func @to_elements_dead_elements
+// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
+// CHECK-NOT: spirv.CompositeExtract %[[A]][0 : i32]
+// CHECK: %[[ELEM1:.*]] = spirv.CompositeExtract %[[A]][1 : i32] : vector<4xf32>
+// CHECK-NOT: spirv.CompositeExtract %[[A]][2 : i32]
+// CHECK: %[[ELEM3:.*]] = spirv.CompositeExtract %[[A]][3 : i32] : vector<4xf32>
+// CHECK: return %[[ELEM1]], %[[ELEM3]] : f32, f32
+func.func @to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
+ %0:4 = vector.to_elements %a : vector<4xf32>
+ return %0#1, %0#3 : f32, f32
}
// -----
>From b02c31ed0b164b7928362b71407f2c5967bd3e4d Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Tue, 1 Jul 2025 18:33:58 -0700
Subject: [PATCH 5/8] format
---
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index b350ef7a42b95..475fd76c667e6 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -1048,7 +1048,8 @@ struct VectorToElementOpConvert final
}
for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
- // Create an CompositeExtract operation only for results that are not dead.
+ // Create an CompositeExtract operation only for results that are not
+ // dead.
if (element.use_empty())
continue;
>From 98ab5f3100ce3ac436bdd217671e85683bced4b5 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Wed, 2 Jul 2025 10:49:57 -0700
Subject: [PATCH 6/8] nits
---
.../Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 16 +++++++---------
1 file changed, 7 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 475fd76c667e6..c79b368b333e4 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -1023,18 +1023,13 @@ struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
};
struct VectorToElementOpConvert final
- : public OpConversionPattern<vector::ToElementsOp> {
+ : OpConversionPattern<vector::ToElementsOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Type srcType =
- getTypeConverter()->convertType(toElementsOp.getSource().getType());
- if (!srcType)
- return failure();
-
SmallVector<Value> results(toElementsOp->getNumResults());
Location loc = toElementsOp.getLoc();
@@ -1047,15 +1042,18 @@ struct VectorToElementOpConvert final
return success();
}
+ Type srcElementType = toElementsOp.getElements().getType().front();
+ Type elementType = getTypeConverter()->convertType(srcElementType);
+ if (!elementType)
+ return rewriter.notifyMatchFailure(
+ toElementsOp, "unsupported element type in source vector");
+
for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
// Create an CompositeExtract operation only for results that are not
// dead.
if (element.use_empty())
continue;
- auto elementType = getTypeConverter()->convertType(element.getType());
- if (!elementType)
- return failure();
Value result = rewriter.create<spirv::CompositeExtractOp>(
loc, elementType, adaptor.getSource(),
rewriter.getI32ArrayAttr({static_cast<int32_t>(idx)}));
>From 18c09ce355f509d0f9339dd9e3d53fd8948090be Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Wed, 2 Jul 2025 11:32:43 -0700
Subject: [PATCH 7/8] nit on error message
---
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c79b368b333e4..abe0b55692d88 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -1046,7 +1046,9 @@ struct VectorToElementOpConvert final
Type elementType = getTypeConverter()->convertType(srcElementType);
if (!elementType)
return rewriter.notifyMatchFailure(
- toElementsOp, "unsupported element type in source vector");
+ toElementsOp,
+ llvm::formatv("failed to convert element type '{0}' to SPIR-V",
+ srcElementType));
for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
// Create an CompositeExtract operation only for results that are not
>From ba50736a874ad526d5527f13b95036c13b5e6e57 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Wed, 2 Jul 2025 11:41:21 -0700
Subject: [PATCH 8/8] format
---
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index abe0b55692d88..21d8e1d9f1156 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -1083,9 +1083,9 @@ void mlir::populateVectorToSPIRVPatterns(
VectorBitcastConvert, VectorBroadcastConvert,
VectorExtractElementOpConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
- VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert, VectorToElementOpConvert,
- VectorInsertElementOpConvert, VectorInsertOpConvert,
- VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
+ VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
+ VectorToElementOpConvert, VectorInsertElementOpConvert,
+ VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
More information about the Mlir-commits
mailing list