[Mlir-commits] [mlir] 80e0bf1 - Add vector.scan op
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 28 12:08:40 PST 2022
Author: harsh
Date: 2022-01-28T20:07:57Z
New Revision: 80e0bf1af11c0c7bd2a6261d83dcb18987cd7f11
URL: https://github.com/llvm/llvm-project/commit/80e0bf1af11c0c7bd2a6261d83dcb18987cd7f11
DIFF: https://github.com/llvm/llvm-project/commit/80e0bf1af11c0c7bd2a6261d83dcb18987cd7f11.diff
LOG: Add vector.scan op
This patch adds the vector.scan op which computes the
scan for a given n-d vector. It requires specifying the operator,
the identity element and whether the scan is inclusive or
exclusive.
TEST: Added test in ops.mlir
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D117171
Added:
mlir/test/Dialect/Vector/vector-scan-transforms.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-scan.mlir
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 1f501ac6b89ea..6a5e0d961b1dc 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -2429,4 +2429,62 @@ def VectorScaleOp : Vector_Op<"vscale",
let verifier = ?;
}
+//===----------------------------------------------------------------------===//
+// VectorScanOp
+//===----------------------------------------------------------------------===//
+
+def Vector_ScanOp :
+ Vector_Op<"scan", [NoSideEffect,
+ AllTypesMatch<["source", "dest"]>,
+ AllTypesMatch<["initial_value", "accumulated_value"]> ]>,
+ Arguments<(ins Vector_CombiningKindAttr:$kind,
+ AnyVector:$source,
+ AnyVectorOfAnyRank:$initial_value,
+ I64Attr:$reduction_dim,
+ BoolAttr:$inclusive)>,
+ Results<(outs AnyVector:$dest,
+ AnyVectorOfAnyRank:$accumulated_value)> {
+ let summary = "Scan operation";
+ let description = [{
+ Performs an inclusive/exclusive scan on an n-D vector along a single
+ dimension returning an n-D result vector using the given
+ operation (add/mul/min/max for int/fp and and/or/xor for
+ int only) and a specified value for the initial value. The operator
+ returns the result of scan as well as the result of the last
+ reduction in the scan.
+
+ Example:
+
+ ```mlir
+ %1:2 = vector.scan <add>, %0, %acc {inclusive = false, reduction_dim = 1 : i64} :
+ vector<4x8x16x32xf32>, vector<4x16x32xf32>
+ ```
+ }];
+ let builders = [
+ OpBuilder<(ins "Value":$source, "Value":$initial_value,
+ "CombiningKind":$kind,
+ CArg<"int64_t", "0">:$reduction_dim,
+ CArg<"bool", "true">:$inclusive)>
+ ];
+ let extraClassDeclaration = [{
+ static StringRef getKindAttrName() { return "kind"; }
+ static StringRef getReductionDimAttrName() { return "reduction_dim"; }
+ VectorType getSourceType() {
+ return source().getType().cast<VectorType>();
+ }
+ VectorType getDestType() {
+ return dest().getType().cast<VectorType>();
+ }
+ VectorType getAccumulatorType() {
+ return accumulated_value().getType().cast<VectorType>();
+ }
+ VectorType getInitialValueType() {
+ return initial_value().getType().cast<VectorType>();
+ }
+ }];
+ let assemblyFormat =
+ "$kind `,` $source `,` $initial_value attr-dict `:` "
+ "type($source) `,` type($initial_value) ";
+}
+
#endif // VECTOR_OPS
diff --git a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
index 563f67d817804..8f01b0523ef24 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
@@ -171,6 +171,9 @@ void populateVectorContractLoweringPatterns(
/// transpose/broadcast ops into the contract.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns);
+/// Collect patterns to convert scan op
+void populateVectorScanLoweringPatterns(RewritePatternSet &patterns);
+
//===----------------------------------------------------------------------===//
// Vector.transfer patterns.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index e1289d5e7fad7..9493187461e49 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -4263,6 +4263,44 @@ void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<CreateMaskFolder>(context);
}
+//===----------------------------------------------------------------------===//
+// ScanOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(ScanOp op) {
+ VectorType srcType = op.getSourceType();
+ VectorType initialType = op.getInitialValueType();
+ // Check reduction dimension < rank.
+ int64_t srcRank = srcType.getRank();
+ int64_t reductionDim = op.reduction_dim();
+ if (reductionDim >= srcRank)
+ return op.emitOpError("reduction dimension ")
+ << reductionDim << " has to be less than " << srcRank;
+
+ // Check that rank(initial_value) = rank(src) - 1.
+ int64_t initialValueRank = initialType.getRank();
+ if (initialValueRank != srcRank - 1)
+ return op.emitOpError("initial value rank ")
+ << initialValueRank << " has to be equal to " << srcRank - 1;
+
+ // Check shapes of initial value and src.
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ ArrayRef<int64_t> initialValueShapes = initialType.getShape();
+ SmallVector<int64_t> expectedShape;
+ for (int i = 0; i < srcRank; i++) {
+ if (i != reductionDim)
+ expectedShape.push_back(srcShape[i]);
+ }
+ if (llvm::any_of(llvm::zip(initialValueShapes, expectedShape),
+ [](std::tuple<int64_t, int64_t> s) {
+ return std::get<0>(s) != std::get<1>(s);
+ })) {
+ return op.emitOpError("incompatible input/initial value shapes");
+ }
+
+ return success();
+}
+
void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
RewritePatternSet &patterns) {
patterns
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 3ead467bc5181..83553b642a3c2 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2348,6 +2348,204 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
}
};
+namespace {
+
+/// This function checks to see if the vector combining kind
+/// is consistent with the integer or float element type.
+static bool isValidKind(bool isInt, vector::CombiningKind kind) {
+ using vector::CombiningKind;
+ enum class KindType { FLOAT, INT, INVALID };
+ KindType type{KindType::INVALID};
+ switch (kind) {
+ case CombiningKind::MINF:
+ case CombiningKind::MAXF:
+ type = KindType::FLOAT;
+ break;
+ case CombiningKind::MINUI:
+ case CombiningKind::MINSI:
+ case CombiningKind::MAXUI:
+ case CombiningKind::MAXSI:
+ case CombiningKind::AND:
+ case CombiningKind::OR:
+ case CombiningKind::XOR:
+ type = KindType::INT;
+ break;
+ case CombiningKind::ADD:
+ case CombiningKind::MUL:
+ type = isInt ? KindType::INT : KindType::FLOAT;
+ break;
+ }
+ bool isValidIntKind = (type == KindType::INT) && isInt;
+ bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
+ return (isValidIntKind || isValidFloatKind);
+}
+
+/// This function constructs the appropriate integer or float
+/// operation given the vector combining kind and operands. The
+/// supported int operations are : add, mul, min (signed/unsigned),
+/// max(signed/unsigned), and, or, xor. The supported float
+/// operations are : add, mul, min and max.
+static Value genOperator(Location loc, Value x, Value y,
+ vector::CombiningKind kind,
+ PatternRewriter &rewriter) {
+ using vector::CombiningKind;
+
+ auto elType = x.getType().cast<VectorType>().getElementType();
+ bool isInt = elType.isIntOrIndex();
+
+ Value combinedResult{nullptr};
+ switch (kind) {
+ case CombiningKind::ADD:
+ if (isInt)
+ combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
+ else
+ combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
+ break;
+ case CombiningKind::MUL:
+ if (isInt)
+ combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
+ else
+ combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
+ break;
+ case CombiningKind::MINUI:
+ combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
+ break;
+ case CombiningKind::MINSI:
+ combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
+ break;
+ case CombiningKind::MAXUI:
+ combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
+ break;
+ case CombiningKind::MAXSI:
+ combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
+ break;
+ case CombiningKind::AND:
+ combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
+ break;
+ case CombiningKind::OR:
+ combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
+ break;
+ case CombiningKind::XOR:
+ combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
+ break;
+ case CombiningKind::MINF:
+ combinedResult = rewriter.create<arith::MinFOp>(loc, x, y);
+ break;
+ case CombiningKind::MAXF:
+ combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y);
+ break;
+ }
+ return combinedResult;
+}
+
+/// Convert vector.scan op into arith ops and
+/// vector.insert_strided_slice/extract_strided_slice
+///
+/// Ex:
+/// ```
+/// %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim =
+/// 1} :
+/// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
+/// ```
+/// Gets converted to:
+/// ```
+/// %cst = arith.constant dense<0> : vector<2x3xi32>
+/// %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1],
+/// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 =
+/// vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]}
+/// : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice
+/// %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} :
+/// vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 :
+/// vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1],
+/// strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 =
+/// vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1],
+/// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3,
+/// %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets =
+/// [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 =
+/// vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 :
+/// vector<2x3xi32>, vector<2xi32>
+/// ```
+struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
+ using OpRewritePattern<vector::ScanOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ScanOp scanOp,
+ PatternRewriter &rewriter) const override {
+ auto loc = scanOp.getLoc();
+ VectorType destType = scanOp.getDestType();
+ ArrayRef<int64_t> destShape = destType.getShape();
+ auto elType = destType.getElementType();
+ bool isInt = elType.isIntOrIndex();
+ if (!isValidKind(isInt, scanOp.kind()))
+ return failure();
+
+ VectorType resType = VectorType::get(destShape, elType);
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, resType, rewriter.getZeroAttr(resType));
+ int64_t reductionDim = scanOp.reduction_dim();
+ bool inclusive = scanOp.inclusive();
+ int64_t destRank = destType.getRank();
+ VectorType initialValueType = scanOp.getInitialValueType();
+ int64_t initialValueRank = initialValueType.getRank();
+
+ SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
+ reductionShape[reductionDim] = 1;
+ VectorType reductionType = VectorType::get(reductionShape, elType);
+ SmallVector<int64_t> offsets(destRank, 0);
+ SmallVector<int64_t> strides(destRank, 1);
+ SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
+ sizes[reductionDim] = 1;
+ ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
+ ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
+
+ Value lastOutput, lastInput;
+ for (int i = 0; i < destShape[reductionDim]; i++) {
+ offsets[reductionDim] = i;
+ ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
+ Value input = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, reductionType, scanOp.source(), scanOffsets, scanSizes,
+ scanStrides);
+ Value output;
+ if (i == 0) {
+ if (inclusive) {
+ output = input;
+ } else {
+ if (initialValueRank == 0) {
+ // ShapeCastOp cannot handle 0-D vectors
+ output = rewriter.create<vector::BroadcastOp>(
+ loc, input.getType(), scanOp.initial_value());
+ } else {
+ output = rewriter.create<vector::ShapeCastOp>(
+ loc, input.getType(), scanOp.initial_value());
+ }
+ }
+ } else {
+ Value y = inclusive ? input : lastInput;
+ output = genOperator(loc, lastOutput, y, scanOp.kind(), rewriter);
+ assert(output != nullptr);
+ }
+ result = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, output, result, offsets, strides);
+ lastOutput = output;
+ lastInput = input;
+ }
+
+ Value reduction;
+ if (initialValueRank == 0) {
+ Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
+ reduction =
+ rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
+ } else {
+ reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
+ lastOutput);
+ }
+
+ rewriter.replaceOp(scanOp, {result, reduction});
+ return success();
+ }
+};
+
+} // namespace
+
void mlir::vector::populateVectorMaskMaterializationPatterns(
RewritePatternSet &patterns, bool indexOptimizations) {
patterns.add<VectorCreateMaskOpConversion,
@@ -2421,3 +2619,8 @@ void mlir::vector::populateVectorTransferLoweringPatterns(
.add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
patterns.getContext());
}
+
+void mlir::vector::populateVectorScanLoweringPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ScanToArithOps>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index a384d42ef6112..cd795b7bab083 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1480,3 +1480,29 @@ func @insert_map_id(%v: vector<2x1xf32>, %v1: vector<4x32xf32>, %id : index) {
%0 = vector.insert_map %v, %v1[%id] : vector<2x1xf32> into vector<4x32xf32>
}
+// -----
+
+func @scan_reduction_dim_constraint(%arg0: vector<2x3xi32>, %arg1: vector<3xi32>) -> vector<3xi32> {
+ // expected-error at +1 {{'vector.scan' op reduction dimension 5 has to be less than 2}}
+ %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 5} :
+ vector<2x3xi32>, vector<3xi32>
+ return %0#1 : vector<3xi32>
+}
+
+// -----
+
+func @scan_ival_rank_constraint(%arg0: vector<2x3xi32>, %arg1: vector<1x3xi32>) -> vector<1x3xi32> {
+ // expected-error at +1 {{initial value rank 2 has to be equal to 1}}
+ %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 0} :
+ vector<2x3xi32>, vector<1x3xi32>
+ return %0#1 : vector<1x3xi32>
+}
+
+// -----
+
+func @scan_incompatible_shapes(%arg0: vector<2x3xi32>, %arg1: vector<5xi32>) -> vector<2x3xi32> {
+ // expected-error at +1 {{incompatible input/initial value shapes}}
+ %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 0} :
+ vector<2x3xi32>, vector<5xi32>
+ return %0#0 : vector<2x3xi32>
+}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index a77d7ba2ebf54..a839476602141 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -717,3 +717,11 @@ func @get_vector_scale() -> index {
%0 = vector.vscale
return %0 : index
}
+
+// CHECK-LABEL: @vector_scan
+func @vector_scan(%0: vector<4x8x16x32xf32>) -> vector<4x8x16x32xf32> {
+ %1 = arith.constant dense<0.0> : vector<4x16x32xf32>
+ %2:2 = vector.scan <add>, %0, %1 {reduction_dim = 1 : i64, inclusive = true} :
+ vector<4x8x16x32xf32>, vector<4x16x32xf32>
+ return %2#0 : vector<4x8x16x32xf32>
+}
diff --git a/mlir/test/Dialect/Vector/vector-scan-transforms.mlir b/mlir/test/Dialect/Vector/vector-scan-transforms.mlir
new file mode 100644
index 0000000000000..2ff84bac6142f
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-scan-transforms.mlir
@@ -0,0 +1,91 @@
+// RUN: mlir-opt %s --test-vector-scan-lowering | FileCheck %s
+
+// CHECK-LABEL: func @scan1d_inc
+// CHECK-SAME: %[[ARG0:.*]]: vector<2xi32>,
+// CHECK-SAME: %[[ARG1:.*]]: vector<i32>
+// CHECK: %[[A:.*]] = arith.constant dense<0> : vector<2xi32>
+// CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xi32> to vector<1xi32>
+// CHECK: %[[C:.*]] = vector.insert_strided_slice %[[B]], %[[A]] {offsets = [0], strides = [1]} : vector<1xi32> into vector<2xi32>
+// CHECK: %[[D:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1], sizes = [1], strides = [1]} : vector<2xi32> to vector<1xi32>
+// CHECK: %[[E:.*]] = arith.addi %[[B]], %[[D]] : vector<1xi32>
+// CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[C]] {offsets = [1], strides = [1]} : vector<1xi32> into vector<2xi32>
+// CHECK: %[[G:.*]] = vector.extract %[[E]][0] : vector<1xi32>
+// CHECK: %[[H:.*]] = vector.broadcast %[[G]] : i32 to vector<i32>
+// CHECK: return %[[F]], %[[H]] : vector<2xi32>, vector<i32>
+func @scan1d_inc(%arg0 : vector<2xi32>, %arg1 : vector<i32>) -> (vector<2xi32>, vector<i32>) {
+ %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 0} :
+ vector<2xi32>, vector<i32>
+ return %0#0, %0#1 : vector<2xi32>, vector<i32>
+}
+
+// CHECK-LABEL: func @scan1d_exc
+// CHECK-SAME: %[[ARG0:.*]]: vector<2xi32>,
+// CHECK-SAME: %[[ARG1:.*]]: vector<i32>
+// CHECK: %[[A:.*]] = arith.constant dense<0> : vector<2xi32>
+// CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xi32> to vector<1xi32>
+// CHECK: %[[C:.*]] = vector.broadcast %[[ARG1]] : vector<i32> to vector<1xi32>
+// CHECK: %[[D:.*]] = vector.insert_strided_slice %[[C]], %[[A]] {offsets = [0], strides = [1]} : vector<1xi32> into vector<2xi32>
+// CHECK: %[[E:.*]] = arith.addi %[[C]], %[[B]] : vector<1xi32>
+// CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[D]] {offsets = [1], strides = [1]} : vector<1xi32> into vector<2xi32>
+// CHECK: %[[G:.*]] = vector.extract %[[E]][0] : vector<1xi32>
+// CHECK: %[[H:.*]] = vector.broadcast %[[G]] : i32 to vector<i32>
+// CHECK: return %[[F]], %[[H]] : vector<2xi32>, vector<i32>
+func @scan1d_exc(%arg0 : vector<2xi32>, %arg1 : vector<i32>) -> (vector<2xi32>, vector<i32>) {
+ %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = false, reduction_dim = 0} :
+ vector<2xi32>, vector<i32>
+ return %0#0, %0#1 : vector<2xi32>, vector<i32>
+}
+
+// CHECK-LABEL: func @scan2d_mul_dim0
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xi32>,
+// CHECK-SAME: %[[ARG1:.*]]: vector<3xi32>
+// CHECK: %[[A:.*]] = arith.constant dense<0> : vector<2x3xi32>
+// CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x3xi32> to vector<1x3xi32>
+// CHECK: %[[C:.*]] = vector.insert_strided_slice %[[B]], %[[A]] {offsets = [0, 0], strides = [1, 1]} : vector<1x3xi32> into vector<2x3xi32>
+// CHECK: %[[D:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x3xi32> to vector<1x3xi32>
+// CHECK: %[[E:.*]] = arith.muli %[[B]], %[[D]] : vector<1x3xi32>
+// CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[C]] {offsets = [1, 0], strides = [1, 1]} : vector<1x3xi32> into vector<2x3xi32>
+// CHECK: %[[G:.*]] = vector.shape_cast %[[E]] : vector<1x3xi32> to vector<3xi32>
+// CHECK: return %[[F]], %[[G]] : vector<2x3xi32>, vector<3xi32>
+func @scan2d_mul_dim0(%arg0 : vector<2x3xi32>, %arg1 : vector<3xi32>) -> (vector<2x3xi32>, vector<3xi32>) {
+ %0:2 = vector.scan <mul>, %arg0, %arg1 {inclusive = true, reduction_dim = 0} :
+ vector<2x3xi32>, vector<3xi32>
+ return %0#0, %0#1 : vector<2x3xi32>, vector<3xi32>
+}
+
+// CHECK-LABEL: func @scan2d_mul_dim1
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xi32>,
+// CHECK-SAME: %[[ARG1:.*]]: vector<2xi32>
+// CHECK: %[[A:.*]] = arith.constant dense<0> : vector<2x3xi32>
+// CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32>
+// CHECK: %[[C:.*]] = vector.insert_strided_slice %[[B]], %[[A]] {offsets = [0, 0], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32>
+// CHECK: %[[D:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32>
+// CHECK: %[[E:.*]] = arith.muli %[[B]], %[[D]] : vector<2x1xi32>
+// CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[C]] {offsets = [0, 1], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32>
+// CHECK: %[[G:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32>
+// CHECK: %[[H:.*]] = arith.muli %[[E]], %[[G]] : vector<2x1xi32>
+// CHECK: %[[I:.*]] = vector.insert_strided_slice %[[H]], %[[F]] {offsets = [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32>
+// CHECK: %[[J:.*]] = vector.shape_cast %[[H]] : vector<2x1xi32> to vector<2xi32>
+// CHECK: return %[[I]], %[[J]] : vector<2x3xi32>, vector<2xi32>
+func @scan2d_mul_dim1(%arg0 : vector<2x3xi32>, %arg1 : vector<2xi32>) -> (vector<2x3xi32>, vector<2xi32>) {
+ %0:2 = vector.scan <mul>, %arg0, %arg1 {inclusive = true, reduction_dim = 1} :
+ vector<2x3xi32>, vector<2xi32>
+ return %0#0, %0#1 : vector<2x3xi32>, vector<2xi32>
+}
+
+// CHECK-LABEL: func @scan3d_mul_dim1
+// CHECK-SAME: %[[ARG0:.*]]: vector<4x2x3xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: vector<4x3xf32>
+// CHECK: %[[A:.*]] = arith.constant dense<0.000000e+00> : vector<4x2x3xf32>
+// CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[C:.*]] = vector.shape_cast %[[ARG1]] : vector<4x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[D:.*]] = vector.insert_strided_slice %[[C]], %[[A]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
+// CHECK: %[[E:.*]] = arith.mulf %[[C]], %[[B]] : vector<4x1x3xf32>
+// CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[D]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
+// CHECK: %[[G:.*]] = vector.shape_cast %[[E]] : vector<4x1x3xf32> to vector<4x3xf32>
+// CHECK: return %[[F]], %[[G]] : vector<4x2x3xf32>, vector<4x3xf32>
+func @scan3d_mul_dim1(%arg0 : vector<4x2x3xf32>, %arg1 : vector<4x3xf32>) -> (vector<4x2x3xf32>, vector<4x3xf32>) {
+ %0:2 = vector.scan <mul>, %arg0, %arg1 {inclusive = false, reduction_dim = 1} :
+ vector<4x2x3xf32>, vector<4x3xf32>
+ return %0#0, %0#1 : vector<4x2x3xf32>, vector<4x3xf32>
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-scan.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-scan.mlir
new file mode 100644
index 0000000000000..f8b670172ce40
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-scan.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt %s -test-vector-scan-lowering -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+func @entry() {
+ %f1 = arith.constant 1.0: f32
+ %f2 = arith.constant 2.0: f32
+ %f3 = arith.constant 3.0: f32
+ %f4 = arith.constant 4.0: f32
+ %f5 = arith.constant 5.0: f32
+ %f6 = arith.constant 6.0: f32
+
+ // Construct test vector.
+ %0 = vector.broadcast %f1 : f32 to vector<3x2xf32>
+ %1 = vector.insert %f2, %0[0, 1] : f32 into vector<3x2xf32>
+ %2 = vector.insert %f3, %1[1, 0] : f32 into vector<3x2xf32>
+ %3 = vector.insert %f4, %2[1, 1] : f32 into vector<3x2xf32>
+ %4 = vector.insert %f5, %3[2, 0] : f32 into vector<3x2xf32>
+ %x = vector.insert %f6, %4[2, 1] : f32 into vector<3x2xf32>
+ vector.print %x : vector<3x2xf32>
+ // CHECK: ( ( 1, 2 ), ( 3, 4 ), ( 5, 6 ) )
+
+ %y = vector.broadcast %f6 : f32 to vector<2xf32>
+ %z = vector.broadcast %f6 : f32 to vector<3xf32>
+ // Scan
+ %a:2 = vector.scan <add>, %x, %y {inclusive = true, reduction_dim = 0} :
+ vector<3x2xf32>, vector<2xf32>
+ %b:2 = vector.scan <add>, %x, %z {inclusive = true, reduction_dim = 1} :
+ vector<3x2xf32>, vector<3xf32>
+ %c:2 = vector.scan <add>, %x, %y {inclusive = false, reduction_dim = 0} :
+ vector<3x2xf32>, vector<2xf32>
+ %d:2 = vector.scan <add>, %x, %z {inclusive = false, reduction_dim = 1} :
+ vector<3x2xf32>, vector<3xf32>
+
+ // CHECK: ( ( 1, 2 ), ( 4, 6 ), ( 9, 12 ) )
+ // CHECK: ( 9, 12 )
+ // CHECK: ( ( 1, 3 ), ( 3, 7 ), ( 5, 11 ) )
+ // CHECK: ( 3, 7, 11 )
+ // CHECK: ( ( 6, 6 ), ( 7, 8 ), ( 10, 12 ) )
+ // CHECK: ( 10, 12 )
+ // CHECK: ( ( 6, 7 ), ( 6, 9 ), ( 6, 11 ) )
+ // CHECK: ( 7, 9, 11 )
+ vector.print %a#0 : vector<3x2xf32>
+ vector.print %a#1 : vector<2xf32>
+ vector.print %b#0 : vector<3x2xf32>
+ vector.print %b#1 : vector<3xf32>
+ vector.print %c#0 : vector<3x2xf32>
+ vector.print %c#1 : vector<2xf32>
+ vector.print %d#0 : vector<3x2xf32>
+ vector.print %d#1 : vector<3xf32>
+
+ return
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 5e9e2fc1e3ae0..cd8ddf66270c1 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -627,6 +627,20 @@ struct TestFlattenVectorTransferPatterns
}
};
+struct TestVectorScanLowering
+ : public PassWrapper<TestVectorScanLowering, OperationPass<FuncOp>> {
+ StringRef getArgument() const final { return "test-vector-scan-lowering"; }
+ StringRef getDescription() const final {
+ return "Test lowering patterns that lower the scan op in the vector "
+ "dialect";
+ }
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorScanLoweringPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+
} // namespace
namespace mlir {
@@ -661,6 +675,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorTransferDropUnitDimsPatterns>();
PassRegistration<TestFlattenVectorTransferPatterns>();
+
+ PassRegistration<TestVectorScanLowering>();
}
} // namespace test
} // namespace mlir
More information about the Mlir-commits
mailing list