[Mlir-commits] [mlir] 3ff4e5f - [mlir][Vector] Thread 0-d vectors through InsertElementOp.
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Nov 23 04:56:53 PST 2021
Author: Nicolas Vasilache
Date: 2021-11-23T12:55:11Z
New Revision: 3ff4e5f2a4a6a0e124356b2ad8793270ebbb16c1
URL: https://github.com/llvm/llvm-project/commit/3ff4e5f2a4a6a0e124356b2ad8793270ebbb16c1
DIFF: https://github.com/llvm/llvm-project/commit/3ff4e5f2a4a6a0e124356b2ad8793270ebbb16c1.diff
LOG: [mlir][Vector] Thread 0-d vectors through InsertElementOp.
This revision makes concrete use of 0-d vectors to extend the semantics of
InsertElementOp.
Reviewed By: dcaballe, pifon2a
Differential Revision: https://reviews.llvm.org/D114388
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index d8ffdff7667e9..f274ff656f253 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -666,16 +666,18 @@ def Vector_InsertElementOp :
"result", "source",
"$_self.cast<ShapedType>().getElementType()">,
AllTypesMatch<["dest", "result"]>]>,
- Arguments<(ins AnyType:$source, AnyVector:$dest,
- AnySignlessIntegerOrIndex:$position)>,
- Results<(outs AnyVector:$result)> {
+ Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest,
+ Optional<AnySignlessIntegerOrIndex>:$position)>,
+ Results<(outs AnyVectorOfAnyRank:$result)> {
let summary = "insertelement operation";
let description = [{
- Takes a scalar source, an 1-D destination vector and a dynamic index
- position and inserts the source into the destination at the proper
- position. Note that this instruction resembles vector.insert, but
- is restricted to 1-D vectors and relaxed to dynamic indices. It is
- meant to be closer to LLVM's version:
+ Takes a scalar source, a 0-D or 1-D destination vector and a dynamic index
+ position and inserts the source into the destination at the proper position.
+
+ Note that this instruction resembles vector.insert, but is restricted to 0-D
+ and 1-D vectors and relaxed to dynamic indices.
+
+ It is meant to be closer to LLVM's version:
https://llvm.org/docs/LangRef.html#insertelement-instruction
Example:
@@ -684,14 +686,18 @@ def Vector_InsertElementOp :
%c = arith.constant 15 : i32
%f = arith.constant 0.0f : f32
%1 = vector.insertelement %f, %0[%c : i32]: vector<16xf32>
+ %2 = vector.insertelement %f, %z[]: vector<f32>
```
}];
let assemblyFormat = [{
- $source `,` $dest `[` $position `:` type($position) `]` attr-dict `:`
+ $source `,` $dest `[` ($position^ `:` type($position))? `]` attr-dict `:`
type($result)
}];
let builders = [
+ // 0-D builder.
+ OpBuilder<(ins "Value":$source, "Value":$dest)>,
+ // 1-D + position builder.
OpBuilder<(ins "Value":$source, "Value":$dest, "Value":$position)>
];
let extraClassDeclaration = [{
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index c74eca56b84b2..108e664f03cad 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -663,6 +663,17 @@ class VectorInsertElementOpConversion
if (!llvmType)
return failure();
+ if (vectorType.getRank() == 0) {
+ Location loc = insertEltOp.getLoc();
+ auto idxType = rewriter.getIndexType();
+ auto zero = rewriter.create<LLVM::ConstantOp>(
+ loc, typeConverter->convertType(idxType),
+ rewriter.getIntegerAttr(idxType, 0));
+ rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
+ insertEltOp, llvmType, adaptor.dest(), adaptor.source(), zero);
+ return success();
+ }
+
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
adaptor.position());
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index d8cd3c178ad15..e4438fc86e5fd 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1553,6 +1553,12 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
// InsertElementOp
//===----------------------------------------------------------------------===//
+void InsertElementOp::build(OpBuilder &builder, OperationState &result,
+ Value source, Value dest) {
+ result.addOperands({source, dest});
+ result.addTypes(dest.getType());
+}
+
void InsertElementOp::build(OpBuilder &builder, OperationState &result,
Value source, Value dest, Value position) {
result.addOperands({source, dest, position});
@@ -1561,8 +1567,15 @@ void InsertElementOp::build(OpBuilder &builder, OperationState &result,
static LogicalResult verify(InsertElementOp op) {
auto dstVectorType = op.getDestVectorType();
+ if (dstVectorType.getRank() == 0) {
+ if (op.position())
+ return op.emitOpError("expected position to be empty with 0-D vector");
+ return success();
+ }
if (dstVectorType.getRank() != 1)
- return op.emitOpError("expected 1-D vector");
+ return op.emitOpError("unexpected >1 vector rank");
+ if (!op.position())
+ return op.emitOpError("expected position for 1-D vector");
return success();
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 9cce66c7fe58b..033e2d812b3f2 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -512,6 +512,19 @@ func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
// -----
+// CHECK-LABEL: @insert_element_0d
+// CHECK-SAME: %[[A:.*]]: f32,
+func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
+ // CHECK: %[[B:.*]] = builtin.unrealized_conversion_cast %{{.*}} :
+ // CHECK: vector<f32> to vector<1xf32>
+ // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+ // CHECK: %[[x:.*]] = llvm.insertelement %[[A]], %[[B]][%[[C0]] : {{.*}}] : vector<1xf32>
+ %1 = vector.insertelement %a, %b[] : vector<f32>
+ return %1 : vector<f32>
+}
+
+// -----
+
func @insert_element(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
%0 = arith.constant 3 : i32
%1 = vector.insertelement %arg0, %arg1[%0 : i32] : vector<4xf32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c327bfe6968ca..593686a425a51 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -79,7 +79,7 @@ func @extract_element(%arg0: vector<f32>) {
}
// -----
-
+
func @extract_element(%arg0: vector<4xf32>) {
%c = arith.constant 3 : i32
// expected-error at +1 {{expected position for 1-D vector}}
@@ -138,9 +138,25 @@ func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
// -----
+func @insert_element(%arg0: f32, %arg1: vector<f32>) {
+ %c = arith.constant 3 : i32
+ // expected-error at +1 {{expected position to be empty with 0-D vector}}
+ %0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<f32>
+}
+
+// -----
+
+func @insert_element(%arg0: f32, %arg1: vector<4xf32>) {
+ %c = arith.constant 3 : i32
+ // expected-error at +1 {{expected position for 1-D vector}}
+ %0 = vector.insertelement %arg0, %arg1[] : vector<4xf32>
+}
+
+// -----
+
func @insert_element(%arg0: f32, %arg1: vector<4x4xf32>) {
%c = arith.constant 3 : i32
- // expected-error at +1 {{'vector.insertelement' op expected 1-D vector}}
+ // expected-error at +1 {{unexpected >1 vector rank}}
%0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<4x4xf32>
}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 3f7fe75b8cb2a..11b986fc9b87c 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -192,6 +192,13 @@ func @extract(%arg0: vector<4x8x16xf32>) -> (vector<4x8x16xf32>, vector<8x16xf32
return %0, %1, %2, %3 : vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32
}
+// CHECK-LABEL: @insert_element_0d
+func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
+ // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector<f32>
+ %1 = vector.insertelement %a, %b[] : vector<f32>
+ return %1 : vector<f32>
+}
+
// CHECK-LABEL: @insert_element
func @insert_element(%a: f32, %b: vector<16xf32>) -> vector<16xf32> {
// CHECK: %[[C15:.*]] = arith.constant 15 : i32
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
index 0921bfc1f03a0..b3052ebd7a600 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
@@ -10,8 +10,15 @@ func @extract_element_0d(%a: vector<f32>) {
return
}
+func @insert_element_0d(%a: f32, %b: vector<f32>) -> (vector<f32>) {
+ %1 = vector.insertelement %a, %b[] : vector<f32>
+ return %1: vector<f32>
+}
+
func @entry() {
- %1 = arith.constant dense<42.0> : vector<f32>
- call @extract_element_0d(%1) : (vector<f32>) -> ()
+ %0 = arith.constant 42.0 : f32
+ %1 = arith.constant dense<0.0> : vector<f32>
+ %2 = call @insert_element_0d(%0, %1) : (f32, vector<f32>) -> (vector<f32>)
+ call @extract_element_0d(%2) : (vector<f32>) -> ()
return
}
More information about the Mlir-commits
mailing list