[Mlir-commits] [mlir] c624027 - [mlir][linalg][TransformOps] Connect hoistRedundantVectorTransfers

Nicolas Vasilache llvmlistbot at llvm.org
Mon Feb 20 01:50:34 PST 2023


Author: Nicolas Vasilache
Date: 2023-02-20T01:50:29-08:00
New Revision: c62402763302f929736281ca7a4344b8ae809ead

URL: https://github.com/llvm/llvm-project/commit/c62402763302f929736281ca7a4344b8ae809ead
DIFF: https://github.com/llvm/llvm-project/commit/c62402763302f929736281ca7a4344b8ae809ead.diff

LOG: [mlir][linalg][TransformOps] Connect hoistRedundantVectorTransfers

Connect the hoistRedundantVectorTransfers functionality to the transform
dialect.

Authored-by: Quentin Colombet <quentin.colombet at gmail.com>

Differential Revision: https://reviews.llvm.org/D144260

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/test/Dialect/Linalg/hoisting.mlir
    mlir/test/lib/Dialect/Linalg/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    mlir/test/lib/Dialect/Linalg/TestLinalgHoisting.cpp


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 5473821e2927..755d7bfc0763 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H
 #define MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H
 
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"

diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 6dc616bf9d95..11f3b3c634fd 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1706,4 +1706,43 @@ def MaskedVectorizeOp : Op<Transform_Dialect, "structured.masked_vectorize",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// HoistRedundantVectorTransfersOp
+//===----------------------------------------------------------------------===//
+
+def HoistRedundantVectorTransfersOp :
+  Op<Transform_Dialect, "structured.hoist_redundant_vector_transfers",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformEachOpTrait, TransformOpInterface]> {
+  let description = [{
+    Hoist vector.transfer_read / vector.transfer_write pairs out of immediately
+    enclosing scf::ForOp iteratively, if the following conditions are true:
+       1. The 2 ops access the same memref with the same indices.
+       2. All operands are invariant under the enclosing scf::ForOp.
+       3. No uses of the memref either dominate the transfer_read or are
+       dominated by the transfer_write (i.e. no aliasing between the write and
+       the read across the loop)
+
+    #### Return modes:
+
+    The operation always succeeds and returns a handle to the transformed
+    function op.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>,
+  ];
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::func::FuncOp target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 #endif // LINALG_TRANSFORM_OPS

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
index 7e72ebf0bd45..eb97c6e168e5 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRLinalgTransformOps
   LINK_LIBS PUBLIC
   MLIRAffineDialect
   MLIRArithDialect
+  MLIRFuncDialect
   MLIRIR
   MLIRLinalgDialect
   MLIRLinalgTransforms

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 7fd290864b3b..59ba8f84559e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/PDL/IR/PDL.h"
@@ -3058,6 +3059,19 @@ SmallVector<OpFoldResult> MaskedVectorizeOp::getMixedVectorSizes() {
   return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
 }
 
+//===----------------------------------------------------------------------===//
+// HoistRedundantVectorTransfersOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::HoistRedundantVectorTransfersOp::applyToOne(
+    func::FuncOp target, transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  linalg::hoistRedundantVectorTransfers(target);
+  linalg::hoistRedundantVectorTransfersOnTensor(target);
+  results.push_back(target);
+  return DiagnosedSilenceableFailure::success();
+}
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index eac8ddcedf55..8830a4f42721 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-redundant-transfers -allow-unregistered-dialect -split-input-file | FileCheck %s
+// RUN: mlir-opt  -test-transform-dialect-interpreter --split-input-file --allow-unregistered-dialect %s | FileCheck %s
 
 // CHECK-LABEL: func @hoist_vector_transfer_pairs(
 //  CHECK-SAME:   %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
@@ -74,6 +74,14 @@ func.func @hoist_vector_transfer_pairs(
   return
 }
 
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1
+    : (!pdl.operation) -> !pdl.operation
+  transform.structured.hoist_redundant_vector_transfers %0
+    : (!pdl.operation) -> !pdl.operation
+}
+
 // -----
 
 // CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint(
@@ -155,6 +163,14 @@ func.func @hoist_vector_transfer_pairs_disjoint(
   return
 }
 
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1
+    : (!pdl.operation) -> !pdl.operation
+  transform.structured.hoist_redundant_vector_transfers %0
+    : (!pdl.operation) -> !pdl.operation
+}
+
 // -----
 
 // CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor
@@ -236,6 +252,14 @@ func.func @hoist_vector_transfer_pairs_tensor(
         tensor<?x?xf32>, tensor<?x?xf32>
 }
 
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1
+    : (!pdl.operation) -> !pdl.operation
+  transform.structured.hoist_redundant_vector_transfers %0
+    : (!pdl.operation) -> !pdl.operation
+}
+
 // -----
 
 // CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint_tensor(
@@ -323,6 +347,14 @@ func.func @hoist_vector_transfer_pairs_disjoint_tensor(
   return %0#0,  %0#1, %0#2, %0#3 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
 }
 
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1
+    : (!pdl.operation) -> !pdl.operation
+  transform.structured.hoist_redundant_vector_transfers %0
+    : (!pdl.operation) -> !pdl.operation
+}
+
 // -----
 
 // CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor_and_slices
@@ -432,6 +464,14 @@ func.func @hoist_vector_transfer_pairs_tensor_and_slices(
   return %0#0, %0#1, %0#2 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
 }
 
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1
+    : (!pdl.operation) -> !pdl.operation
+  transform.structured.hoist_redundant_vector_transfers %0
+    : (!pdl.operation) -> !pdl.operation
+}
+
 // -----
 
 // CHECK-LABEL: func @hoist_vector_transfer_write_pairs_disjoint_tensor(
@@ -469,6 +509,14 @@ func.func @hoist_vector_transfer_write_pairs_disjoint_tensor(
   return %1 : tensor<?x?xf32>
 }
 
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1
+    : (!pdl.operation) -> !pdl.operation
+  transform.structured.hoist_redundant_vector_transfers %0
+    : (!pdl.operation) -> !pdl.operation
+}
+
 // -----
 
 // CHECK-LABEL: func @hoist_vector_transfer_pairs_in_affine_loops(
@@ -505,3 +553,11 @@ func.func @hoist_vector_transfer_pairs_in_affine_loops(%memref0: memref<64x64xi3
   }
   return
 }
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1
+    : (!pdl.operation) -> !pdl.operation
+  transform.structured.hoist_redundant_vector_transfers %0
+    : (!pdl.operation) -> !pdl.operation
+}

diff  --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index 4640a2c10d22..03aefc8d7117 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -4,7 +4,6 @@ add_mlir_library(MLIRLinalgTestPasses
   TestLinalgDecomposeOps.cpp
   TestLinalgElementwiseFusion.cpp
   TestLinalgFusionTransforms.cpp
-  TestLinalgHoisting.cpp
   TestLinalgTransforms.cpp
   TestPadFusion.cpp
 

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgHoisting.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgHoisting.cpp
deleted file mode 100644
index 40e29ff1a304..000000000000
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgHoisting.cpp
+++ /dev/null
@@ -1,58 +0,0 @@
-//===- TestLinalgHoisting.cpp - Test Linalg hoisting functions ------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements logic for testing Linalg hoisting functions.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
-#include "mlir/Pass/Pass.h"
-
-using namespace mlir;
-using namespace mlir::linalg;
-
-namespace {
-struct TestLinalgHoisting
-    : public PassWrapper<TestLinalgHoisting, OperationPass<func::FuncOp>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgHoisting)
-
-  TestLinalgHoisting() = default;
-  TestLinalgHoisting(const TestLinalgHoisting &pass) : PassWrapper(pass) {}
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<AffineDialect>();
-  }
-  StringRef getArgument() const final { return "test-linalg-hoisting"; }
-  StringRef getDescription() const final {
-    return "Test Linalg hoisting functions.";
-  }
-
-  void runOnOperation() override;
-
-  Option<bool> testHoistRedundantTransfers{
-      *this, "test-hoist-redundant-transfers",
-      llvm::cl::desc("Test hoisting transfer_read/transfer_write pairs"),
-      llvm::cl::init(false)};
-};
-} // namespace
-
-void TestLinalgHoisting::runOnOperation() {
-  if (testHoistRedundantTransfers) {
-    hoistRedundantVectorTransfers(getOperation());
-    hoistRedundantVectorTransfersOnTensor(getOperation());
-    return;
-  }
-}
-
-namespace mlir {
-namespace test {
-void registerTestLinalgHoisting() { PassRegistration<TestLinalgHoisting>(); }
-} // namespace test
-} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 5a07004baf38..b56c883da587 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -95,7 +95,6 @@ void registerTestLastModifiedPass();
 void registerTestLinalgDecomposeOps();
 void registerTestLinalgElementwiseFusion();
 void registerTestLinalgGreedyFusion();
-void registerTestLinalgHoisting();
 void registerTestLinalgTransforms();
 void registerTestLivenessPass();
 void registerTestLoopFusion();
@@ -205,7 +204,6 @@ void registerTestPasses() {
   mlir::test::registerTestLinalgDecomposeOps();
   mlir::test::registerTestLinalgElementwiseFusion();
   mlir::test::registerTestLinalgGreedyFusion();
-  mlir::test::registerTestLinalgHoisting();
   mlir::test::registerTestLinalgTransforms();
   mlir::test::registerTestLivenessPass();
   mlir::test::registerTestLoopFusion();

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index c89dab4bf419..044cff23f661 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -8350,6 +8350,7 @@ cc_library(
         ":AsmParser",
         ":ControlFlowDialect",
         ":DialectUtils",
+        ":FuncDialect",
         ":GPUDialect",
         ":IR",
         ":LinalgDialect",


        


More information about the Mlir-commits mailing list