[Mlir-commits] [mlir] e83b7b9 - [mlir] [VectorOps] Implement vector.reduce operation

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 11 11:32:09 PST 2020


Author: aartbik
Date: 2020-02-11T11:31:59-08:00
New Revision: e83b7b99da2e0385c567cd3883cad66fb5ce271c

URL: https://github.com/llvm/llvm-project/commit/e83b7b99da2e0385c567cd3883cad66fb5ce271c
DIFF: https://github.com/llvm/llvm-project/commit/e83b7b99da2e0385c567cd3883cad66fb5ce271c.diff

LOG: [mlir] [VectorOps] Implement vector.reduce operation

Summary:
This new operation operates on 1-D vectors and
forms the bridge between vector.contract and
llvm intrinsics for vector reductions.

Reviewers: nicolasvasilache, andydavis1, ftynse

Reviewed By: nicolasvasilache

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D74370

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/VectorOps/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/VectorOps/VectorOps.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/VectorOps/invalid.mlir
    mlir/test/Dialect/VectorOps/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index de7007e3b509..074a6d005376 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -183,6 +183,39 @@ def Vector_ContractionOp :
   }];
 }
 
+def Vector_ReductionOp :
+  Vector_Op<"reduction", [NoSideEffect,
+     PredOpTrait<"source operand and result have same element type",
+                 TCresVTEtIsSameAsOpBase<0, 0>>]>,
+    Arguments<(ins StrAttr:$kind, AnyVector:$vector)>,
+    Results<(outs AnyType:$dest)> {
+  let summary = "reduction operation";
+  let description = [{
+    Reduces an 1-D vector "horizontally" into a scalar using the given
+    operation (add/mul/min/max for int/fp and and/or/xor for int only).
+    Note that these operations are restricted to 1-D vectors to remain
+    close to the corresponding LLVM intrinsics:
+
+    http://llvm.org/docs/LangRef.html#experimental-vector-reduction-intrinsics
+
+    Examples:
+    ```
+      %1 = vector.reduction "add", %0 : vector<16xf32> into f32
+
+      %3 = vector.reduction "xor", %2 : vector<4xi32> into i32
+    ```
+  }];
+  let verifier = [{ return ::verify(*this); }];
+  let assemblyFormat = [{
+    $kind `,` $vector attr-dict `:` type($vector) `into` type($dest)
+  }];
+  let extraClassDeclaration = [{
+    VectorType getVectorType() {
+      return vector().getType().cast<VectorType>();
+    }
+  }];
+}
+
 def Vector_BroadcastOp :
   Vector_Op<"broadcast", [NoSideEffect,
      PredOpTrait<"source operand and result have same element type",

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a3d724bea834..9fcad2f5063a 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -124,6 +124,7 @@ static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
 }
 
 namespace {
+
 class VectorBroadcastOpConversion : public LLVMOpLowering {
 public:
   explicit VectorBroadcastOpConversion(MLIRContext *context,
@@ -272,6 +273,73 @@ class VectorBroadcastOpConversion : public LLVMOpLowering {
   }
 };
 
+class VectorReductionOpConversion : public LLVMOpLowering {
+public:
+  explicit VectorReductionOpConversion(MLIRContext *context,
+                                       LLVMTypeConverter &typeConverter)
+      : LLVMOpLowering(vector::ReductionOp::getOperationName(), context,
+                       typeConverter) {}
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto reductionOp = cast<vector::ReductionOp>(op);
+    auto kind = reductionOp.kind();
+    Type eltType = reductionOp.dest().getType();
+    Type llvmType = lowering.convertType(eltType);
+    if (eltType.isInteger(32) || eltType.isInteger(64)) {
+      // Integer reductions: add/mul/min/max/and/or/xor.
+      if (kind == "add")
+        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>(
+            op, llvmType, operands[0]);
+      else if (kind == "mul")
+        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_mul>(
+            op, llvmType, operands[0]);
+      else if (kind == "min")
+        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smin>(
+            op, llvmType, operands[0]);
+      else if (kind == "max")
+        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smax>(
+            op, llvmType, operands[0]);
+      else if (kind == "and")
+        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_and>(
+            op, llvmType, operands[0]);
+      else if (kind == "or")
+        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_or>(
+            op, llvmType, operands[0]);
+      else if (kind == "xor")
+        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>(
+            op, llvmType, operands[0]);
+      else
+        return matchFailure();
+      return matchSuccess();
+
+    } else if (eltType.isF32() || eltType.isF64()) {
+      // Floating-point reductions: add/mul/min/max
+      if (kind == "add") {
+        Value zero = rewriter.create<LLVM::ConstantOp>(
+            op->getLoc(), llvmType, rewriter.getZeroAttr(eltType));
+        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>(
+            op, llvmType, zero, operands[0]);
+      } else if (kind == "mul") {
+        Value one = rewriter.create<LLVM::ConstantOp>(
+            op->getLoc(), llvmType, rewriter.getFloatAttr(eltType, 1.0));
+        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>(
+            op, llvmType, one, operands[0]);
+      } else if (kind == "min")
+        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>(
+            op, llvmType, operands[0]);
+      else if (kind == "max")
+        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>(
+            op, llvmType, operands[0]);
+      else
+        return matchFailure();
+      return matchSuccess();
+    }
+    return matchFailure();
+  }
+};
+
 class VectorShuffleOpConversion : public LLVMOpLowering {
 public:
   explicit VectorShuffleOpConversion(MLIRContext *context,
@@ -1056,12 +1124,12 @@ void mlir::populateVectorToLLVMConversionPatterns(
                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
                   VectorInsertStridedSliceOpSameRankRewritePattern,
                   VectorStridedSliceOpConversion>(ctx);
-  patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
-                  VectorExtractElementOpConversion, VectorExtractOpConversion,
-                  VectorFMAOp1DConversion, VectorInsertElementOpConversion,
-                  VectorInsertOpConversion, VectorOuterProductOpConversion,
-                  VectorTypeCastOpConversion, VectorPrintOpConversion>(
-      ctx, converter);
+  patterns.insert<VectorBroadcastOpConversion, VectorReductionOpConversion,
+                  VectorShuffleOpConversion, VectorExtractElementOpConversion,
+                  VectorExtractOpConversion, VectorFMAOp1DConversion,
+                  VectorInsertElementOpConversion, VectorInsertOpConversion,
+                  VectorOuterProductOpConversion, VectorTypeCastOpConversion,
+                  VectorPrintOpConversion>(ctx, converter);
 }
 
 namespace {

diff  --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index a987a54f5ea1..b4d7aee70b17 100644
--- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -60,6 +60,33 @@ ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
   return builder.getI64ArrayAttr(values);
 }
 
+//===----------------------------------------------------------------------===//
+// ReductionOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(ReductionOp op) {
+  // Verify for 1-D vector.
+  int64_t rank = op.getVectorType().getRank();
+  if (rank != 1)
+    return op.emitOpError("unsupported reduction rank: ") << rank;
+
+  // Verify supported reduction kind.
+  auto kind = op.kind();
+  Type eltType = op.dest().getType();
+  if (kind == "add" || kind == "mul" || kind == "min" || kind == "max") {
+    if (eltType.isF32() || eltType.isF64() || eltType.isInteger(32) ||
+        eltType.isInteger(64))
+      return success();
+    return op.emitOpError("unsupported reduction type");
+  }
+  if (kind == "and" || kind == "or" || kind == "xor") {
+    if (eltType.isInteger(32) || eltType.isInteger(64))
+      return success();
+    return op.emitOpError("unsupported reduction type");
+  }
+  return op.emitOpError("unknown reduction kind: ") << kind;
+}
+
 //===----------------------------------------------------------------------===//
 // ContractionOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index d1535d59593c..5159031339aa 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -645,7 +645,7 @@ func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vect
   //         CHECK: "llvm.intr.fma"(%[[A]], %[[A]], %[[A]]) :
   //    CHECK-SAME:   (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) -> !llvm<"<8 x float>">
   %0 = vector.fma %a, %a, %a : vector<8xf32>
-  
+
   //       CHECK: %[[b00:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
   //       CHECK: %[[b01:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
   //       CHECK: %[[b02:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
@@ -659,7 +659,45 @@ func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vect
   //  CHECK-SAME: (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
   //       CHECK: llvm.insertvalue %[[B1]], {{.*}}[1] : !llvm<"[2 x <4 x float>]">
   %1 = vector.fma %b, %b, %b : vector<2x4xf32>
-  
+
   return %0, %1: vector<8xf32>, vector<2x4xf32>
 }
-        
+
+func @reduce_f32(%arg0: vector<16xf32>) -> f32 {
+  %0 = vector.reduction "add", %arg0 : vector<16xf32> into f32
+  return %0 : f32
+}
+// CHECK-LABEL: llvm.func @reduce_f32
+// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>">
+//      CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
+//      CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]])
+//      CHECK: llvm.return %[[V]] : !llvm.float
+
+func @reduce_f64(%arg0: vector<16xf64>) -> f64 {
+  %0 = vector.reduction "add", %arg0 : vector<16xf64> into f64
+  return %0 : f64
+}
+// CHECK-LABEL: llvm.func @reduce_f64
+// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x double>">
+//      CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f64) : !llvm.double
+//      CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]])
+//      CHECK: llvm.return %[[V]] : !llvm.double
+
+func @reduce_i32(%arg0: vector<16xi32>) -> i32 {
+  %0 = vector.reduction "add", %arg0 : vector<16xi32> into i32
+  return %0 : i32
+}
+// CHECK-LABEL: llvm.func @reduce_i32
+// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x i32>">
+//      CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]])
+//      CHECK: llvm.return %[[V]] : !llvm.i32
+
+func @reduce_i64(%arg0: vector<16xi64>) -> i64 {
+  %0 = vector.reduction "add", %arg0 : vector<16xi64> into i64
+  return %0 : i64
+}
+// CHECK-LABEL: llvm.func @reduce_i64
+// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x i64>">
+//      CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]])
+//      CHECK: llvm.return %[[V]] : !llvm.i64
+

diff  --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir
index a1fee2cf9a62..2a45820be7b0 100644
--- a/mlir/test/Dialect/VectorOps/invalid.mlir
+++ b/mlir/test/Dialect/VectorOps/invalid.mlir
@@ -990,3 +990,31 @@ func @shape_cast_
diff erent_tuple_sizes(
   %1 = vector.shape_cast %arg1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
                                  tuple<vector<20x2xf32>>
 }
+
+// -----
+
+func @reduce_unknown_kind(%arg0: vector<16xf32>) -> f32 {
+  // expected-error at +1 {{'vector.reduction' op unknown reduction kind: joho}}
+  %0 = vector.reduction "joho", %arg0 : vector<16xf32> into f32
+}
+
+// -----
+
+func @reduce_elt_type_mismatch(%arg0: vector<16xf32>) -> i32 {
+  // expected-error at +1 {{'vector.reduction' op failed to verify that source operand and result have same element type}}
+  %0 = vector.reduction "add", %arg0 : vector<16xf32> into i32
+}
+
+// -----
+
+func @reduce_unsupported_type(%arg0: vector<16xf32>) -> f32 {
+  // expected-error at +1 {{'vector.reduction' op unsupported reduction type}}
+  %0 = vector.reduction "xor", %arg0 : vector<16xf32> into f32
+}
+
+// -----
+
+func @reduce_unsupported_rank(%arg0: vector<4x16xf32>) -> f32 {
+  // expected-error at +1 {{'vector.reduction' op unsupported reduction rank: 2}}
+  %0 = vector.reduction "add", %arg0 : vector<4x16xf32> into f32
+}

diff  --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir
index ff0078310af2..bb5ca6eb8538 100644
--- a/mlir/test/Dialect/VectorOps/ops.mlir
+++ b/mlir/test/Dialect/VectorOps/ops.mlir
@@ -277,3 +277,37 @@ func @vector_fma(%a: vector<8xf32>, %b: vector<8x4xf32>) {
   vector.fma %b, %b, %b : vector<8x4xf32>
   return
 }
+
+// CHECK-LABEL: reduce_fp
+func @reduce_fp(%arg0: vector<16xf32>) -> f32 {
+  // CHECK:    vector.reduction "add", %{{.*}} : vector<16xf32> into f32
+  vector.reduction "add", %arg0 : vector<16xf32> into f32
+  // CHECK:    vector.reduction "mul", %{{.*}} : vector<16xf32> into f32
+  vector.reduction "mul", %arg0 : vector<16xf32> into f32
+  // CHECK:    vector.reduction "min", %{{.*}} : vector<16xf32> into f32
+  vector.reduction "min", %arg0 : vector<16xf32> into f32
+  // CHECK:    %[[X:.*]] = vector.reduction "max", %{{.*}} : vector<16xf32> into f32
+  %0 = vector.reduction "max", %arg0 : vector<16xf32> into f32
+  // CHECK:    return %[[X]] : f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: reduce_int
+func @reduce_int(%arg0: vector<16xi32>) -> i32 {
+  // CHECK:    vector.reduction "add", %{{.*}} : vector<16xi32> into i32
+  vector.reduction "add", %arg0 : vector<16xi32> into i32
+  // CHECK:    vector.reduction "mul", %{{.*}} : vector<16xi32> into i32
+  vector.reduction "mul", %arg0 : vector<16xi32> into i32
+  // CHECK:    vector.reduction "min", %{{.*}} : vector<16xi32> into i32
+  vector.reduction "min", %arg0 : vector<16xi32> into i32
+  // CHECK:    vector.reduction "max", %{{.*}} : vector<16xi32> into i32
+  vector.reduction "max", %arg0 : vector<16xi32> into i32
+  // CHECK:    vector.reduction "and", %{{.*}} : vector<16xi32> into i32
+  vector.reduction "and", %arg0 : vector<16xi32> into i32
+  // CHECK:    vector.reduction "or", %{{.*}} : vector<16xi32> into i32
+  vector.reduction "or", %arg0 : vector<16xi32> into i32
+  // CHECK:    %[[X:.*]] = vector.reduction "xor", %{{.*}} : vector<16xi32> into i32
+  %0 = vector.reduction "xor", %arg0 : vector<16xi32> into i32
+  // CHECK:    return %[[X]] : i32
+  return %0 : i32
+}


        


More information about the Mlir-commits mailing list