[Mlir-commits] [mlir] c624027 - [mlir][linalg][TransformOps] Connect hoistRedundantVectorTransfers
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Feb 20 01:50:34 PST 2023
Author: Nicolas Vasilache
Date: 2023-02-20T01:50:29-08:00
New Revision: c62402763302f929736281ca7a4344b8ae809ead
URL: https://github.com/llvm/llvm-project/commit/c62402763302f929736281ca7a4344b8ae809ead
DIFF: https://github.com/llvm/llvm-project/commit/c62402763302f929736281ca7a4344b8ae809ead.diff
LOG: [mlir][linalg][TransformOps] Connect hoistRedundantVectorTransfers
Connect the hoistRedundantVectorTransfers functionality to the transform
dialect.
Authored-by: Quentin Colombet <quentin.colombet at gmail.com>
Differential Revision: https://reviews.llvm.org/D144260
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/hoisting.mlir
mlir/test/lib/Dialect/Linalg/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
mlir/test/lib/Dialect/Linalg/TestLinalgHoisting.cpp
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 5473821e2927..755d7bfc0763 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H
#define MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 6dc616bf9d95..11f3b3c634fd 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1706,4 +1706,43 @@ def MaskedVectorizeOp : Op<Transform_Dialect, "structured.masked_vectorize",
}];
}
+//===----------------------------------------------------------------------===//
+// HoistRedundantVectorTransfersOp
+//===----------------------------------------------------------------------===//
+
+def HoistRedundantVectorTransfersOp :
+ Op<Transform_Dialect, "structured.hoist_redundant_vector_transfers",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformEachOpTrait, TransformOpInterface]> {
+ let description = [{
+ Hoist vector.transfer_read / vector.transfer_write pairs out of immediately
+ enclosing scf::ForOp iteratively, if the following conditions are true:
+ 1. The 2 ops access the same memref with the same indices.
+ 2. All operands are invariant under the enclosing scf::ForOp.
+ 3. No uses of the memref either dominate the transfer_read or are
+ dominated by the transfer_write (i.e. no aliasing between the write and
+ the read across the loop)
+
+ #### Return modes:
+
+ The operation always succeeds and returns a handle to the transformed
+ function op.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+
+ let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
+
+ let builders = [
+ OpBuilder<(ins "Value":$target)>,
+ ];
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::func::FuncOp target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
#endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
index 7e72ebf0bd45..eb97c6e168e5 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRLinalgTransformOps
LINK_LIBS PUBLIC
MLIRAffineDialect
MLIRArithDialect
+ MLIRFuncDialect
MLIRIR
MLIRLinalgDialect
MLIRLinalgTransforms
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 7fd290864b3b..59ba8f84559e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
@@ -3058,6 +3059,19 @@ SmallVector<OpFoldResult> MaskedVectorizeOp::getMixedVectorSizes() {
return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
}
+//===----------------------------------------------------------------------===//
+// HoistRedundantVectorTransfersOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::HoistRedundantVectorTransfersOp::applyToOne(
+ func::FuncOp target, transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ linalg::hoistRedundantVectorTransfers(target);
+ linalg::hoistRedundantVectorTransfersOnTensor(target);
+ results.push_back(target);
+ return DiagnosedSilenceableFailure::success();
+}
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index eac8ddcedf55..8830a4f42721 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-redundant-transfers -allow-unregistered-dialect -split-input-file | FileCheck %s
+// RUN: mlir-opt -test-transform-dialect-interpreter --split-input-file --allow-unregistered-dialect %s | FileCheck %s
// CHECK-LABEL: func @hoist_vector_transfer_pairs(
// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
@@ -74,6 +74,14 @@ func.func @hoist_vector_transfer_pairs(
return
}
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!pdl.operation) -> !pdl.operation
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!pdl.operation) -> !pdl.operation
+}
+
// -----
// CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint(
@@ -155,6 +163,14 @@ func.func @hoist_vector_transfer_pairs_disjoint(
return
}
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!pdl.operation) -> !pdl.operation
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!pdl.operation) -> !pdl.operation
+}
+
// -----
// CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor
@@ -236,6 +252,14 @@ func.func @hoist_vector_transfer_pairs_tensor(
tensor<?x?xf32>, tensor<?x?xf32>
}
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!pdl.operation) -> !pdl.operation
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!pdl.operation) -> !pdl.operation
+}
+
// -----
// CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint_tensor(
@@ -323,6 +347,14 @@ func.func @hoist_vector_transfer_pairs_disjoint_tensor(
return %0#0, %0#1, %0#2, %0#3 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
}
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!pdl.operation) -> !pdl.operation
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!pdl.operation) -> !pdl.operation
+}
+
// -----
// CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor_and_slices
@@ -432,6 +464,14 @@ func.func @hoist_vector_transfer_pairs_tensor_and_slices(
return %0#0, %0#1, %0#2 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
}
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!pdl.operation) -> !pdl.operation
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!pdl.operation) -> !pdl.operation
+}
+
// -----
// CHECK-LABEL: func @hoist_vector_transfer_write_pairs_disjoint_tensor(
@@ -469,6 +509,14 @@ func.func @hoist_vector_transfer_write_pairs_disjoint_tensor(
return %1 : tensor<?x?xf32>
}
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!pdl.operation) -> !pdl.operation
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!pdl.operation) -> !pdl.operation
+}
+
// -----
// CHECK-LABEL: func @hoist_vector_transfer_pairs_in_affine_loops(
@@ -505,3 +553,11 @@ func.func @hoist_vector_transfer_pairs_in_affine_loops(%memref0: memref<64x64xi3
}
return
}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!pdl.operation) -> !pdl.operation
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!pdl.operation) -> !pdl.operation
+}
diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index 4640a2c10d22..03aefc8d7117 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -4,7 +4,6 @@ add_mlir_library(MLIRLinalgTestPasses
TestLinalgDecomposeOps.cpp
TestLinalgElementwiseFusion.cpp
TestLinalgFusionTransforms.cpp
- TestLinalgHoisting.cpp
TestLinalgTransforms.cpp
TestPadFusion.cpp
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgHoisting.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgHoisting.cpp
deleted file mode 100644
index 40e29ff1a304..000000000000
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgHoisting.cpp
+++ /dev/null
@@ -1,58 +0,0 @@
-//===- TestLinalgHoisting.cpp - Test Linalg hoisting functions ------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements logic for testing Linalg hoisting functions.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
-#include "mlir/Pass/Pass.h"
-
-using namespace mlir;
-using namespace mlir::linalg;
-
-namespace {
-struct TestLinalgHoisting
- : public PassWrapper<TestLinalgHoisting, OperationPass<func::FuncOp>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgHoisting)
-
- TestLinalgHoisting() = default;
- TestLinalgHoisting(const TestLinalgHoisting &pass) : PassWrapper(pass) {}
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<AffineDialect>();
- }
- StringRef getArgument() const final { return "test-linalg-hoisting"; }
- StringRef getDescription() const final {
- return "Test Linalg hoisting functions.";
- }
-
- void runOnOperation() override;
-
- Option<bool> testHoistRedundantTransfers{
- *this, "test-hoist-redundant-transfers",
- llvm::cl::desc("Test hoisting transfer_read/transfer_write pairs"),
- llvm::cl::init(false)};
-};
-} // namespace
-
-void TestLinalgHoisting::runOnOperation() {
- if (testHoistRedundantTransfers) {
- hoistRedundantVectorTransfers(getOperation());
- hoistRedundantVectorTransfersOnTensor(getOperation());
- return;
- }
-}
-
-namespace mlir {
-namespace test {
-void registerTestLinalgHoisting() { PassRegistration<TestLinalgHoisting>(); }
-} // namespace test
-} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 5a07004baf38..b56c883da587 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -95,7 +95,6 @@ void registerTestLastModifiedPass();
void registerTestLinalgDecomposeOps();
void registerTestLinalgElementwiseFusion();
void registerTestLinalgGreedyFusion();
-void registerTestLinalgHoisting();
void registerTestLinalgTransforms();
void registerTestLivenessPass();
void registerTestLoopFusion();
@@ -205,7 +204,6 @@ void registerTestPasses() {
mlir::test::registerTestLinalgDecomposeOps();
mlir::test::registerTestLinalgElementwiseFusion();
mlir::test::registerTestLinalgGreedyFusion();
- mlir::test::registerTestLinalgHoisting();
mlir::test::registerTestLinalgTransforms();
mlir::test::registerTestLivenessPass();
mlir::test::registerTestLoopFusion();
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index c89dab4bf419..044cff23f661 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -8350,6 +8350,7 @@ cc_library(
":AsmParser",
":ControlFlowDialect",
":DialectUtils",
+ ":FuncDialect",
":GPUDialect",
":IR",
":LinalgDialect",
More information about the Mlir-commits
mailing list