[Mlir-commits] [mlir] e7026ab - [mlir][Vector] Thread 0-d vectors through ExtractElementOp.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Nov 23 04:43:37 PST 2021


Author: Nicolas Vasilache
Date: 2021-11-23T12:39:44Z
New Revision: e7026aba004934cad5487256601af7690757d09f

URL: https://github.com/llvm/llvm-project/commit/e7026aba004934cad5487256601af7690757d09f
DIFF: https://github.com/llvm/llvm-project/commit/e7026aba004934cad5487256601af7690757d09f.diff

LOG: [mlir][Vector] Thread 0-d vectors through ExtractElementOp.

This revision starts making concrete use of 0-d vectors to extend the semantics of
ExtractElementOp.
In the process a new VectorOfAnyRank Tablegen OpBase.td is added to allow progressive transition to supporting 0-d vectors by gradually opting in.

Differential Revision: https://reviews.llvm.org/D114387

Added: 
    mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/include/mlir/IR/OpBase.td
    mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index bbd45b78ecaf2..d8ffdff7667e9 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -482,14 +482,20 @@ def Vector_ExtractElementOp :
      TypesMatchWith<"result type matches element type of vector operand",
                     "vector", "result",
                     "$_self.cast<ShapedType>().getElementType()">]>,
-    Arguments<(ins AnyVector:$vector, AnySignlessIntegerOrIndex:$position)>,
+    Arguments<(ins AnyVectorOfAnyRank:$vector,
+                   Optional<AnySignlessIntegerOrIndex>:$position)>,
     Results<(outs AnyType:$result)> {
   let summary = "extractelement operation";
   let description = [{
-    Takes an 1-D vector and a dynamic index position and extracts the
-    scalar at that position. Note that this instruction resembles
-    vector.extract, but is restricted to 1-D vectors and relaxed
-    to dynamic indices. It is meant to be closer to LLVM's version:
+    Takes a 0-D or 1-D vector and a optional dynamic index position and
+    extracts the scalar at that position.
+
+    Note that this instruction resembles vector.extract, but is restricted to
+    0-D and 1-D vectors and relaxed to dynamic indices.
+    If the vector is 0-D, the position must be llvm::None.
+
+
+    It is meant to be closer to LLVM's version:
     https://llvm.org/docs/LangRef.html#extractelement-instruction
 
     Example:
@@ -497,14 +503,18 @@ def Vector_ExtractElementOp :
     ```mlir
     %c = arith.constant 15 : i32
     %1 = vector.extractelement %0[%c : i32]: vector<16xf32>
+    %2 = vector.extractelement %z[]: vector<f32>
     ```
   }];
   let assemblyFormat = [{
-    $vector `[` $position `:` type($position) `]` attr-dict `:` type($vector)
+    $vector `[` ($position^ `:` type($position))? `]` attr-dict `:` type($vector)
   }];
 
   let builders = [
-    OpBuilder<(ins "Value":$source, "Value":$position)>
+    // 0-D builder.
+    OpBuilder<(ins "Value":$source)>,
+    // 1-D + position builder.
+    OpBuilder<(ins "Value":$source, "Value":$position)>,
   ];
   let extraClassDeclaration = [{
     VectorType getVectorType() {

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index efe6cbc296fa1..15d1ffd1c70f0 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -208,7 +208,12 @@ class SuccessorConstraint<Pred predicate, string summary = ""> :
 //===----------------------------------------------------------------------===//
 
 // Whether a type is a VectorType.
-def IsVectorTypePred : CPred<"$_self.isa<::mlir::VectorType>()">;
+// Explicitly disallow 0-D vectors for now until we have good enough coverage.
+def IsVectorTypePred : And<[CPred<"$_self.isa<::mlir::VectorType>()">,
+                            CPred<"$_self.cast<::mlir::VectorType>().getRank() > 0">]>;
+
+// Temporary vector type clone that allows gradual transition to 0-D vectors.
+def IsVectorOfAnyRankTypePred : CPred<"$_self.isa<::mlir::VectorType>()">;
 
 // Whether a type is a TensorType.
 def IsTensorTypePred : CPred<"$_self.isa<::mlir::TensorType>()">;
@@ -598,6 +603,10 @@ class HasAnyRankOfPred<list<int> ranks> : And<[
 class VectorOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsVectorTypePred, "vector",
                       "::mlir::VectorType">;
+// Temporary vector type clone that allows gradual transition to 0-D vectors.
+class VectorOfAnyRankOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsVectorOfAnyRankTypePred, "vector",
+                      "::mlir::VectorType">;
 
 // Whether the number of elements of a vector is from the given
 // `allowedRanks` list
@@ -649,6 +658,8 @@ class VectorOfLengthAndType<list<int> allowedLengths,
   "::mlir::VectorType">;
 
 def AnyVector : VectorOf<[AnyType]>;
+// Temporary vector type clone that allows gradual transition to 0-D vectors.
+def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
 
 // Shaped types.
 

diff  --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 28be3300fb382..5446aed95d137 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -369,13 +369,17 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
   return LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt());
 }
 
-/// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type
-/// when n > 1. For example, `vector<4 x f32>` remains as is while,
-/// `vector<4x8x16xf32>` converts to `!llvm.array<4xarray<8 x vector<16xf32>>>`.
+/// Convert an n-D vector type to an LLVM vector type:
+///  * 0-D `vector<T>` are converted to vector<1xT>
+///  * 1-D `vector<axT>` remains as is while,
+///  * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
+///    `!llvm.array<ax...array<jxvector<kxT>>>`.
 Type LLVMTypeConverter::convertVectorType(VectorType type) {
   auto elementType = convertType(type.getElementType());
   if (!elementType)
     return {};
+  if (type.getShape().empty())
+    return VectorType::get({1}, elementType);
   Type vectorType = VectorType::get(type.getShape().back(), elementType);
   assert(LLVM::isCompatibleVectorType(vectorType) &&
          "expected vector type compatible with the LLVM dialect");

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 65816a2d0580e..c74eca56b84b2 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -40,6 +40,7 @@ static Value insertOne(ConversionPatternRewriter &rewriter,
                        LLVMTypeConverter &typeConverter, Location loc,
                        Value val1, Value val2, Type llvmType, int64_t rank,
                        int64_t pos) {
+  assert(rank > 0 && "0-D vector corner case should have been handled already");
   if (rank == 1) {
     auto idxType = rewriter.getIndexType();
     auto constant = rewriter.create<LLVM::ConstantOp>(
@@ -56,6 +57,7 @@ static Value insertOne(ConversionPatternRewriter &rewriter,
 static Value extractOne(ConversionPatternRewriter &rewriter,
                         LLVMTypeConverter &typeConverter, Location loc,
                         Value val, Type llvmType, int64_t rank, int64_t pos) {
+  assert(rank > 0 && "0-D vector corner case should have been handled already");
   if (rank == 1) {
     auto idxType = rewriter.getIndexType();
     auto constant = rewriter.create<LLVM::ConstantOp>(
@@ -542,6 +544,17 @@ class VectorExtractElementOpConversion
     if (!llvmType)
       return failure();
 
+    if (vectorType.getRank() == 0) {
+      Location loc = extractEltOp.getLoc();
+      auto idxType = rewriter.getIndexType();
+      auto zero = rewriter.create<LLVM::ConstantOp>(
+          loc, typeConverter->convertType(idxType),
+          rewriter.getIntegerAttr(idxType, 0));
+      rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
+          extractEltOp, llvmType, adaptor.vector(), zero);
+      return success();
+    }
+
     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
         extractEltOp, llvmType, adaptor.vector(), adaptor.position());
     return success();

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 4b67b39b2fdb0..d8cd3c178ad15 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -832,6 +832,12 @@ void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
 // ExtractElementOp
 //===----------------------------------------------------------------------===//
 
+void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
+                                     Value source) {
+  result.addOperands({source});
+  result.addTypes(source.getType().cast<VectorType>().getElementType());
+}
+
 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
                                      Value source, Value position) {
   result.addOperands({source, position});
@@ -840,8 +846,15 @@ void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
 
 static LogicalResult verify(vector::ExtractElementOp op) {
   VectorType vectorType = op.getVectorType();
+  if (vectorType.getRank() == 0) {
+    if (op.position())
+      return op.emitOpError("expected position to be empty with 0-D vector");
+    return success();
+  }
   if (vectorType.getRank() != 1)
-    return op.emitOpError("expected 1-D vector");
+    return op.emitOpError("unexpected >1 vector rank");
+  if (!op.position())
+    return op.emitOpError("expected position for 1-D vector");
   return success();
 }
 

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index d5d8509cfa61f..9cce66c7fe58b 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -418,6 +418,16 @@ func @shuffle_2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32> {
 
 // -----
 
+// CHECK-LABEL: @extract_element_0d
+func @extract_element_0d(%a: vector<f32>) -> f32 {
+  // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+  // CHECK: llvm.extractelement %{{.*}}[%[[C0]] : {{.*}}] : vector<1xf32>
+  %1 = vector.extractelement %a[] : vector<f32>
+  return %1 : f32
+}
+
+// -----
+
 func @extract_element(%arg0: vector<16xf32>) -> f32 {
   %0 = arith.constant 15 : i32
   %1 = vector.extractelement %arg0[%0 : i32]: vector<16xf32>

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 6aa9679117cb0..c327bfe6968ca 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -72,9 +72,25 @@ func @shuffle_empty_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) {
 
 // -----
 
+func @extract_element(%arg0: vector<f32>) {
+  %c = arith.constant 3 : i32
+  // expected-error at +1 {{expected position to be empty with 0-D vector}}
+  %1 = vector.extractelement %arg0[%c : i32] : vector<f32>
+}
+
+// -----
+
+func @extract_element(%arg0: vector<4xf32>) {
+  %c = arith.constant 3 : i32
+  // expected-error at +1 {{expected position for 1-D vector}}
+  %1 = vector.extractelement %arg0[] : vector<4xf32>
+}
+
+// -----
+
 func @extract_element(%arg0: vector<4x4xf32>) {
   %c = arith.constant 3 : i32
-  // expected-error at +1 {{'vector.extractelement' op expected 1-D vector}}
+  // expected-error at +1 {{unexpected >1 vector rank}}
   %1 = vector.extractelement %arg0[%c : i32] : vector<4x4xf32>
 }
 

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index f4bf29dae7fbb..3f7fe75b8cb2a 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -163,6 +163,13 @@ func @shuffle2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32> {
   return %1 : vector<3x4xf32>
 }
 
+// CHECK-LABEL: @extract_element_0d
+func @extract_element_0d(%a: vector<f32>) -> f32 {
+  // CHECK-NEXT: vector.extractelement %{{.*}}[] : vector<f32>
+  %1 = vector.extractelement %a[] : vector<f32>
+  return %1 : f32
+}
+
 // CHECK-LABEL: @extract_element
 func @extract_element(%a: vector<16xf32>) -> f32 {
   // CHECK:      %[[C15:.*]] = arith.constant 15 : i32

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
new file mode 100644
index 0000000000000..0921bfc1f03a0
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+func @extract_element_0d(%a: vector<f32>) {
+  %1 = vector.extractelement %a[] : vector<f32>
+  // CHECK: 42
+  vector.print %1: f32
+  return
+}
+
+func @entry() {
+  %1 = arith.constant dense<42.0> : vector<f32>
+  call  @extract_element_0d(%1) : (vector<f32>) -> ()
+  return
+}


        


More information about the Mlir-commits mailing list