[Mlir-commits] [mlir] d137c05 - [mlir][spirv] Add conversion from vector.reduction

Lei Zhang llvmlistbot at llvm.org
Wed Apr 27 07:29:53 PDT 2022


Author: Lei Zhang
Date: 2022-04-27T10:29:46-04:00
New Revision: d137c05fc9a3aa04c66daaa7a92b90998676f693

URL: https://github.com/llvm/llvm-project/commit/d137c05fc9a3aa04c66daaa7a92b90998676f693
DIFF: https://github.com/llvm/llvm-project/commit/d137c05fc9a3aa04c66daaa7a92b90998676f693.diff

LOG: [mlir][spirv] Add conversion from vector.reduction

Only supports addition and multiplication for now; other cases
to be implemented.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D124380

Added: 
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Modified: 
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Removed: 
    mlir/test/Conversion/VectorToSPIRV/simple.mlir


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 5bdcc384165c9..e404d75f7d243 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -19,7 +19,9 @@
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include <numeric>
 
@@ -249,6 +251,69 @@ struct VectorInsertStridedSliceOpConvert final
   }
 };
 
+struct VectorReductionPattern final
+    : public OpConversionPattern<vector::ReductionOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type resultType = typeConverter->convertType(reduceOp.getType());
+    if (!resultType)
+      return failure();
+
+    auto srcVectorType = adaptor.getVector().getType().dyn_cast<VectorType>();
+    if (!srcVectorType || srcVectorType.getRank() != 1)
+      return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");
+
+    // Extract all elements.
+    int numElements = srcVectorType.getDimSize(0);
+    SmallVector<Value, 4> values;
+    values.reserve(numElements + (adaptor.getAcc() != nullptr));
+    Location loc = reduceOp.getLoc();
+    for (int i = 0; i < numElements; ++i) {
+      values.push_back(rewriter.create<spirv::CompositeExtractOp>(
+          loc, srcVectorType.getElementType(), adaptor.getVector(),
+          rewriter.getI32ArrayAttr({i})));
+    }
+    if (Value acc = adaptor.getAcc())
+      values.push_back(acc);
+
+    // Reduce them.
+    Value result = values.front();
+    for (Value next : llvm::makeArrayRef(values).drop_front()) {
+      switch (reduceOp.getKind()) {
+#define INT_FLOAT_CASE(kind, iop, fop)                                         \
+  case vector::CombiningKind::kind:                                            \
+    if (resultType.isa<IntegerType>()) {                                       \
+      result = rewriter.create<spirv::iop>(loc, resultType, result, next);     \
+    } else {                                                                   \
+      assert(resultType.isa<FloatType>());                                     \
+      result = rewriter.create<spirv::fop>(loc, resultType, result, next);     \
+    }                                                                          \
+    break
+
+        INT_FLOAT_CASE(ADD, IAddOp, FAddOp);
+        INT_FLOAT_CASE(MUL, IMulOp, FMulOp);
+
+      case vector::CombiningKind::MINUI:
+      case vector::CombiningKind::MINSI:
+      case vector::CombiningKind::MINF:
+      case vector::CombiningKind::MAXUI:
+      case vector::CombiningKind::MAXSI:
+      case vector::CombiningKind::MAXF:
+      case vector::CombiningKind::AND:
+      case vector::CombiningKind::OR:
+      case vector::CombiningKind::XOR:
+        return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
+      }
+    }
+
+    rewriter.replaceOp(reduceOp, result);
+    return success();
+  }
+};
+
 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
 public:
   using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
@@ -312,6 +377,7 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                VectorExtractElementOpConvert, VectorExtractOpConvert,
                VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
                VectorInsertElementOpConvert, VectorInsertOpConvert,
-               VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
-               VectorSplatPattern>(typeConverter, patterns.getContext());
+               VectorReductionPattern, VectorInsertStridedSliceOpConvert,
+               VectorShuffleOpConvert, VectorSplatPattern>(
+      typeConverter, patterns.getContext());
 }

diff  --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
similarity index 79%
rename from mlir/test/Conversion/VectorToSPIRV/simple.mlir
rename to mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 421c803b5a91a..8cd4258e83d2f 100644
--- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -46,7 +46,7 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
 //       CHECK:   return %[[R]]
 func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 {
   %0 = vector.extract %arg0[0] : vector<1xf32>
-	return %0: f32
+  return %0: f32
 }
 
 // -----
@@ -56,7 +56,7 @@ func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 {
 //       CHECK:   spv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32>
 func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
   %1 = vector.insert %arg1, %arg0[2] : f32 into vector<4xf32>
-	return %1: vector<4xf32>
+  return %1: vector<4xf32>
 }
 
 // -----
@@ -67,7 +67,7 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
 //       CHECK:   return %[[R]]
 func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf32> {
   %1 = vector.insert %arg1, %arg0[0] : f32 into vector<1xf32>
-	return %1 : vector<1xf32>
+  return %1 : vector<1xf32>
 }
 
 // -----
@@ -84,7 +84,7 @@ func.func @extract_element(%arg0 : vector<4xf32>, %id : i32) -> f32 {
 
 // CHECK-LABEL: @extract_element_index
 func.func @extract_element_index(%arg0 : vector<4xf32>, %id : index) -> f32 {
-	// CHECK: vector.extractelement
+  // CHECK: vector.extractelement
   %0 = vector.extractelement %arg0[%id : index] : vector<4xf32>
   return %0: f32
 }
@@ -93,7 +93,7 @@ func.func @extract_element_index(%arg0 : vector<4xf32>, %id : index) -> f32 {
 
 // CHECK-LABEL: @extract_element_size5_vector
 func.func @extract_element_size5_vector(%arg0 : vector<5xf32>, %id : i32) -> f32 {
-	// CHECK: vector.extractelement
+  // CHECK: vector.extractelement
   %0 = vector.extractelement %arg0[%id : i32] : vector<5xf32>
   return %0: f32
 }
@@ -124,7 +124,7 @@ func.func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) -> vector
 
 // CHECK-LABEL: @insert_element_index
 func.func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
-	// CHECK: vector.insertelement
+  // CHECK: vector.insertelement
   %0 = vector.insertelement %val, %arg0[%id : index] : vector<4xf32>
   return %0: vector<4xf32>
 }
@@ -133,19 +133,19 @@ func.func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) -
 
 // CHECK-LABEL: @insert_element_size5_vector
 func.func @insert_element_size5_vector(%val: f32, %arg0 : vector<5xf32>, %id : i32) -> vector<5xf32> {
-	// CHECK: vector.insertelement
+  // CHECK: vector.insertelement
   %0 = vector.insertelement %val, %arg0[%id : i32] : vector<5xf32>
-	return %0 : vector<5xf32>
+  return %0 : vector<5xf32>
 }
 
 // -----
 
 // CHECK-LABEL: @insert_strided_slice
 //  CHECK-SAME: %[[PART:.+]]: vector<2xf32>, %[[ALL:.+]]: vector<4xf32>
-//       CHECK: 	spv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]] : vector<4xf32>, %[[PART]] : vector<2xf32> -> vector<4xf32>
+//       CHECK:   spv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]] : vector<4xf32>, %[[PART]] : vector<2xf32> -> vector<4xf32>
 func.func @insert_strided_slice(%arg0: vector<2xf32>, %arg1: vector<4xf32>) -> vector<4xf32> {
   %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [1], strides = [1]} : vector<2xf32> into vector<4xf32>
-	return %0 : vector<4xf32>
+  return %0 : vector<4xf32>
 }
 
 // -----
@@ -156,7 +156,7 @@ func.func @insert_strided_slice(%arg0: vector<2xf32>, %arg1: vector<4xf32>) -> v
 //       CHECK:   spv.CompositeInsert %[[S]], %[[FULL]][2 : i32] : f32 into vector<3xf32>
 func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: vector<3xf32>) -> vector<3xf32> {
   %1 = vector.insert_strided_slice %arg0, %arg1 {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
-	return %1 : vector<3xf32>
+  return %1 : vector<3xf32>
 }
 
 // -----
@@ -166,7 +166,7 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: vector<3xf32>) -> v
 //       CHECK:   spv.GLSL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32>
 func.func @fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vector<4xf32> {
   %0 = vector.fma %a, %b, %c: vector<4xf32>
-	return %0 : vector<4xf32>
+  return %0 : vector<4xf32>
 }
 
 // -----
@@ -210,3 +210,36 @@ func.func @shuffle(%v0 : vector<2x16xf32>, %v1: vector<1x16xf32>) -> vector<3x16
   %shuffle = vector.shuffle %v0, %v1 [0, 1, 2] : vector<2x16xf32>, vector<1x16xf32>
   return %shuffle : vector<3x16xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @reduction
+//  CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<4xi32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<4xi32>
+//       CHECK:   %[[S3:.+]] = spv.CompositeExtract %[[V]][3 : i32] : vector<4xi32>
+//       CHECK:   %[[ADD0:.+]] = spv.IAdd %[[S0]], %[[S1]]
+//       CHECK:   %[[ADD1:.+]] = spv.IAdd %[[ADD0]], %[[S2]]
+//       CHECK:   %[[ADD2:.+]] = spv.IAdd %[[ADD1]], %[[S3]]
+//       CHECK:   return %[[ADD2]]
+func.func @reduction(%v : vector<4xi32>) -> i32 {
+  %reduce = vector.reduction <add>, %v : vector<4xi32> into i32
+  return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction
+//  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+//       CHECK:   %[[ADD0:.+]] = spv.FMul %[[S0]], %[[S1]]
+//       CHECK:   %[[ADD1:.+]] = spv.FMul %[[ADD0]], %[[S2]]
+//       CHECK:   %[[ADD2:.+]] = spv.FMul %[[ADD1]], %[[S]]
+//       CHECK:   return %[[ADD2]]
+func.func @reduction(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <mul>, %v, %s : vector<3xf32> into f32
+  return %reduce : f32
+}


        


More information about the Mlir-commits mailing list