[Mlir-commits] [mlir] 80e0bf1 - Add vector.scan op

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 28 12:08:40 PST 2022


Author: harsh
Date: 2022-01-28T20:07:57Z
New Revision: 80e0bf1af11c0c7bd2a6261d83dcb18987cd7f11

URL: https://github.com/llvm/llvm-project/commit/80e0bf1af11c0c7bd2a6261d83dcb18987cd7f11
DIFF: https://github.com/llvm/llvm-project/commit/80e0bf1af11c0c7bd2a6261d83dcb18987cd7f11.diff

LOG: Add vector.scan op

This patch adds the vector.scan op which computes the
scan for a given n-d vector. It requires specifying the operator,
the identity element and whether the scan is inclusive or
exclusive.

TEST: Added test in ops.mlir

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D117171

Added: 
    mlir/test/Dialect/Vector/vector-scan-transforms.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-scan.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 1f501ac6b89ea..6a5e0d961b1dc 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -2429,4 +2429,62 @@ def VectorScaleOp : Vector_Op<"vscale",
   let verifier = ?;
 }
 
+//===----------------------------------------------------------------------===//
+// VectorScanOp
+//===----------------------------------------------------------------------===//
+
+def Vector_ScanOp :
+  Vector_Op<"scan", [NoSideEffect,
+    AllTypesMatch<["source", "dest"]>,
+    AllTypesMatch<["initial_value", "accumulated_value"]> ]>,
+    Arguments<(ins Vector_CombiningKindAttr:$kind,
+                   AnyVector:$source,
+                   AnyVectorOfAnyRank:$initial_value,
+                   I64Attr:$reduction_dim,
+                   BoolAttr:$inclusive)>,
+    Results<(outs AnyVector:$dest,
+                  AnyVectorOfAnyRank:$accumulated_value)> {
+  let summary = "Scan operation";
+  let description = [{
+    Performs an inclusive/exclusive scan on an n-D vector along a single
+    dimension returning an n-D result vector using the given
+    operation (add/mul/min/max for int/fp and and/or/xor for
+    int only) and a specified value for the initial value. The operator
+    returns the result of scan as well as the result of the last
+    reduction in the scan.
+
+    Example:
+
+    ```mlir
+    %1:2 = vector.scan <add>, %0, %acc {inclusive = false, reduction_dim = 1 : i64} :
+      vector<4x8x16x32xf32>, vector<4x16x32xf32>
+    ```
+  }];
+  let builders = [
+    OpBuilder<(ins "Value":$source, "Value":$initial_value,
+                   "CombiningKind":$kind,
+                   CArg<"int64_t", "0">:$reduction_dim,
+                   CArg<"bool", "true">:$inclusive)>
+  ];
+  let extraClassDeclaration = [{
+    static StringRef getKindAttrName() { return "kind"; }
+    static StringRef getReductionDimAttrName() { return "reduction_dim"; }
+    VectorType getSourceType() {
+      return source().getType().cast<VectorType>();
+    }
+    VectorType getDestType() {
+      return dest().getType().cast<VectorType>();
+    }
+    VectorType getAccumulatorType() {
+      return accumulated_value().getType().cast<VectorType>();
+    }
+    VectorType getInitialValueType() {
+      return initial_value().getType().cast<VectorType>();
+    }
+  }];
+  let assemblyFormat =
+    "$kind `,` $source `,` $initial_value attr-dict `:` "
+    "type($source) `,` type($initial_value) ";
+}
+
 #endif // VECTOR_OPS

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
index 563f67d817804..8f01b0523ef24 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
@@ -171,6 +171,9 @@ void populateVectorContractLoweringPatterns(
 /// transpose/broadcast ops into the contract.
 void populateVectorReductionToContractPatterns(RewritePatternSet &patterns);
 
+/// Collect patterns to convert scan op
+void populateVectorScanLoweringPatterns(RewritePatternSet &patterns);
+
 //===----------------------------------------------------------------------===//
 // Vector.transfer patterns.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index e1289d5e7fad7..9493187461e49 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -4263,6 +4263,44 @@ void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<CreateMaskFolder>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// ScanOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(ScanOp op) {
+  VectorType srcType = op.getSourceType();
+  VectorType initialType = op.getInitialValueType();
+  // Check reduction dimension < rank.
+  int64_t srcRank = srcType.getRank();
+  int64_t reductionDim = op.reduction_dim();
+  if (reductionDim >= srcRank)
+    return op.emitOpError("reduction dimension ")
+           << reductionDim << " has to be less than " << srcRank;
+
+  // Check that rank(initial_value) = rank(src) - 1.
+  int64_t initialValueRank = initialType.getRank();
+  if (initialValueRank != srcRank - 1)
+    return op.emitOpError("initial value rank ")
+           << initialValueRank << " has to be equal to " << srcRank - 1;
+
+  // Check shapes of initial value and src.
+  ArrayRef<int64_t> srcShape = srcType.getShape();
+  ArrayRef<int64_t> initialValueShapes = initialType.getShape();
+  SmallVector<int64_t> expectedShape;
+  for (int i = 0; i < srcRank; i++) {
+    if (i != reductionDim)
+      expectedShape.push_back(srcShape[i]);
+  }
+  if (llvm::any_of(llvm::zip(initialValueShapes, expectedShape),
+                   [](std::tuple<int64_t, int64_t> s) {
+                     return std::get<0>(s) != std::get<1>(s);
+                   })) {
+    return op.emitOpError("incompatible input/initial value shapes");
+  }
+
+  return success();
+}
+
 void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
     RewritePatternSet &patterns) {
   patterns

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 3ead467bc5181..83553b642a3c2 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2348,6 +2348,204 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
   }
 };
 
+namespace {
+
+/// This function checks to see if the vector combining kind
+/// is consistent with the integer or float element type.
+static bool isValidKind(bool isInt, vector::CombiningKind kind) {
+  using vector::CombiningKind;
+  enum class KindType { FLOAT, INT, INVALID };
+  KindType type{KindType::INVALID};
+  switch (kind) {
+  case CombiningKind::MINF:
+  case CombiningKind::MAXF:
+    type = KindType::FLOAT;
+    break;
+  case CombiningKind::MINUI:
+  case CombiningKind::MINSI:
+  case CombiningKind::MAXUI:
+  case CombiningKind::MAXSI:
+  case CombiningKind::AND:
+  case CombiningKind::OR:
+  case CombiningKind::XOR:
+    type = KindType::INT;
+    break;
+  case CombiningKind::ADD:
+  case CombiningKind::MUL:
+    type = isInt ? KindType::INT : KindType::FLOAT;
+    break;
+  }
+  bool isValidIntKind = (type == KindType::INT) && isInt;
+  bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
+  return (isValidIntKind || isValidFloatKind);
+}
+
+/// This function constructs the appropriate integer or float
+/// operation given the vector combining kind and operands. The
+/// supported int operations are : add, mul, min (signed/unsigned),
+/// max(signed/unsigned), and, or, xor. The supported float
+/// operations are : add, mul, min and max.
+static Value genOperator(Location loc, Value x, Value y,
+                         vector::CombiningKind kind,
+                         PatternRewriter &rewriter) {
+  using vector::CombiningKind;
+
+  auto elType = x.getType().cast<VectorType>().getElementType();
+  bool isInt = elType.isIntOrIndex();
+
+  Value combinedResult{nullptr};
+  switch (kind) {
+  case CombiningKind::ADD:
+    if (isInt)
+      combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
+    else
+      combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
+    break;
+  case CombiningKind::MUL:
+    if (isInt)
+      combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
+    else
+      combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
+    break;
+  case CombiningKind::MINUI:
+    combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
+    break;
+  case CombiningKind::MINSI:
+    combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
+    break;
+  case CombiningKind::MAXUI:
+    combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
+    break;
+  case CombiningKind::MAXSI:
+    combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
+    break;
+  case CombiningKind::AND:
+    combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
+    break;
+  case CombiningKind::OR:
+    combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
+    break;
+  case CombiningKind::XOR:
+    combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
+    break;
+  case CombiningKind::MINF:
+    combinedResult = rewriter.create<arith::MinFOp>(loc, x, y);
+    break;
+  case CombiningKind::MAXF:
+    combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y);
+    break;
+  }
+  return combinedResult;
+}
+
+/// Convert vector.scan op into arith ops and
+/// vector.insert_strided_slice/extract_strided_slice
+///
+/// Ex:
+/// ```
+///   %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim =
+///   1} :
+///     (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
+/// ```
+/// Gets converted to:
+/// ```
+///   %cst = arith.constant dense<0> : vector<2x3xi32>
+///   %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1],
+///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 =
+///   vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]}
+///   : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice
+///   %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} :
+///   vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 :
+///   vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1],
+///   strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 =
+///   vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1],
+///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3,
+///   %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets =
+///   [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 =
+///   vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 :
+///   vector<2x3xi32>, vector<2xi32>
+/// ```
+struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
+  using OpRewritePattern<vector::ScanOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ScanOp scanOp,
+                                PatternRewriter &rewriter) const override {
+    auto loc = scanOp.getLoc();
+    VectorType destType = scanOp.getDestType();
+    ArrayRef<int64_t> destShape = destType.getShape();
+    auto elType = destType.getElementType();
+    bool isInt = elType.isIntOrIndex();
+    if (!isValidKind(isInt, scanOp.kind()))
+      return failure();
+
+    VectorType resType = VectorType::get(destShape, elType);
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resType, rewriter.getZeroAttr(resType));
+    int64_t reductionDim = scanOp.reduction_dim();
+    bool inclusive = scanOp.inclusive();
+    int64_t destRank = destType.getRank();
+    VectorType initialValueType = scanOp.getInitialValueType();
+    int64_t initialValueRank = initialValueType.getRank();
+
+    SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
+    reductionShape[reductionDim] = 1;
+    VectorType reductionType = VectorType::get(reductionShape, elType);
+    SmallVector<int64_t> offsets(destRank, 0);
+    SmallVector<int64_t> strides(destRank, 1);
+    SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
+    sizes[reductionDim] = 1;
+    ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
+    ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
+
+    Value lastOutput, lastInput;
+    for (int i = 0; i < destShape[reductionDim]; i++) {
+      offsets[reductionDim] = i;
+      ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
+      Value input = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, reductionType, scanOp.source(), scanOffsets, scanSizes,
+          scanStrides);
+      Value output;
+      if (i == 0) {
+        if (inclusive) {
+          output = input;
+        } else {
+          if (initialValueRank == 0) {
+            // ShapeCastOp cannot handle 0-D vectors
+            output = rewriter.create<vector::BroadcastOp>(
+                loc, input.getType(), scanOp.initial_value());
+          } else {
+            output = rewriter.create<vector::ShapeCastOp>(
+                loc, input.getType(), scanOp.initial_value());
+          }
+        }
+      } else {
+        Value y = inclusive ? input : lastInput;
+        output = genOperator(loc, lastOutput, y, scanOp.kind(), rewriter);
+        assert(output != nullptr);
+      }
+      result = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, output, result, offsets, strides);
+      lastOutput = output;
+      lastInput = input;
+    }
+
+    Value reduction;
+    if (initialValueRank == 0) {
+      Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
+      reduction =
+          rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
+    } else {
+      reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
+                                                       lastOutput);
+    }
+
+    rewriter.replaceOp(scanOp, {result, reduction});
+    return success();
+  }
+};
+
+} // namespace
+
 void mlir::vector::populateVectorMaskMaterializationPatterns(
     RewritePatternSet &patterns, bool indexOptimizations) {
   patterns.add<VectorCreateMaskOpConversion,
@@ -2421,3 +2619,8 @@ void mlir::vector::populateVectorTransferLoweringPatterns(
       .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
           patterns.getContext());
 }
+
+void mlir::vector::populateVectorScanLoweringPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ScanToArithOps>(patterns.getContext());
+}

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index a384d42ef6112..cd795b7bab083 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1480,3 +1480,29 @@ func @insert_map_id(%v: vector<2x1xf32>, %v1: vector<4x32xf32>, %id : index) {
   %0 = vector.insert_map %v, %v1[%id] : vector<2x1xf32> into vector<4x32xf32>
 }
 
+// -----
+
+func @scan_reduction_dim_constraint(%arg0: vector<2x3xi32>, %arg1: vector<3xi32>) -> vector<3xi32> {
+  // expected-error at +1 {{'vector.scan' op reduction dimension 5 has to be less than 2}}
+  %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 5} :
+    vector<2x3xi32>, vector<3xi32>
+  return %0#1 : vector<3xi32>
+}
+
+// -----
+
+func @scan_ival_rank_constraint(%arg0: vector<2x3xi32>, %arg1: vector<1x3xi32>) -> vector<1x3xi32> {
+  // expected-error at +1 {{initial value rank 2 has to be equal to 1}}
+  %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 0} :
+    vector<2x3xi32>, vector<1x3xi32>
+  return %0#1 : vector<1x3xi32>
+}
+
+// -----
+
+func @scan_incompatible_shapes(%arg0: vector<2x3xi32>, %arg1: vector<5xi32>) -> vector<2x3xi32> {
+  // expected-error at +1 {{incompatible input/initial value shapes}}
+  %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 0} :
+    vector<2x3xi32>, vector<5xi32>
+  return %0#0 : vector<2x3xi32>
+}

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index a77d7ba2ebf54..a839476602141 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -717,3 +717,11 @@ func @get_vector_scale() -> index {
   %0 = vector.vscale
   return %0 : index
 }
+
+// CHECK-LABEL: @vector_scan
+func @vector_scan(%0: vector<4x8x16x32xf32>) -> vector<4x8x16x32xf32> {
+  %1 = arith.constant dense<0.0> : vector<4x16x32xf32>
+  %2:2 = vector.scan <add>, %0, %1 {reduction_dim = 1 : i64, inclusive = true} :
+    vector<4x8x16x32xf32>, vector<4x16x32xf32>
+  return %2#0 : vector<4x8x16x32xf32>
+}

diff  --git a/mlir/test/Dialect/Vector/vector-scan-transforms.mlir b/mlir/test/Dialect/Vector/vector-scan-transforms.mlir
new file mode 100644
index 0000000000000..2ff84bac6142f
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-scan-transforms.mlir
@@ -0,0 +1,91 @@
+// RUN: mlir-opt %s --test-vector-scan-lowering | FileCheck %s
+
+// CHECK-LABEL: func @scan1d_inc
+// CHECK-SAME: %[[ARG0:.*]]: vector<2xi32>,
+// CHECK-SAME: %[[ARG1:.*]]: vector<i32>
+// CHECK:      %[[A:.*]] = arith.constant dense<0> : vector<2xi32>
+// CHECK:      %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xi32> to vector<1xi32>
+// CHECK:      %[[C:.*]] = vector.insert_strided_slice %[[B]], %[[A]] {offsets = [0], strides = [1]} : vector<1xi32> into vector<2xi32>
+// CHECK:      %[[D:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1], sizes = [1], strides = [1]} : vector<2xi32> to vector<1xi32>
+// CHECK:      %[[E:.*]] = arith.addi %[[B]], %[[D]] : vector<1xi32>
+// CHECK:      %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[C]] {offsets = [1], strides = [1]} : vector<1xi32> into vector<2xi32>
+// CHECK:      %[[G:.*]] = vector.extract %[[E]][0] : vector<1xi32>
+// CHECK:      %[[H:.*]] = vector.broadcast %[[G]] : i32 to vector<i32>
+// CHECK:      return %[[F]], %[[H]] : vector<2xi32>, vector<i32>
+func @scan1d_inc(%arg0 : vector<2xi32>, %arg1 : vector<i32>) -> (vector<2xi32>, vector<i32>) {
+  %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 0} :
+    vector<2xi32>, vector<i32>
+  return %0#0, %0#1 : vector<2xi32>, vector<i32>
+}
+
+// CHECK-LABEL: func @scan1d_exc
+// CHECK-SAME: %[[ARG0:.*]]: vector<2xi32>,
+// CHECK-SAME: %[[ARG1:.*]]: vector<i32>
+// CHECK:      %[[A:.*]] = arith.constant dense<0> : vector<2xi32>
+// CHECK:      %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xi32> to vector<1xi32>
+// CHECK:      %[[C:.*]] = vector.broadcast %[[ARG1]] : vector<i32> to vector<1xi32>
+// CHECK:      %[[D:.*]] = vector.insert_strided_slice %[[C]], %[[A]] {offsets = [0], strides = [1]} : vector<1xi32> into vector<2xi32>
+// CHECK:      %[[E:.*]] = arith.addi %[[C]], %[[B]] : vector<1xi32>
+// CHECK:      %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[D]] {offsets = [1], strides = [1]} : vector<1xi32> into vector<2xi32>
+// CHECK:      %[[G:.*]] = vector.extract %[[E]][0] : vector<1xi32>
+// CHECK:      %[[H:.*]] = vector.broadcast %[[G]] : i32 to vector<i32>
+// CHECK:      return %[[F]], %[[H]] : vector<2xi32>, vector<i32>
+func @scan1d_exc(%arg0 : vector<2xi32>, %arg1 : vector<i32>) -> (vector<2xi32>, vector<i32>) {
+  %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = false, reduction_dim = 0} :
+    vector<2xi32>, vector<i32>
+  return %0#0, %0#1 : vector<2xi32>, vector<i32>
+}
+
+// CHECK-LABEL: func @scan2d_mul_dim0
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xi32>,
+// CHECK-SAME: %[[ARG1:.*]]: vector<3xi32>
+// CHECK:      %[[A:.*]] = arith.constant dense<0> : vector<2x3xi32>
+// CHECK:      %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x3xi32> to vector<1x3xi32>
+// CHECK:      %[[C:.*]] = vector.insert_strided_slice %[[B]], %[[A]] {offsets = [0, 0], strides = [1, 1]} : vector<1x3xi32> into vector<2x3xi32>
+// CHECK:      %[[D:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x3xi32> to vector<1x3xi32>
+// CHECK:      %[[E:.*]] = arith.muli %[[B]], %[[D]] : vector<1x3xi32>
+// CHECK:      %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[C]] {offsets = [1, 0], strides = [1, 1]} : vector<1x3xi32> into vector<2x3xi32>
+// CHECK:      %[[G:.*]] = vector.shape_cast %[[E]] : vector<1x3xi32> to vector<3xi32>
+// CHECK:      return %[[F]], %[[G]] : vector<2x3xi32>, vector<3xi32>
+func @scan2d_mul_dim0(%arg0 : vector<2x3xi32>, %arg1 : vector<3xi32>) -> (vector<2x3xi32>, vector<3xi32>) {
+  %0:2 = vector.scan <mul>, %arg0, %arg1 {inclusive = true, reduction_dim = 0} :
+    vector<2x3xi32>, vector<3xi32>
+  return %0#0, %0#1 : vector<2x3xi32>, vector<3xi32>
+}
+
+// CHECK-LABEL: func @scan2d_mul_dim1
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xi32>,
+// CHECK-SAME: %[[ARG1:.*]]: vector<2xi32>
+// CHECK:      %[[A:.*]] = arith.constant dense<0> : vector<2x3xi32>
+// CHECK:      %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32>
+// CHECK:      %[[C:.*]] = vector.insert_strided_slice %[[B]], %[[A]] {offsets = [0, 0], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32>
+// CHECK:      %[[D:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32>
+// CHECK:      %[[E:.*]] = arith.muli %[[B]], %[[D]] : vector<2x1xi32>
+// CHECK:      %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[C]] {offsets = [0, 1], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32>
+// CHECK:      %[[G:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32>
+// CHECK:      %[[H:.*]] = arith.muli %[[E]], %[[G]] : vector<2x1xi32>
+// CHECK:      %[[I:.*]] = vector.insert_strided_slice %[[H]], %[[F]] {offsets = [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32>
+// CHECK:      %[[J:.*]] = vector.shape_cast %[[H]] : vector<2x1xi32> to vector<2xi32>
+// CHECK:      return %[[I]], %[[J]] : vector<2x3xi32>, vector<2xi32>
+func @scan2d_mul_dim1(%arg0 : vector<2x3xi32>, %arg1 : vector<2xi32>) -> (vector<2x3xi32>, vector<2xi32>) {
+  %0:2 = vector.scan <mul>, %arg0, %arg1 {inclusive = true, reduction_dim = 1} :
+    vector<2x3xi32>, vector<2xi32>
+  return %0#0, %0#1 : vector<2x3xi32>, vector<2xi32>
+}
+
+// CHECK-LABEL: func @scan3d_mul_dim1
+// CHECK-SAME: %[[ARG0:.*]]: vector<4x2x3xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: vector<4x3xf32>
+// CHECK:      %[[A:.*]] = arith.constant dense<0.000000e+00> : vector<4x2x3xf32>
+// CHECK:      %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32>
+// CHECK:      %[[C:.*]] = vector.shape_cast %[[ARG1]] : vector<4x3xf32> to vector<4x1x3xf32>
+// CHECK:      %[[D:.*]] = vector.insert_strided_slice %[[C]], %[[A]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
+// CHECK:      %[[E:.*]] = arith.mulf %[[C]], %[[B]] : vector<4x1x3xf32>
+// CHECK:      %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[D]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
+// CHECK:      %[[G:.*]] = vector.shape_cast %[[E]] : vector<4x1x3xf32> to vector<4x3xf32>
+// CHECK:      return %[[F]], %[[G]] : vector<4x2x3xf32>, vector<4x3xf32>
+func @scan3d_mul_dim1(%arg0 : vector<4x2x3xf32>, %arg1 : vector<4x3xf32>) -> (vector<4x2x3xf32>, vector<4x3xf32>) {
+  %0:2 = vector.scan <mul>, %arg0, %arg1 {inclusive = false, reduction_dim = 1} :
+    vector<4x2x3xf32>, vector<4x3xf32>
+  return %0#0, %0#1 : vector<4x2x3xf32>, vector<4x3xf32>
+}

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-scan.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-scan.mlir
new file mode 100644
index 0000000000000..f8b670172ce40
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-scan.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt %s -test-vector-scan-lowering -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+func @entry() {
+  %f1 = arith.constant 1.0: f32
+  %f2 = arith.constant 2.0: f32
+  %f3 = arith.constant 3.0: f32
+  %f4 = arith.constant 4.0: f32
+  %f5 = arith.constant 5.0: f32
+  %f6 = arith.constant 6.0: f32
+
+  // Construct test vector.
+  %0 = vector.broadcast %f1 : f32 to vector<3x2xf32>
+  %1 = vector.insert %f2, %0[0, 1] : f32 into vector<3x2xf32>
+  %2 = vector.insert %f3, %1[1, 0] : f32 into vector<3x2xf32>
+  %3 = vector.insert %f4, %2[1, 1] : f32 into vector<3x2xf32>
+  %4 = vector.insert %f5, %3[2, 0] : f32 into vector<3x2xf32>
+  %x = vector.insert %f6, %4[2, 1] : f32 into vector<3x2xf32>
+  vector.print %x : vector<3x2xf32>
+  // CHECK:  ( ( 1, 2 ), ( 3, 4 ), ( 5, 6 ) )
+
+  %y = vector.broadcast %f6 : f32 to vector<2xf32>
+  %z = vector.broadcast %f6 : f32 to vector<3xf32>
+  // Scan
+  %a:2 = vector.scan <add>, %x, %y {inclusive = true, reduction_dim = 0} :
+    vector<3x2xf32>, vector<2xf32>
+  %b:2 = vector.scan <add>, %x, %z {inclusive = true, reduction_dim = 1} :
+    vector<3x2xf32>, vector<3xf32>
+  %c:2 = vector.scan <add>, %x, %y {inclusive = false, reduction_dim = 0} :
+    vector<3x2xf32>, vector<2xf32>
+  %d:2 = vector.scan <add>, %x, %z {inclusive = false, reduction_dim = 1} :
+    vector<3x2xf32>, vector<3xf32>
+
+  // CHECK: ( ( 1, 2 ), ( 4, 6 ), ( 9, 12 ) )
+  // CHECK: ( 9, 12 )
+  // CHECK: ( ( 1, 3 ), ( 3, 7 ), ( 5, 11 ) )
+  // CHECK: ( 3, 7, 11 )
+  // CHECK: ( ( 6, 6 ), ( 7, 8 ), ( 10, 12 ) )
+  // CHECK: ( 10, 12 )
+  // CHECK: ( ( 6, 7 ), ( 6, 9 ), ( 6, 11 ) )
+  // CHECK: ( 7, 9, 11 )
+  vector.print %a#0 : vector<3x2xf32>
+  vector.print %a#1 : vector<2xf32>
+  vector.print %b#0 : vector<3x2xf32>
+  vector.print %b#1 : vector<3xf32>
+  vector.print %c#0 : vector<3x2xf32>
+  vector.print %c#1 : vector<2xf32>
+  vector.print %d#0 : vector<3x2xf32>
+  vector.print %d#1 : vector<3xf32>
+
+  return 
+}

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 5e9e2fc1e3ae0..cd8ddf66270c1 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -627,6 +627,20 @@ struct TestFlattenVectorTransferPatterns
   }
 };
 
+struct TestVectorScanLowering
+    : public PassWrapper<TestVectorScanLowering, OperationPass<FuncOp>> {
+  StringRef getArgument() const final { return "test-vector-scan-lowering"; }
+  StringRef getDescription() const final {
+    return "Test lowering patterns that lower the scan op in the vector "
+           "dialect";
+  }
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorScanLoweringPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -661,6 +675,8 @@ void registerTestVectorLowerings() {
   PassRegistration<TestVectorTransferDropUnitDimsPatterns>();
 
   PassRegistration<TestFlattenVectorTransferPatterns>();
+
+  PassRegistration<TestVectorScanLowering>();
 }
 } // namespace test
 } // namespace mlir


        


More information about the Mlir-commits mailing list