[Mlir-commits] [mlir] [mlir][vector] Add vector.step operation (PR #96776)
Cullen Rhodes
llvmlistbot at llvm.org
Wed Jul 3 03:13:57 PDT 2024
https://github.com/c-rhodes updated https://github.com/llvm/llvm-project/pull/96776
>From 1ba850c1801b94bcb903bbd6c4d715a0d4b6c959 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 19 Jun 2024 15:29:25 +0000
Subject: [PATCH 1/2] [mlir][vector] Add vector.step operation
This patch adds a new vector.step operation to the Vector dialect. It
produces a linear sequence of index values from 0 to N, where N is the
number of elements in the result vector, and can be used to create
vectors of indices.
It supports both fixed-width and scalable vectors. For fixed the
canonical representation is `arith.constant dense<[0, .., N]>`. A
scalable step cannot be represented as a constant and is lowered to the
`llvm.experimental.stepvector` intrinsic [1].
[1] https://llvm.org/docs/LangRef.html#llvm-experimental-stepvector-intrinsic
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 29 +++++++++++++++++++
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 17 +++++++++--
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 14 +++++++++
.../VectorToLLVM/vector-to-llvm.mlir | 11 +++++++
mlir/test/Dialect/Vector/canonicalize.mlir | 10 +++++++
mlir/test/Dialect/Vector/invalid.mlir | 16 ++++++++++
mlir/test/Dialect/Vector/ops.mlir | 9 +++++-
7 files changed, 103 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 097e5e6fb0d61..94cba7d7882cd 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -3017,6 +3017,35 @@ def Vector_ScanOp :
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// VectorStepOp
+//===----------------------------------------------------------------------===//
+
+def Vector_StepOp : Vector_Op<"step", [Pure]> {
+ let summary = "A linear sequence of values from 0 to N";
+ let description = [{
+ A `step` operation produces an index vector, i.e. a 1-D vector of values of
+ index type that represents a linear sequence from 0 to N, where N is the
+ number of elements in the `result` vector.
+
+ Supports fixed-width and scalable vectors. For fixed the canonical
+ representation is `arith.constant dense<[0, .., N]>`. A scalable step
+ cannot be represented as a constant and is lowered to the
+ [llvm.experimental.stepvector](https://llvm.org/docs/LangRef.html#llvm-experimental-stepvector-intrinsic)
+ intrinsic.
+
+ Examples:
+
+ ```mlir
+ %0 = vector.step : vector<4xindex> ; [0, 1, 2, 3]
+ %1 = vector.step : vector<[4]xindex> ; [0, 1, .., <vscale * 4 - 1>]
+ ```
+ }];
+ let hasFolder = 1;
+ let results = (outs VectorOfRankAndType<[1], [Index]>:$result);
+ let assemblyFormat = "attr-dict `:` type($result)";
+}
+
def Vector_YieldOp : Vector_Op<"yield", [
Pure, ReturnLike, Terminator]> {
let summary = "Terminates and yields values from vector regions.";
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 0eac55255b133..6a8a9d818aad2 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1860,6 +1860,19 @@ struct VectorFromElementsLowering
}
};
+/// Conversion pattern for vector.step.
+struct VectorStepOpLowering : public ConvertOpToLLVMPattern<vector::StepOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type llvmType = typeConverter->convertType(stepOp.getType());
+ rewriter.replaceOpWithNewOp<LLVM::StepVectorOp>(stepOp, llvmType);
+ return success();
+ }
+};
+
} // namespace
/// Populate the given list with patterns that convert from Vector to LLVM.
@@ -1885,8 +1898,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorSplatOpLowering, VectorSplatNdOpLowering,
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
MaskedReductionOpConversion, VectorInterleaveOpLowering,
- VectorDeinterleaveOpLowering, VectorFromElementsLowering>(
- converter);
+ VectorDeinterleaveOpLowering, VectorFromElementsLowering,
+ VectorStepOpLowering>(converter);
// Transfer ops with rank > 1 are handled by VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 6734c80f2760d..8efafcab5529e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6316,6 +6316,20 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
return SplatElementsAttr::get(getType(), {constOperand});
}
+//===----------------------------------------------------------------------===//
+// StepOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult StepOp::fold(FoldAdaptor adaptor) {
+ auto resultType = cast<VectorType>(getType());
+ if (resultType.isScalable())
+ return nullptr;
+ SmallVector<APInt> indices;
+ for (unsigned i = 0; i < resultType.getNumElements(); i++)
+ indices.push_back(APInt(/*width=*/64, i));
+ return DenseElementsAttr::get(resultType, indices);
+}
+
//===----------------------------------------------------------------------===//
// WarpExecuteOnLane0Op
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 09b79708a9ab2..897ff7ad6b43a 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2621,3 +2621,14 @@ func.func @vector_from_elements_0d(%a: f32) -> vector<f32> {
%0 = vector.from_elements %a : vector<f32>
return %0 : vector<f32>
}
+
+// -----
+
+// CHECK-LABEL: @vector_step
+// CHECK: %[[STEPVECTOR:.*]] = llvm.intr.experimental.stepvector : vector<[4]xi64>
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[STEPVECTOR]] : vector<[4]xi64> to vector<[4]xindex>
+// CHECK: return %[[CAST]] : vector<[4]xindex>
+func.func @vector_step() -> vector<[4]xindex> {
+ %0 = vector.step : vector<[4]xindex>
+ return %0 : vector<[4]xindex>
+}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8181f1a8c5d13..9c3bbb907cfb4 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2711,3 +2711,13 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
// CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
}
+
+// -----
+
+// CHECK-LABEL: @fold_vector_step_to_constant
+// CHECK: %[[CONSTANT:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK: return %[[CONSTANT]] : vector<4xindex>
+func.func @fold_vector_step_to_constant() -> vector<4xindex> {
+ %0 = vector.step : vector<4xindex>
+ return %0 : vector<4xindex>
+}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index d0eaed8f98cc5..db169a6c1f8ae 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1871,3 +1871,19 @@ func.func @invalid_from_elements(%a: f32, %b: i32) {
vector.from_elements %a, %b : vector<2xf32>
return
}
+
+// -----
+
+func.func @invalid_step_0d() {
+ // expected-error @+1 {{vector.step' op result #0 must be vector of index values of ranks 1, but got 'vector<f32>'}}
+ vector.step : vector<f32>
+ return
+}
+
+// -----
+
+func.func @invalid_step_2d() {
+ // expected-error @+1 {{vector.step' op result #0 must be vector of index values of ranks 1, but got 'vector<2x4xf32>'}}
+ vector.step : vector<2x4xf32>
+ return
+}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 4da09584db88b..7908e61abc704 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1171,4 +1171,11 @@ func.func @from_elements(%a: f32, %b: f32) -> (vector<f32>, vector<1xf32>, vecto
// CHECK: vector.from_elements %[[b]], %[[b]], %[[a]], %[[a]] : vector<2x2xf32>
%3 = vector.from_elements %b, %b, %a, %a : vector<2x2xf32>
return %0, %1, %2, %3 : vector<f32>, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32>
-}
\ No newline at end of file
+}
+
+// CHECK-LABEL: @step
+func.func @step() {
+ %0 = vector.step : vector<2xindex>
+ %1 = vector.step : vector<[4]xindex>
+ return
+}
>From 6a4b1f275f7bae21995163335aeb5df81e39092d Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 3 Jul 2024 10:11:33 +0000
Subject: [PATCH 2/2] address comments
---
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 6 +++---
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 4 ++--
mlir/test/Dialect/Vector/ops.mlir | 2 ++
3 files changed, 7 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 94cba7d7882cd..afe364fbd9fc3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -3028,9 +3028,9 @@ def Vector_StepOp : Vector_Op<"step", [Pure]> {
index type that represents a linear sequence from 0 to N, where N is the
number of elements in the `result` vector.
- Supports fixed-width and scalable vectors. For fixed the canonical
- representation is `arith.constant dense<[0, .., N]>`. A scalable step
- cannot be represented as a constant and is lowered to the
+ Supports fixed-width and scalable vectors. For a fixed-width `step` vector,
+ the canonical representation is `arith.constant dense<[0, .., N]>`. A
+ scalable step cannot be represented as a constant and is lowered to the
[llvm.experimental.stepvector](https://llvm.org/docs/LangRef.html#llvm-experimental-stepvector-intrinsic)
intrinsic.
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 897ff7ad6b43a..5f2d2809a0fe8 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2624,11 +2624,11 @@ func.func @vector_from_elements_0d(%a: f32) -> vector<f32> {
// -----
-// CHECK-LABEL: @vector_step
+// CHECK-LABEL: @vector_step_scalable
// CHECK: %[[STEPVECTOR:.*]] = llvm.intr.experimental.stepvector : vector<[4]xi64>
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[STEPVECTOR]] : vector<[4]xi64> to vector<[4]xindex>
// CHECK: return %[[CAST]] : vector<[4]xindex>
-func.func @vector_step() -> vector<[4]xindex> {
+func.func @vector_step_scalable() -> vector<[4]xindex> {
%0 = vector.step : vector<[4]xindex>
return %0 : vector<[4]xindex>
}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 7908e61abc704..531e2db636431 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1175,7 +1175,9 @@ func.func @from_elements(%a: f32, %b: f32) -> (vector<f32>, vector<1xf32>, vecto
// CHECK-LABEL: @step
func.func @step() {
+ // CHECK: vector.step : vector<2xindex>
%0 = vector.step : vector<2xindex>
+ // CHECK: vector.step : vector<[4]xindex>
%1 = vector.step : vector<[4]xindex>
return
}
More information about the Mlir-commits
mailing list