[Mlir-commits] [mlir] 3ff4e5f - [mlir][Vector] Thread 0-d vectors through InsertElementOp.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Nov 23 04:56:53 PST 2021


Author: Nicolas Vasilache
Date: 2021-11-23T12:55:11Z
New Revision: 3ff4e5f2a4a6a0e124356b2ad8793270ebbb16c1

URL: https://github.com/llvm/llvm-project/commit/3ff4e5f2a4a6a0e124356b2ad8793270ebbb16c1
DIFF: https://github.com/llvm/llvm-project/commit/3ff4e5f2a4a6a0e124356b2ad8793270ebbb16c1.diff

LOG: [mlir][Vector] Thread 0-d vectors through InsertElementOp.

This revision makes concrete use of 0-d vectors to extend the semantics of
InsertElementOp.

Reviewed By: dcaballe, pifon2a

Differential Revision: https://reviews.llvm.org/D114388

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index d8ffdff7667e9..f274ff656f253 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -666,16 +666,18 @@ def Vector_InsertElementOp :
                     "result", "source",
                     "$_self.cast<ShapedType>().getElementType()">,
      AllTypesMatch<["dest", "result"]>]>,
-     Arguments<(ins AnyType:$source, AnyVector:$dest,
-                    AnySignlessIntegerOrIndex:$position)>,
-     Results<(outs AnyVector:$result)> {
+     Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest,
+                    Optional<AnySignlessIntegerOrIndex>:$position)>,
+     Results<(outs AnyVectorOfAnyRank:$result)> {
   let summary = "insertelement operation";
   let description = [{
-    Takes a scalar source, an 1-D destination vector and a dynamic index
-    position and inserts the source into the destination at the proper
-    position.  Note that this instruction resembles vector.insert, but
-    is restricted to 1-D vectors and relaxed to dynamic indices. It is
-    meant to be closer to LLVM's version:
+    Takes a scalar source, a 0-D or 1-D destination vector and a dynamic index
+    position and inserts the source into the destination at the proper position.
+
+    Note that this instruction resembles vector.insert, but is restricted to 0-D
+    and 1-D vectors and relaxed to dynamic indices. 
+
+    It is meant to be closer to LLVM's version:
     https://llvm.org/docs/LangRef.html#insertelement-instruction
 
     Example:
@@ -684,14 +686,18 @@ def Vector_InsertElementOp :
     %c = arith.constant 15 : i32
     %f = arith.constant 0.0f : f32
     %1 = vector.insertelement %f, %0[%c : i32]: vector<16xf32>
+    %2 = vector.insertelement %f, %z[]: vector<f32>
     ```
   }];
   let assemblyFormat = [{
-    $source `,` $dest `[` $position `:` type($position) `]` attr-dict `:`
+    $source `,` $dest `[` ($position^ `:` type($position))? `]`  attr-dict `:`
     type($result)
   }];
 
   let builders = [
+    // 0-D builder.
+    OpBuilder<(ins "Value":$source, "Value":$dest)>,
+    // 1-D + position builder.
     OpBuilder<(ins "Value":$source, "Value":$dest, "Value":$position)>
   ];
   let extraClassDeclaration = [{

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index c74eca56b84b2..108e664f03cad 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -663,6 +663,17 @@ class VectorInsertElementOpConversion
     if (!llvmType)
       return failure();
 
+    if (vectorType.getRank() == 0) {
+      Location loc = insertEltOp.getLoc();
+      auto idxType = rewriter.getIndexType();
+      auto zero = rewriter.create<LLVM::ConstantOp>(
+          loc, typeConverter->convertType(idxType),
+          rewriter.getIntegerAttr(idxType, 0));
+      rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
+          insertEltOp, llvmType, adaptor.dest(), adaptor.source(), zero);
+      return success();
+    }
+
     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
         insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
         adaptor.position());

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index d8cd3c178ad15..e4438fc86e5fd 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1553,6 +1553,12 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
 // InsertElementOp
 //===----------------------------------------------------------------------===//
 
+void InsertElementOp::build(OpBuilder &builder, OperationState &result,
+                            Value source, Value dest) {
+  result.addOperands({source, dest});
+  result.addTypes(dest.getType());
+}
+
 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
                             Value source, Value dest, Value position) {
   result.addOperands({source, dest, position});
@@ -1561,8 +1567,15 @@ void InsertElementOp::build(OpBuilder &builder, OperationState &result,
 
 static LogicalResult verify(InsertElementOp op) {
   auto dstVectorType = op.getDestVectorType();
+  if (dstVectorType.getRank() == 0) {
+    if (op.position())
+      return op.emitOpError("expected position to be empty with 0-D vector");
+    return success();
+  }
   if (dstVectorType.getRank() != 1)
-    return op.emitOpError("expected 1-D vector");
+    return op.emitOpError("unexpected >1 vector rank");
+  if (!op.position())
+    return op.emitOpError("expected position for 1-D vector");
   return success();
 }
 

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 9cce66c7fe58b..033e2d812b3f2 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -512,6 +512,19 @@ func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
 
 // -----
 
+// CHECK-LABEL: @insert_element_0d
+// CHECK-SAME: %[[A:.*]]: f32,
+func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
+  // CHECK: %[[B:.*]] =  builtin.unrealized_conversion_cast %{{.*}} :
+  // CHECK:   vector<f32> to vector<1xf32>
+  // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+  // CHECK: %[[x:.*]] = llvm.insertelement %[[A]], %[[B]][%[[C0]] : {{.*}}] : vector<1xf32>
+  %1 = vector.insertelement %a, %b[] : vector<f32>
+  return %1 : vector<f32>
+}
+
+// -----
+
 func @insert_element(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
   %0 = arith.constant 3 : i32
   %1 = vector.insertelement %arg0, %arg1[%0 : i32] : vector<4xf32>

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c327bfe6968ca..593686a425a51 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -79,7 +79,7 @@ func @extract_element(%arg0: vector<f32>) {
 }
 
 // -----
-
+ 
 func @extract_element(%arg0: vector<4xf32>) {
   %c = arith.constant 3 : i32
   // expected-error at +1 {{expected position for 1-D vector}}
@@ -138,9 +138,25 @@ func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
 
 // -----
 
+func @insert_element(%arg0: f32, %arg1: vector<f32>) {
+  %c = arith.constant 3 : i32
+  // expected-error at +1 {{expected position to be empty with 0-D vector}}
+  %0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<f32>
+}
+
+// -----
+
+func @insert_element(%arg0: f32, %arg1: vector<4xf32>) {
+  %c = arith.constant 3 : i32
+  // expected-error at +1 {{expected position for 1-D vector}}
+  %0 = vector.insertelement %arg0, %arg1[] : vector<4xf32>
+}
+
+// -----
+
 func @insert_element(%arg0: f32, %arg1: vector<4x4xf32>) {
   %c = arith.constant 3 : i32
-  // expected-error at +1 {{'vector.insertelement' op expected 1-D vector}}
+  // expected-error at +1 {{unexpected >1 vector rank}}
   %0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<4x4xf32>
 }
 

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 3f7fe75b8cb2a..11b986fc9b87c 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -192,6 +192,13 @@ func @extract(%arg0: vector<4x8x16xf32>) -> (vector<4x8x16xf32>, vector<8x16xf32
   return %0, %1, %2, %3 : vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32
 }
 
+// CHECK-LABEL: @insert_element_0d
+func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
+  // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector<f32>
+  %1 = vector.insertelement %a, %b[] : vector<f32>
+  return %1 : vector<f32>
+}
+
 // CHECK-LABEL: @insert_element
 func @insert_element(%a: f32, %b: vector<16xf32>) -> vector<16xf32> {
   // CHECK:      %[[C15:.*]] = arith.constant 15 : i32

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
index 0921bfc1f03a0..b3052ebd7a600 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
@@ -10,8 +10,15 @@ func @extract_element_0d(%a: vector<f32>) {
   return
 }
 
+func @insert_element_0d(%a: f32, %b: vector<f32>) -> (vector<f32>) {
+  %1 = vector.insertelement %a, %b[] : vector<f32>
+  return %1: vector<f32>
+}
+
 func @entry() {
-  %1 = arith.constant dense<42.0> : vector<f32>
-  call  @extract_element_0d(%1) : (vector<f32>) -> ()
+  %0 = arith.constant 42.0 : f32
+  %1 = arith.constant dense<0.0> : vector<f32>
+  %2 = call  @insert_element_0d(%0, %1) : (f32, vector<f32>) -> (vector<f32>)
+  call  @extract_element_0d(%2) : (vector<f32>) -> ()
   return
 }


        


More information about the Mlir-commits mailing list