[Mlir-commits] [mlir] 5279e11 - [mlir][linalg] Retire Linalg's Vectorization Pattern
Guray Ozen
llvmlistbot at llvm.org
Thu Sep 15 02:23:55 PDT 2022
Author: Guray Ozen
Date: 2022-09-15T11:23:46+02:00
New Revision: 5279e11f063db6a0cc87ccf9e0e1c7b1b31aa7cf
URL: https://github.com/llvm/llvm-project/commit/5279e11f063db6a0cc87ccf9e0e1c7b1b31aa7cf
DIFF: https://github.com/llvm/llvm-project/commit/5279e11f063db6a0cc87ccf9e0e1c7b1b31aa7cf.diff
LOG: [mlir][linalg] Retire Linalg's Vectorization Pattern
This revision retires the LinalgCodegenStrategy vectorization pattern. Please see the context: https://discourse.llvm.org/t/psa-retire-linalg-filter-based-patterns/63785.
This revision improves the transform dialect's VectorizeOp in different ways below:
- Adds LinalgDialect as a dependent dialect. When `transform.structured.vectorize` vectorizes `tensor.pad`, it generates `linalg.init_tensor`. In this case, linalg dialect must be registered.
- Inserts CopyVectorizationPattern in order to vectorize `memref.copy`.
- Creates two attributes: `disable_multi_reduction_to_contract_patterns` and `disable_transfer_permutation_map_lowering_patterns`. They are limiting the power of vectorization and are currently intended for testing purposes.
It also removes some of the "CHECK: vector.transfer_write" in the vectorization.mlir test. They are redundant writes, at the end of the code there is a rewrite to the same place. Transform dialect no longer generates them.
Depends on D133684 that retires the LinalgCodegenStrategy vectorization pass.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D133699
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 46d70a6561b0a..79c0e6266c827 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -767,6 +767,10 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
Note that this transformation is invalidating the handles to any payload IR
operation that is contained inside the vectorization target.
+ `disable_multi_reduction_to_contract_patterns` and
+ `disable_transfer_permutation_map_lowering_patterns` limits the power of
+ vectorization. They are currently intended for testing purposes.
+
#### Return modes:
This operation produces `definiteFailure` if vectorization fails for any
@@ -776,7 +780,9 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
}];
let arguments = (ins PDL_Operation:$target,
- DefaultValuedAttr<BoolAttr, "false">:$vectorize_padding);
+ DefaultValuedAttr<BoolAttr, "false">:$vectorize_padding,
+ DefaultValuedAttr<BoolAttr, "false">:$disable_multi_reduction_to_contract_patterns,
+ DefaultValuedAttr<BoolAttr, "false">:$disable_transfer_permutation_map_lowering_patterns);
let results = (outs PDL_Operation:$transformed);
let assemblyFormat = "$target attr-dict";
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 958681557702d..43185a208af63 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -926,31 +926,6 @@ struct LinalgPeelingPattern : public OpInterfaceRewritePattern<LinalgOp> {
/// Empty for now, used for SFINAE purposes only.
struct LinalgVectorizationOptions {};
-/// `filter` controls LinalgTransformMarker matching and update when specified.
-/// See `vectorizeLinalgOp` for more details.
-struct LinalgVectorizationPattern : public OpInterfaceRewritePattern<LinalgOp> {
- /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
- LinalgVectorizationPattern(
- MLIRContext *context,
- LinalgTransformationFilter f = LinalgTransformationFilter(),
- LinalgVectorizationOptions options = LinalgVectorizationOptions(),
- PatternBenefit benefit = 1);
-
- /// Construct a pattern specifically applied to `opName`.
- LinalgVectorizationPattern(
- StringRef opName, MLIRContext *context,
- LinalgVectorizationOptions options = LinalgVectorizationOptions(),
- LinalgTransformationFilter f = LinalgTransformationFilter(),
- PatternBenefit benefit = 1);
-
- LogicalResult matchAndRewrite(LinalgOp linalgOp,
- PatternRewriter &rewriter) const override;
-
-private:
- /// LinalgTransformMarker handles special attribute manipulations.
- LinalgTransformationFilter filter;
-};
-
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `vectorizeLinalgOp` for more details.
struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
@@ -1335,18 +1310,6 @@ class VectorizationPatterns<> {
const LinalgTransformationFilter &f) {}
};
-template <typename OpTy, typename... OpTypes>
-class VectorizationPatterns<OpTy, OpTypes...> {
-public:
- static void insert(RewritePatternSet &patterns,
- const LinalgVectorizationOptions &options,
- const LinalgTransformationFilter &f) {
- patterns.add<LinalgVectorizationPattern>(OpTy::getOperationName(),
- patterns.getContext(), options, f);
- VectorizationPatterns<OpTypes...>::insert(patterns, options, f);
- }
-};
-
template <typename... OpTypes>
class TilingPatterns;
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 29b13e27de7ed..93b1274d0e884 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1166,6 +1166,22 @@ LogicalResult TileToForeachThreadOp::verify() {
// VectorizeOp
//===----------------------------------------------------------------------===//
+namespace {
+/// This is an helper only to call vectorize via a pattern inside of
+/// VectorizeOp::applyToOne.
+struct VectorizationPattern : public RewritePattern {
+ explicit VectorizationPattern(MLIRContext *context)
+ : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
+ if (!linalgOp)
+ return failure();
+ return vectorize(rewriter, linalgOp);
+ }
+};
+} // namespace
+
DiagnosedSilenceableFailure
transform::VectorizeOp::applyToOne(Operation *target,
SmallVectorImpl<Operation *> &results,
@@ -1178,15 +1194,22 @@ transform::VectorizeOp::applyToOne(Operation *target,
MLIRContext *ctx = getContext();
RewritePatternSet patterns(ctx);
- patterns.add<LinalgVectorizationPattern>(ctx);
+ patterns.add<VectorizationPattern>(ctx);
+
+ if (!getDisableTransferPermutationMapLoweringPatterns())
+ vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
+
+ if (!getDisableMultiReductionToContractPatterns())
+ vector::populateVectorReductionToContractPatterns(patterns);
- vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
- vector::populateVectorReductionToContractPatterns(patterns);
patterns.add<linalg::LinalgCopyVTRForwardingPattern,
linalg::LinalgCopyVTWForwardingPattern>(ctx,
/*benefit=*/2);
vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
+
+ patterns.add<CopyVectorizationPattern>(ctx);
+
if (getVectorizePadding())
linalg::populatePadOpVectorizationPatterns(patterns);
@@ -1212,7 +1235,7 @@ class LinalgTransformDialectExtension
void init() {
declareDependentDialect<pdl::PDLDialect>();
-
+ declareDependentDialect<LinalgDialect>();
declareGeneratedDialect<AffineDialect>();
declareGeneratedDialect<arith::ArithmeticDialect>();
declareGeneratedDialect<scf::SCFDialect>();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 31136dda7a918..b00d9233526d7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -590,25 +590,6 @@ LogicalResult mlir::linalg::LinalgPeelingPattern::matchAndRewrite(
return success();
}
-mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
- MLIRContext *context, LinalgTransformationFilter f,
- LinalgVectorizationOptions options, PatternBenefit benefit)
- : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
- filter(std::move(f)) {}
-
-mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
- StringRef opName, MLIRContext *context, LinalgVectorizationOptions options,
- LinalgTransformationFilter f, PatternBenefit benefit)
- : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
- filter(f.addOpNameFilter(opName)) {}
-
-LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite(
- LinalgOp linalgOp, PatternRewriter &rewriter) const {
- if (failed(filter.checkAndNotify(rewriter, linalgOp)))
- return failure();
- return vectorize(rewriter, linalgOp);
-}
-
LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
memref::CopyOp copyOp, PatternRewriter &rewriter) const {
return vectorizeCopy(rewriter, copyOp);
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 229530587fdc5..6ac2fdbbdc572 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s
// -----
@@ -12,6 +12,16 @@ func.func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memre
return
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.dot"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true }
+ }
+}
+
// -----
// CHECK-LABEL: contraction_matvec
@@ -24,6 +34,16 @@ func.func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %
return
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matvec"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true }
+ }
+}
+
// -----
// CHECK-LABEL: contraction_matmul
@@ -35,6 +55,16 @@ func.func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf3
return
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true }
+ }
+}
+
// -----
// CHECK-LABEL: contraction_batch_matmul
@@ -47,6 +77,16 @@ func.func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1
return
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true }
+ }
+}
+
// -----
#matmul_trait = {
@@ -80,6 +120,16 @@ func.func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
return
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true }
+ }
+}
+
// -----
#matmul_transpose_out_trait = {
@@ -113,6 +163,16 @@ func.func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
return
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true }
+ }
+}
+
// -----
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
@@ -133,6 +193,16 @@ func.func @generic_interchanged_transpose(%arg0: tensor<12x128x32xf32>) -> tenso
return %1 : tensor<128x12x32xf32>
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true }
+ }
+}
+
// -----
#matmul_trait = {
@@ -166,6 +236,16 @@ func.func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32
return
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true }
+ }
+}
+
// -----
// CHECK-LABEL: func @vectorization_test_2
@@ -179,6 +259,16 @@ func.func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
return
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true }
+ }
+}
+
// -----
// CHECK-LABEL: func @test_vectorize_scalar_input
@@ -196,6 +286,16 @@ func.func @test_vectorize_scalar_input(%A : memref<8x16xf32>, %arg0 : f32) {
return
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1
+ }
+}
+
// -----
// CHECK-LABEL: func @test_do_not_vectorize_unsupported_element_types
@@ -213,6 +313,16 @@ func.func @test_do_not_vectorize_unsupported_element_types(%A : memref<8x16xcomp
return
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1
+ }
+}
+
// -----
// CHECK-LABEL: func @test_vectorize_fill
@@ -223,6 +333,16 @@ func.func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
return
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.fill"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1
+ }
+}
+
// -----
// CHECK-LABEL: func @test_vectorize_fill
@@ -234,6 +354,16 @@ func.func @test_vectorize_fill_scalar(%A : memref<f32>, %arg0 : f32) {
return
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.fill"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1
+ }
+}
+
// -----
// CHECK-LABEL: func @test_vectorize_copy
@@ -244,6 +374,16 @@ func.func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
return
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["memref.copy"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1
+ }
+}
+
// -----
// CHECK-LABEL: func @test_vectorize_copy_scalar
@@ -257,6 +397,15 @@ func.func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
return
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["memref.copy"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1
+ }
+}
// -----
// CHECK-LABEL: func @test_vectorize_trailing_index
@@ -278,6 +427,16 @@ func.func @test_vectorize_trailing_index(%arg0: memref<1x2x4x8xindex>) {
return
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1
+ }
+}
+
// -----
// CHECK-LABEL: func @test_vectorize_inner_index
@@ -300,6 +459,16 @@ func.func @test_vectorize_inner_index(%arg0: memref<1x2x4x8xindex>) {
return
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1
+ }
+}
+
// -----
// CHECK-LABEL: func @generic_vectorize
@@ -378,6 +547,16 @@ func.func @generic_vectorize(%arg0: memref<4x256xf32>,
return
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1 {disable_transfer_permutation_map_lowering_patterns = true }
+ }
+}
+
// -----
// CHECK-LABEL: func @generic_vectorize_tensor
@@ -462,6 +641,16 @@ func.func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>,
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1 { disable_transfer_permutation_map_lowering_patterns = true }
+ }
+}
+
// -----
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, 0, 0, d1)>
@@ -499,6 +688,16 @@ func.func @generic_vectorize_broadcast_transpose(
return
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1 {disable_transfer_permutation_map_lowering_patterns = true }
+ }
+}
+
// -----
// Test
diff erent input maps.
@@ -535,6 +734,16 @@ func.func @vectorization_transpose(%A: memref<14x7xf32>, %B: memref<16x14xf32>,
return
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1 {disable_transfer_permutation_map_lowering_patterns = true }
+ }
+}
+
// -----
// CHECK-LABEL: func @matmul_tensors
@@ -560,6 +769,16 @@ func.func @matmul_tensors(
return %0 : tensor<8x12xf32>
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true }
+ }
+}
+
// -----
// CHECK-LABEL: func @pad_static(
@@ -581,6 +800,17 @@ func.func @pad_static(%arg0: tensor<2x?x2xf32>, %pad_value: f32) -> tensor<2x3x4
return %0 : tensor<2x3x4xf32>
}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1 { vectorize_padding = true }
+ }
+}
+
// -----
// CHECK-LABEL: func @pad_static_source(
@@ -602,6 +832,18 @@ func.func @pad_static_source(%arg0: tensor<2x5x2xf32>, %pad_value: f32) -> tenso
return %0 : tensor<2x6x4xf32>
}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1 { vectorize_padding = true }
+ }
+}
+
+
// -----
// CHECK-LABEL: func @pad_static_dynamic(
@@ -630,6 +872,18 @@ func.func @pad_static_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: in
return %0 : tensor<6x?x?x?xf32>
}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1 { vectorize_padding = true }
+ }
+}
+
+
// -----
// CHECK-LABEL: func @pad_and_transfer_read
@@ -652,6 +906,17 @@ func.func @pad_and_transfer_read(%arg0: tensor<5x6xf32>) -> vector<7x9xf32> {
return %1 : vector<7x9xf32>
}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1 { vectorize_padding = true }
+ }
+}
+
// -----
func.func private @make_vector() -> vector<7x9xf32>
@@ -678,6 +943,17 @@ func.func @pad_and_transfer_write_static(
return %3 : tensor<5x6xf32>
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %3 = transform.structured.match ops{["tensor.pad"]} in %arg1
+ %4 = get_closest_isolated_parent %3
+ %5 = transform.structured.vectorize %4 { vectorize_padding = true }
+ }
+}
+
+
// -----
func.func private @make_vector() -> vector<7x9xf32>
@@ -707,6 +983,17 @@ func.func @pad_and_transfer_write_dynamic_static(
return %3 : tensor<?x6xf32>
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %3 = transform.structured.match ops{["tensor.pad"]} in %arg1
+ %4 = get_closest_isolated_parent %3
+ %5 = transform.structured.vectorize %4 { vectorize_padding = true }
+ }
+}
+
+
// -----
func.func private @make_vector() -> tensor<12x13xf32>
@@ -733,6 +1020,17 @@ func.func @pad_and_insert_slice_source(
return %r : tensor<12x13xf32>
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %3 = transform.structured.match ops{["tensor.pad"]} in %arg1
+ %4 = get_closest_isolated_parent %3
+ %5 = transform.structured.vectorize %4 { vectorize_padding = true }
+ }
+}
+
+
// -----
func.func private @make_vector() -> tensor<12x13xf32>
@@ -753,6 +1051,16 @@ func.func @pad_and_insert_slice_dest(
return %r : tensor<1x12x13xf32>
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %3 = transform.structured.match ops{["tensor.pad"]} in %arg1
+ %4 = get_closest_isolated_parent %3
+ %5 = transform.structured.vectorize %4
+ }
+}
+
// -----
// CHECK-LABEL: func @pad_tensor_non_const_pad_value
@@ -782,6 +1090,17 @@ func.func @pad_tensor_non_const_pad_value(%arg0: tensor<5x6xf32>) -> tensor<12x1
return %0 : tensor<12x13xf32>
}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %3 = transform.structured.match ops{["tensor.pad"]} in %arg1
+ %4 = get_closest_isolated_parent %3
+ %5 = transform.structured.vectorize %4 { vectorize_padding = true }
+ }
+}
+
// -----
// CHECK-LABEL: func @sum_exp
@@ -809,6 +1128,17 @@ func.func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>)
return %0 : tensor<4x16xf32>
}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %3 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %4 = get_closest_isolated_parent %3
+ %5 = transform.structured.vectorize %4
+ }
+}
+
// -----
// CHECK-DAG: #[[$M1:.*]] = affine_map<(d0, d1) -> (d1, d0, 0, 0)>
@@ -846,13 +1176,23 @@ func.func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output
return %0 : tensor<5x2xf32>
}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %3 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %4 = get_closest_isolated_parent %3
+ %5 = transform.structured.vectorize %4 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true }
+ }
+}
+
// -----
// CHECK-LABEL: func @red_max_2d(
func.func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
// CHECK: %[[CMINF:.+]] = arith.constant dense<-3.402820e+38> : vector<4xf32>
// CHECK: linalg.init_tensor [4] : tensor<4xf32>
- // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
// CHECK: vector.multi_reduction <maxf>, {{.*}}, %[[CMINF]] [1] : vector<4x4xf32> to vector<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
%ident = arith.constant -3.40282e+38 : f32
@@ -869,13 +1209,23 @@ func.func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
return %red : tensor<4xf32>
}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %3 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %4 = get_closest_isolated_parent %3
+ %5 = transform.structured.vectorize %4 { vectorize_padding = true }
+ }
+}
+
// -----
// CHECK-LABEL: func @red_min_2d(
func.func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
// CHECK: %[[CMAXF:.+]] = arith.constant dense<3.402820e+38> : vector<4xf32>
// CHECK: linalg.init_tensor [4] : tensor<4xf32>
- // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
// CHECK: vector.multi_reduction <minf>, {{.*}}, %[[CMAXF]] [1] : vector<4x4xf32> to vector<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
@@ -893,12 +1243,22 @@ func.func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
return %red : tensor<4xf32>
}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %3 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %4 = get_closest_isolated_parent %3
+ %5 = transform.structured.vectorize %4
+ }
+}
+
// -----
// CHECK-LABEL: func @red_mul_2d(
func.func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
// CHECK: linalg.init_tensor [4] : tensor<4xf32>
- // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
// CHECK: vector.multi_reduction <mul>, {{.*}}, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
@@ -916,12 +1276,22 @@ func.func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
return %red : tensor<4xf32>
}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %3 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %4 = get_closest_isolated_parent %3
+ %5 = transform.structured.vectorize %4
+ }
+}
+
// -----
// CHECK-LABEL: func @red_or_2d(
func.func @red_or_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
// CHECK: linalg.init_tensor [4] : tensor<4xi1>
- // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
// CHECK: vector.multi_reduction <or>, {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
@@ -939,12 +1309,22 @@ func.func @red_or_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
return %red : tensor<4xi1>
}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %3 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %4 = get_closest_isolated_parent %3
+ %5 = transform.structured.vectorize %4
+ }
+}
+
// -----
// CHECK-LABEL: func @red_and_2d(
func.func @red_and_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
// CHECK: linalg.init_tensor [4] : tensor<4xi1>
- // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
// CHECK: vector.multi_reduction <and>, {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
@@ -962,12 +1342,22 @@ func.func @red_and_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
return %red : tensor<4xi1>
}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %3 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %4 = get_closest_isolated_parent %3
+ %5 = transform.structured.vectorize %4
+ }
+}
+
// -----
// CHECK-LABEL: func @red_xor_2d(
func.func @red_xor_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
// CHECK: linalg.init_tensor [4] : tensor<4xi1>
- // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
// CHECK: vector.multi_reduction <xor>, {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
@@ -985,6 +1375,17 @@ func.func @red_xor_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
return %red : tensor<4xi1>
}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %3 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %4 = get_closest_isolated_parent %3
+ %5 = transform.structured.vectorize %4
+ }
+}
+
// -----
// CHECK-DAG: #[[$M5:.*]] = affine_map<(d0, d1) -> (d0, 0)>
@@ -1011,6 +1412,17 @@ func.func @explicit_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>) ->
return %red : tensor<4x4xf32>
}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %3 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %4 = get_closest_isolated_parent %3
+ %5 = transform.structured.vectorize %4
+ }
+}
+
// -----
// CHECK-DAG: #[[$M6:.*]] = affine_map<(d0, d1) -> (d0, 0)>
@@ -1041,6 +1453,21 @@ func.func @fused_broadcast_red_2d(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>
return %red : tensor<4xf32>
}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.fill"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1
+
+ %3 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %4 = get_closest_isolated_parent %3
+ %5 = transform.structured.vectorize %4
+ }
+}
+
// -----
// CHECK-LABEL: func @reduce_1d(
@@ -1054,8 +1481,6 @@ func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
// CHECK: %[[init:.*]] = linalg.init_tensor [] : tensor<f32>
%0 = linalg.init_tensor [] : tensor<f32>
- // CHECK: %[[f:.*]] = vector.transfer_write %[[vF0]], %[[init]][]
- // CHECK-SAME: : vector<f32>, tensor<f32>
%1 = linalg.fill ins(%f0 : f32) outs(%0 : tensor<f32>) -> tensor<f32>
// CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
// CHECK-SAME: : tensor<32xf32>, vector<32xf32>
@@ -1063,7 +1488,7 @@ func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
// CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]], %[[f0]] [0]
// CHECK-SAME: : vector<32xf32> to f32
// CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<f32>
- // CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[f]][]
+ // CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[init]][]
// CHECK-SAME: : vector<f32>, tensor<f32>
%2 = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>,
@@ -1079,6 +1504,16 @@ func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
return %2 : tensor<f32>
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1
+ }
+}
+
// -----
@@ -1103,6 +1538,16 @@ func.func @not_projected_permutation(%arg0: tensor<8x8xf32>) -> tensor<6x6x3x3xf
return %result : tensor<6x6x3x3xf32>
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1
+ }
+}
+
// -----
// Check vectorization can handle cases where outputs are a mix of reduced and non-reduced values.
@@ -1134,3 +1579,13 @@ func.func @mixed_parallel_reduced_results(%arg0 : tensor<2x4x8xf32>,
// CHECK-DAG: %[[ADD:.+]] = vector.multi_reduction <add>, %[[MUL]], %[[V2]]
// CHECK-DAG: vector.transfer_write %[[MUL]], %[[ARG2]]
// CHECK-DAG: vector.transfer_write %[[ADD]], %[[ARG3]]
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1 = get_closest_isolated_parent %0
+ %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true }
+ }
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 576082a572e10..3949544cd1c10 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -225,9 +225,6 @@ static void applyPatterns(func::FuncOp funcOp) {
//===--------------------------------------------------------------------===//
// Linalg to vector contraction patterns.
//===--------------------------------------------------------------------===//
- patterns.add<LinalgVectorizationPattern>(
- ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE"))
- .addOpFilter<MatmulOp, FillOp, GenericOp>());
patterns.add<CopyVectorizationPattern>(ctx);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
@@ -441,9 +438,6 @@ static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) {
static void applyLinalgToVectorPatterns(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
auto *ctx = funcOp.getContext();
- patterns.add<LinalgVectorizationPattern>(
- ctx, LinalgTransformationFilter()
- .addOpFilter<ContractionOpInterface, FillOp, GenericOp>());
patterns.add<CopyVectorizationPattern>(ctx);
populatePadOpVectorizationPatterns(patterns);
populateConvolutionVectorizationPatterns(patterns);
More information about the Mlir-commits
mailing list