[Mlir-commits] [mlir] 7d0426d - [mlir] Move ComposeSubView+ExpandOps from Standard to MemRef

River Riddle llvmlistbot at llvm.org
Wed Jan 26 23:19:53 PST 2022


Author: River Riddle
Date: 2022-01-26T23:11:02-08:00
New Revision: 7d0426dd95440e087bf3026fc8cffa2d1f2eb5a8

URL: https://github.com/llvm/llvm-project/commit/7d0426dd95440e087bf3026fc8cffa2d1f2eb5a8
DIFF: https://github.com/llvm/llvm-project/commit/7d0426dd95440e087bf3026fc8cffa2d1f2eb5a8.diff

LOG: [mlir] Move ComposeSubView+ExpandOps from Standard to MemRef

These transformations already operate on memref operations (as part of
splitting up the standard dialect). Now that the operations have moved,
it's time for these transformations to move as well.

Differential Revision: https://reviews.llvm.org/D118285

Added: 
    mlir/include/mlir/Dialect/MemRef/Transforms/ComposeSubView.h
    mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
    mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
    mlir/test/Dialect/MemRef/expand-ops.mlir
    mlir/test/lib/Dialect/MemRef/CMakeLists.txt
    mlir/test/lib/Dialect/MemRef/TestComposeSubView.cpp

Modified: 
    mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
    mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
    mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
    mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
    mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
    mlir/lib/Dialect/MemRef/Transforms/PassDetail.h
    mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
    mlir/test/lib/Dialect/CMakeLists.txt
    mlir/test/lib/Dialect/StandardOps/CMakeLists.txt
    mlir/test/mlir-cpu-runner/memref-reshape.mlir
    mlir/tools/mlir-opt/CMakeLists.txt

Removed: 
    mlir/include/mlir/Dialect/StandardOps/Transforms/ComposeSubView.h
    mlir/lib/Dialect/StandardOps/Transforms/ComposeSubView.cpp
    mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
    mlir/test/Dialect/Standard/expand-ops.mlir
    mlir/test/lib/Dialect/StandardOps/TestComposeSubView.cpp


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/ComposeSubView.h b/mlir/include/mlir/Dialect/MemRef/Transforms/ComposeSubView.h
similarity index 71%
rename from mlir/include/mlir/Dialect/StandardOps/Transforms/ComposeSubView.h
rename to mlir/include/mlir/Dialect/MemRef/Transforms/ComposeSubView.h
index 7a5ae3e8417b7..20aa1c02db178 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/ComposeSubView.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/ComposeSubView.h
@@ -1,4 +1,4 @@
-//===- ComposeSubView.h - Combining composed subview ops --------*- C++ -*-===//
+//===- ComposeSubView.h - Combining composed memref ops ---------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -10,19 +10,20 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_STANDARDOPS_TRANSFORMS_COMPOSESUBVIEW_H_
-#define MLIR_DIALECT_STANDARDOPS_TRANSFORMS_COMPOSESUBVIEW_H_
+#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_COMPOSESUBVIEW_H_
+#define MLIR_DIALECT_MEMREF_TRANSFORMS_COMPOSESUBVIEW_H_
 
 namespace mlir {
-
-// Forward declarations.
 class MLIRContext;
 class RewritePatternSet;
 using OwningRewritePatternList = RewritePatternSet;
 
+namespace memref {
+
 void populateComposeSubViewPatterns(OwningRewritePatternList &patterns,
                                     MLIRContext *context);
 
+} // namespace memref
 } // namespace mlir
 
-#endif // MLIR_DIALECT_STANDARDOPS_TRANSFORMS_COMPOSESUBVIEW_H_
+#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_COMPOSESUBVIEW_H_

diff  --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index 23d12508b65cb..6e9f6fb10a665 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -18,6 +18,7 @@
 namespace mlir {
 
 class AffineDialect;
+class StandardOpsDialect;
 namespace tensor {
 class TensorDialect;
 } // namespace tensor
@@ -31,6 +32,9 @@ namespace memref {
 // Patterns
 //===----------------------------------------------------------------------===//
 
+/// Collects a set of patterns to rewrite ops within the memref dialect.
+void populateExpandOpsPatterns(RewritePatternSet &patterns);
+
 /// Appends patterns for folding memref.subview ops into consumer load/store ops
 /// into `patterns`.
 void populateFoldSubViewOpPatterns(RewritePatternSet &patterns);
@@ -51,6 +55,11 @@ void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
 // Passes
 //===----------------------------------------------------------------------===//
 
+/// Creates an instance of the ExpandOps pass that legalizes memref dialect ops
+/// to be convertible to LLVM. For example, `memref.reshape` gets converted to
+/// `memref_reinterpret_cast`.
+std::unique_ptr<Pass> createExpandOpsPass();
+
 /// Creates an operation pass to fold memref.subview ops into consumer
 /// load/store ops into `patterns`.
 std::unique_ptr<Pass> createFoldSubViewOpsPass();

diff  --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index d67746b9c6033..0f2e3a91a2554 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -11,6 +11,12 @@
 
 include "mlir/Pass/PassBase.td"
 
+def ExpandOps : Pass<"memref-expand", "FuncOp"> {
+  let summary = "Legalize memref operations to be convertible to LLVM.";
+  let constructor = "mlir::memref::createExpandOpsPass()";
+  let dependentDialects = ["StandardOpsDialect"];
+}
+
 def FoldSubViewOps : Pass<"fold-memref-subview-ops"> {
   let summary = "Fold memref.subview ops into consumer load/store ops";
   let description = [{

diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
index b47303c70250e..3f723133b2a6e 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
@@ -45,16 +45,6 @@ void populateTensorConstantBufferizePatterns(
 /// Creates an instance of tensor constant bufferization pass.
 std::unique_ptr<Pass> createTensorConstantBufferizePass(unsigned alignment = 0);
 
-/// Creates an instance of the StdExpand pass that legalizes Std
-/// dialect ops to be convertible to LLVM. For example,
-/// `std.arith.ceildivsi` gets transformed to a number of std operations,
-/// which can be lowered to LLVM; `memref.reshape` gets converted to
-/// `memref_reinterpret_cast`.
-std::unique_ptr<Pass> createStdExpandOpsPass();
-
-/// Collects a set of patterns to rewrite ops within the Std dialect.
-void populateStdExpandOpsPatterns(RewritePatternSet &patterns);
-
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
index 8ac2b64a8006d..339c1b1194cce 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
@@ -18,11 +18,6 @@ def StdBufferize : Pass<"std-bufferize", "FuncOp"> {
                            "memref::MemRefDialect", "scf::SCFDialect"];
 }
 
-def StdExpandOps : Pass<"std-expand", "FuncOp"> {
-  let summary = "Legalize std operations to be convertible to LLVM.";
-  let constructor = "mlir::createStdExpandOpsPass()";
-}
-
 def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> {
   let summary = "Bufferize func/call/return ops";
   let description = [{

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index 319f9bbb95a37..99bc552548f90 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -1,4 +1,6 @@
 add_mlir_dialect_library(MLIRMemRefTransforms
+  ComposeSubView.cpp
+  ExpandOps.cpp
   FoldSubViewOps.cpp
   NormalizeMemRefs.cpp
   ResolveShapedTypeResultDims.cpp

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
similarity index 96%
rename from mlir/lib/Dialect/StandardOps/Transforms/ComposeSubView.cpp
rename to mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
index cabaa614fa2a9..29ba5060d167d 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/ComposeSubView.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
@@ -11,8 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/StandardOps/Transforms/ComposeSubView.h"
-
+#include "mlir/Dialect/MemRef/Transforms/ComposeSubView.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -21,7 +20,7 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
-namespace mlir {
+using namespace mlir;
 
 namespace {
 
@@ -128,9 +127,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
 
 } // namespace
 
-void populateComposeSubViewPatterns(OwningRewritePatternList &patterns,
-                                    MLIRContext *context) {
+void mlir::memref::populateComposeSubViewPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context) {
   patterns.insert<ComposeSubViewOpPattern>(context);
 }
-
-} // namespace mlir

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
similarity index 93%
rename from mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
rename to mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index e5cf08da4904a..293fb58d4e701 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -17,8 +17,8 @@
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -120,13 +120,13 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
   }
 };
 
-struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
+struct ExpandOpsPass : public ExpandOpsBase<ExpandOpsPass> {
   void runOnOperation() override {
     MLIRContext &ctx = getContext();
 
     RewritePatternSet patterns(&ctx);
-    populateStdExpandOpsPatterns(patterns);
-    ConversionTarget target(getContext());
+    memref::populateExpandOpsPatterns(patterns);
+    ConversionTarget target(ctx);
 
     target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect,
                            StandardOpsDialect>();
@@ -146,11 +146,11 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
 
 } // namespace
 
-void mlir::populateStdExpandOpsPatterns(RewritePatternSet &patterns) {
+void mlir::memref::populateExpandOpsPatterns(RewritePatternSet &patterns) {
   patterns.add<AtomicRMWOpConverter, MemRefReshapeOpConverter>(
       patterns.getContext());
 }
 
-std::unique_ptr<Pass> mlir::createStdExpandOpsPass() {
-  return std::make_unique<StdExpandOpsPass>();
+std::unique_ptr<Pass> mlir::memref::createExpandOpsPass() {
+  return std::make_unique<ExpandOpsPass>();
 }

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/PassDetail.h b/mlir/lib/Dialect/MemRef/Transforms/PassDetail.h
index d15631526817f..d1e5baa798fd1 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/PassDetail.h
+++ b/mlir/lib/Dialect/MemRef/Transforms/PassDetail.h
@@ -14,6 +14,7 @@
 namespace mlir {
 
 class AffineDialect;
+class StandardOpsDialect;
 
 // Forward declaration from Dialect.h
 template <typename ConcreteDialect>

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
index 82e25923840dd..f8082601b48b3 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
@@ -1,8 +1,6 @@
 add_mlir_dialect_library(MLIRStandardOpsTransforms
   Bufferize.cpp
-  ComposeSubView.cpp
   DecomposeCallGraphTypes.cpp
-  ExpandOps.cpp
   FuncBufferize.cpp
   FuncConversions.cpp
   TensorConstantBufferize.cpp

diff  --git a/mlir/test/Dialect/Standard/expand-ops.mlir b/mlir/test/Dialect/MemRef/expand-ops.mlir
similarity index 96%
rename from mlir/test/Dialect/Standard/expand-ops.mlir
rename to mlir/test/Dialect/MemRef/expand-ops.mlir
index 2a1c367ff80f0..bcf83042184f3 100644
--- a/mlir/test/Dialect/Standard/expand-ops.mlir
+++ b/mlir/test/Dialect/MemRef/expand-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -std-expand %s -split-input-file | FileCheck %s
+// RUN: mlir-opt -memref-expand %s -split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @atomic_rmw_to_generic
 // CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)

diff  --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index d8219057f9391..d02dc61045fae 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -3,6 +3,7 @@ add_subdirectory(DLTI)
 add_subdirectory(GPU)
 add_subdirectory(Linalg)
 add_subdirectory(Math)
+add_subdirectory(MemRef)
 add_subdirectory(SCF)
 add_subdirectory(Shape)
 add_subdirectory(SPIRV)

diff  --git a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt
new file mode 100644
index 0000000000000..c43dec48bdabf
--- /dev/null
+++ b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt
@@ -0,0 +1,17 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRMemRefTestPasses
+  TestComposeSubView.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  LINK_LIBS PUBLIC
+  MLIRPass
+  MLIRMemRefTransforms
+  MLIRTestDialect
+  )
+
+target_include_directories(MLIRMemRefTestPasses
+  PRIVATE
+  ${CMAKE_CURRENT_SOURCE_DIR}/../Test
+  ${CMAKE_CURRENT_BINARY_DIR}/../Test
+  )

diff  --git a/mlir/test/lib/Dialect/StandardOps/TestComposeSubView.cpp b/mlir/test/lib/Dialect/MemRef/TestComposeSubView.cpp
similarity index 92%
rename from mlir/test/lib/Dialect/StandardOps/TestComposeSubView.cpp
rename to mlir/test/lib/Dialect/MemRef/TestComposeSubView.cpp
index 1638ee3debe0a..20add4cc94c8b 100644
--- a/mlir/test/lib/Dialect/StandardOps/TestComposeSubView.cpp
+++ b/mlir/test/lib/Dialect/MemRef/TestComposeSubView.cpp
@@ -11,7 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/StandardOps/Transforms/ComposeSubView.h"
+#include "mlir/Dialect/MemRef/Transforms/ComposeSubView.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
@@ -35,7 +35,7 @@ void TestComposeSubViewPass::getDependentDialects(
 
 void TestComposeSubViewPass::runOnOperation() {
   OwningRewritePatternList patterns(&getContext());
-  populateComposeSubViewPatterns(patterns, &getContext());
+  memref::populateComposeSubViewPatterns(patterns, &getContext());
   (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
 }
 } // namespace

diff  --git a/mlir/test/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/test/lib/Dialect/StandardOps/CMakeLists.txt
index 22d5818e34133..b85de09e10bf4 100644
--- a/mlir/test/lib/Dialect/StandardOps/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/StandardOps/CMakeLists.txt
@@ -1,7 +1,6 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRStandardOpsTestPasses
   TestDecomposeCallGraphTypes.cpp
-  TestComposeSubView.cpp
 
   EXCLUDE_FROM_LIBMLIR
 

diff  --git a/mlir/test/mlir-cpu-runner/memref-reshape.mlir b/mlir/test/mlir-cpu-runner/memref-reshape.mlir
index 6d0397399ccaf..e74d6219a1f33 100644
--- a/mlir/test/mlir-cpu-runner/memref-reshape.mlir
+++ b/mlir/test/mlir-cpu-runner/memref-reshape.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-scf-to-std -std-expand -convert-arith-to-llvm -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts \
+// RUN: mlir-opt %s -convert-scf-to-std -memref-expand -convert-arith-to-llvm -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts \
 // RUN: | mlir-cpu-runner -e main -entry-point-result=void \
 // RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
 // RUN: | FileCheck %s

diff  --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 7a16a497b6f86..c03d6403a74eb 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -18,6 +18,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRGPUTestPasses
     MLIRLinalgTestPasses
     MLIRMathTestPasses
+    MLIRMemRefTestPasses
     MLIRSCFTestPasses
     MLIRShapeTestPasses
     MLIRSPIRVTestPasses


        


More information about the Mlir-commits mailing list