[Mlir-commits] [mlir] 0deeaac - [mlir] Move memref.subview patterns to MemRef/Transforms/

Lei Zhang llvmlistbot at llvm.org
Mon Apr 12 13:43:24 PDT 2021


Author: Lei Zhang
Date: 2021-04-12T16:38:22-04:00
New Revision: 0deeaaca399b381ddccffde71c921e7636be7fdc

URL: https://github.com/llvm/llvm-project/commit/0deeaaca399b381ddccffde71c921e7636be7fdc
DIFF: https://github.com/llvm/llvm-project/commit/0deeaaca399b381ddccffde71c921e7636be7fdc.diff

LOG: [mlir] Move memref.subview patterns to MemRef/Transforms/

These patterns have been used as a prerequisite step for lowering
to SPIR-V. But they don't involve SPIR-V dialect ops; they are
pure memref/vector op transformations. Given now we have a dedicated
MemRef dialect, moving them to Memref/Transforms/, which is a more
suitable place to host them, to allow used by others.

This commit just moves code around and renames patterns/passes
accordingly. CMakeLists.txt for existing MemRef libraries are
also improved along the way.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D100326

Added: 
    mlir/include/mlir/Dialect/MemRef/Transforms/CMakeLists.txt
    mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
    mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
    mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
    mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
    mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt
    mlir/test/Dialect/MemRef/fold-subview-ops.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
    mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h
    mlir/include/mlir/Dialect/MemRef/CMakeLists.txt
    mlir/include/mlir/InitAllPasses.h
    mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
    mlir/lib/Dialect/MemRef/CMakeLists.txt
    mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
    mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
    mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp

Removed: 
    mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
    mlir/test/Conversion/StandardToSPIRV/legalization.mlir
    mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d569af5235d62..6eb5abdefe552 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -410,17 +410,6 @@ def ConvertStandardToLLVM : Pass<"convert-std-to-llvm", "ModuleOp"> {
 // StandardToSPIRV
 //===----------------------------------------------------------------------===//
 
-def LegalizeStandardForSPIRV : Pass<"legalize-std-for-spirv"> {
-  let summary = "Legalize standard ops for SPIR-V lowering";
-  let description = [{
-    The pass contains certain intra standard op conversions that are meant for
-    lowering to SPIR-V ops, e.g., folding subviews loads/stores to the original
-    loads/stores from/to the original memref.
-  }];
-  let constructor = "mlir::createLegalizeStdOpsForSPIRVLoweringPass()";
-  let dependentDialects = ["spirv::SPIRVDialect"];
-}
-
 def ConvertStandardToSPIRV : Pass<"convert-std-to-spirv", "ModuleOp"> {
   let summary = "Convert Standard dialect to SPIR-V dialect";
   let constructor = "mlir::createConvertStandardToSPIRVPass()";

diff  --git a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h b/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
index 165ba0081b776..9091e28c9a7ea 100644
--- a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
+++ b/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
@@ -40,11 +40,6 @@ void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                    int64_t byteCountThreshold,
                                    RewritePatternSet &patterns);
 
-/// Appends to a pattern list patterns to legalize ops that are not directly
-/// lowered to SPIR-V.
-void populateStdLegalizationPatternsForSPIRVLowering(
-    RewritePatternSet &patterns);
-
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_STANDARDTOSPIRV_STANDARDTOSPIRV_H

diff  --git a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h b/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h
index e987527e56f5d..de8b474230fa2 100644
--- a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h
+++ b/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h
@@ -20,9 +20,6 @@ namespace mlir {
 /// Creates a pass to convert standard ops to SPIR-V ops.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertStandardToSPIRVPass();
 
-/// Creates a pass to legalize ops that are not directly lowered to SPIR-V.
-std::unique_ptr<Pass> createLegalizeStdOpsForSPIRVLoweringPass();
-
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_STANDARDTOSPIRV_STANDARDTOSPIRVPASS_H

diff  --git a/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt b/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt
index f33061b2d87cf..9f57627c321fb 100644
--- a/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt
@@ -1 +1,2 @@
 add_subdirectory(IR)
+add_subdirectory(Transforms)

diff  --git a/mlir/include/mlir/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/MemRef/Transforms/CMakeLists.txt
new file mode 100644
index 0000000000000..27a79d8ed3216
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name MemRef)
+add_public_tablegen_target(MLIRMemRefPassIncGen)
+add_dependencies(mlir-headers MLIRMemRefPassIncGen)
+
+add_mlir_doc(Passes -gen-pass-doc MemRefPasses ./)

diff  --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
new file mode 100644
index 0000000000000..1eae023a6df19
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -0,0 +1,47 @@
+//===- Passes.h - MemRef Patterns and Passes --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header declares patterns and passes on MemRef operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace memref {
+
+//===----------------------------------------------------------------------===//
+// Patterns
+//===----------------------------------------------------------------------===//
+
+/// Appends patterns for folding memref.subview ops into consumer load/store ops
+/// into `patterns`.
+void populateFoldSubViewOpPatterns(RewritePatternSet &patterns);
+
+//===----------------------------------------------------------------------===//
+// Passes
+//===----------------------------------------------------------------------===//
+
+/// Creates an operation pass to fold memref.subview ops into consumer
+/// load/store ops into `patterns`.
+std::unique_ptr<Pass> createFoldSubViewOpsPass();
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H

diff  --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
new file mode 100644
index 0000000000000..18be136ac6d18
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -0,0 +1,26 @@
+//===-- Passes.td - MemRef transformation definition file --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
+#define MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def FoldSubViewOps : Pass<"fold-memref-subview-ops"> {
+  let summary = "Fold memref.subview ops into consumer load/store ops";
+  let description = [{
+    The pass folds loading/storing from/to subview ops to loading/storing
+    from/to the original memref.
+  }];
+  let constructor = "mlir::memref::createFoldSubViewOpsPass()";
+  let dependentDialects = ["memref::MemRefDialect", "vector::VectorDialect"];
+}
+
+
+#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
+

diff  --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index ab9629ac86c15..cd22a96b24d05 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/GPU/Passes.h"
 #include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
 #include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/Quant/Passes.h"
 #include "mlir/Dialect/SCF/Passes.h"
 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
@@ -55,6 +56,7 @@ inline void registerAllPasses() {
   registerGpuSerializeToHsacoPass();
   registerLinalgPasses();
   LLVM::registerLLVMPasses();
+  memref::registerMemRefPasses();
   quant::registerQuantPasses();
   registerSCFPasses();
   registerShapePasses();

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
index 094934dd53adf..b296a2ef0ce83 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
@@ -1,5 +1,4 @@
 add_mlir_conversion_library(MLIRStandardToSPIRV
-  LegalizeStandardForSPIRV.cpp
   StandardToSPIRV.cpp
   StandardToSPIRVPass.cpp
 

diff  --git a/mlir/lib/Dialect/MemRef/CMakeLists.txt b/mlir/lib/Dialect/MemRef/CMakeLists.txt
index dc79a5087f8ec..31167e6af908b 100644
--- a/mlir/lib/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/CMakeLists.txt
@@ -1,23 +1,3 @@
-add_mlir_dialect_library(MLIRMemRef
-  IR/MemRefDialect.cpp
-  IR/MemRefOps.cpp
-  Utils/MemRefUtils.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect
-
-  DEPENDS
-  MLIRStandardOpsIncGen
-  MLIRMemRefOpsIncGen
-
-  LINK_COMPONENTS
-  Core
-
-  LINK_LIBS PUBLIC
-  MLIRDialect
-  MLIRInferTypeOpInterface
-  MLIRIR
-  MLIRStandard
-  MLIRTensor
-  MLIRViewLikeInterface
-)
+add_subdirectory(IR)
+add_subdirectory(Transforms)
+add_subdirectory(Utils)

diff  --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index f4ec6ea05f740..6ac47b11996a3 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRMemRef
   MLIRDialect
   MLIRInferTypeOpInterface
   MLIRIR
+  MLIRMemRefUtils
   MLIRStandard
   MLIRTensor
   MLIRViewLikeInterface

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
new file mode 100644
index 0000000000000..cb27354f36dfe
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRMemRefTransforms
+  FoldSubViewOps.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef
+
+  DEPENDS
+  MLIRMemRefPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRMemRef
+  MLIRPass
+  MLIRStandard
+  MLIRTransforms
+  MLIRVector
+)
+

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
similarity index 84%
rename from mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
rename to mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
index 02dc56c4054be..bd1fef099e668 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
@@ -1,4 +1,4 @@
-//===- LegalizeStandardForSPIRV.cpp - Legalize ops for SPIR-V lowering ----===//
+//===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,16 +6,13 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This transformation pass legalizes operations before the conversion to SPIR-V
-// dialect to handle ops that cannot be lowered directly.
+// This transformation pass folds loading/storing from/to subview ops into
+// loading/storing from/to the original memref.
 //
 //===----------------------------------------------------------------------===//
 
-#include "../PassDetail.h"
-#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h"
-#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -23,6 +20,49 @@
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+/// Given the 'indices' of an load/store operation where the memref is a result
+/// of a subview op, returns the indices w.r.t to the source memref of the
+/// subview op. For example
+///
+/// %0 = ... : memref<12x42xf32>
+/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
+///          memref<4x4xf32, offset=?, strides=[?, ?]>
+/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
+///
+/// could be folded into
+///
+/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
+///          memref<12x42xf32>
+static LogicalResult
+resolveSourceIndices(Location loc, PatternRewriter &rewriter,
+                     memref::SubViewOp subViewOp, ValueRange indices,
+                     SmallVectorImpl<Value> &sourceIndices) {
+  // TODO: Aborting when the offsets are static. There might be a way to fold
+  // the subview op with load even if the offsets have been canonicalized
+  // away.
+  SmallVector<Range, 4> opRanges = subViewOp.getOrCreateRanges(rewriter, loc);
+  auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; });
+  auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; });
+  assert(opRanges.size() == indices.size() &&
+         "expected as many indices as rank of subview op result type");
+
+  // New indices for the load are the current indices * subview_stride +
+  // subview_offset.
+  sourceIndices.resize(indices.size());
+  for (auto index : llvm::enumerate(indices)) {
+    auto offset = *(opOffsets.begin() + index.index());
+    auto stride = *(opStrides.begin() + index.index());
+    auto mul = rewriter.create<MulIOp>(loc, index.value(), stride);
+    sourceIndices[index.index()] =
+        rewriter.create<AddIOp>(loc, offset, mul).getResult();
+  }
+  return success();
+}
+
 /// Helpers to access the memref operand for each op.
 static Value getMemRefOperand(memref::LoadOp op) { return op.memref(); }
 
@@ -34,6 +74,10 @@ static Value getMemRefOperand(vector::TransferWriteOp op) {
   return op.source();
 }
 
+//===----------------------------------------------------------------------===//
+// Patterns
+//===----------------------------------------------------------------------===//
+
 namespace {
 /// Merges subview operation with load/transferRead operation.
 template <typename OpTy>
@@ -101,62 +145,15 @@ void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
 }
 } // namespace
 
-//===----------------------------------------------------------------------===//
-// Utility functions for op legalization.
-//===----------------------------------------------------------------------===//
-
-/// Given the 'indices' of an load/store operation where the memref is a result
-/// of a subview op, returns the indices w.r.t to the source memref of the
-/// subview op. For example
-///
-/// %0 = ... : memref<12x42xf32>
-/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
-///          memref<4x4xf32, offset=?, strides=[?, ?]>
-/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
-///
-/// could be folded into
-///
-/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
-///          memref<12x42xf32>
-static LogicalResult
-resolveSourceIndices(Location loc, PatternRewriter &rewriter,
-                     memref::SubViewOp subViewOp, ValueRange indices,
-                     SmallVectorImpl<Value> &sourceIndices) {
-  // TODO: Aborting when the offsets are static. There might be a way to fold
-  // the subview op with load even if the offsets have been canonicalized
-  // away.
-  SmallVector<Range, 4> opRanges = subViewOp.getOrCreateRanges(rewriter, loc);
-  auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; });
-  auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; });
-  assert(opRanges.size() == indices.size() &&
-         "expected as many indices as rank of subview op result type");
-
-  // New indices for the load are the current indices * subview_stride +
-  // subview_offset.
-  sourceIndices.resize(indices.size());
-  for (auto index : llvm::enumerate(indices)) {
-    auto offset = *(opOffsets.begin() + index.index());
-    auto stride = *(opStrides.begin() + index.index());
-    auto mul = rewriter.create<MulIOp>(loc, index.value(), stride);
-    sourceIndices[index.index()] =
-        rewriter.create<AddIOp>(loc, offset, mul).getResult();
-  }
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Folding SubViewOp and LoadOp/TransferReadOp.
-//===----------------------------------------------------------------------===//
-
 template <typename OpTy>
 LogicalResult
 LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
                                              PatternRewriter &rewriter) const {
   auto subViewOp =
       getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
-  if (!subViewOp) {
+  if (!subViewOp)
     return failure();
-  }
+
   SmallVector<Value, 4> sourceIndices;
   if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
                                   loadOp.indices(), sourceIndices)))
@@ -166,19 +163,15 @@ LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// Folding SubViewOp and StoreOp/TransferWriteOp.
-//===----------------------------------------------------------------------===//
-
 template <typename OpTy>
 LogicalResult
 StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
                                               PatternRewriter &rewriter) const {
   auto subViewOp =
       getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
-  if (!subViewOp) {
+  if (!subViewOp)
     return failure();
-  }
+
   SmallVector<Value, 4> sourceIndices;
   if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
                                   storeOp.indices(), sourceIndices)))
@@ -188,12 +181,7 @@ StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// Hook for adding patterns.
-//===----------------------------------------------------------------------===//
-
-void mlir::populateStdLegalizationPatternsForSPIRVLowering(
-    RewritePatternSet &patterns) {
+void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
   patterns.add<LoadOpOfSubViewFolder<memref::LoadOp>,
                LoadOpOfSubViewFolder<vector::TransferReadOp>,
                StoreOpOfSubViewFolder<memref::StoreOp>,
@@ -202,23 +190,28 @@ void mlir::populateStdLegalizationPatternsForSPIRVLowering(
 }
 
 //===----------------------------------------------------------------------===//
-// Pass for testing just the legalization patterns.
+// Pass registration
 //===----------------------------------------------------------------------===//
 
 namespace {
-struct SPIRVLegalization final
-    : public LegalizeStandardForSPIRVBase<SPIRVLegalization> {
+
+#define GEN_PASS_CLASSES
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+
+struct FoldSubViewOpsPass final
+    : public FoldSubViewOpsBase<FoldSubViewOpsPass> {
   void runOnOperation() override;
 };
+
 } // namespace
 
-void SPIRVLegalization::runOnOperation() {
+void FoldSubViewOpsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
-  populateStdLegalizationPatternsForSPIRVLowering(patterns);
+  memref::populateFoldSubViewOpPatterns(patterns);
   (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
                                      std::move(patterns));
 }
 
-std::unique_ptr<Pass> mlir::createLegalizeStdOpsForSPIRVLoweringPass() {
-  return std::make_unique<SPIRVLegalization>();
+std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() {
+  return std::make_unique<FoldSubViewOpsPass>();
 }

diff  --git a/mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt
new file mode 100644
index 0000000000000..17a6ecba09106
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_mlir_dialect_library(MLIRMemRefUtils
+  MemRefUtils.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRSideEffectInterfaces
+)
+

diff  --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 26a9a217134e2..eb9817014c2ff 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -11,7 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
 
 using namespace mlir;
 

diff  --git a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
deleted file mode 100644
index 942f827275f62..0000000000000
--- a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
+++ /dev/null
@@ -1,38 +0,0 @@
-// RUN: mlir-opt -legalize-std-for-spirv %s -o - | FileCheck %s
-
-module {
-
-//===----------------------------------------------------------------------===//
-// memref.subview
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: @fold_static_stride_subview
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<12x32xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: index
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: index
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: index
-// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]*]]: index
-func @fold_static_stride_subview
-       (%arg0 : memref<12x32xf32>, %arg1 : index,
-        %arg2 : index, %arg3 : index, %arg4 : index) {
-  // CHECK-DAG: %[[C2:.*]] = constant 2
-  // CHECK-DAG: %[[C3:.*]] = constant 3
-  //     CHECK: %[[T0:.*]] = muli %[[ARG3]], %[[C3]]
-  //     CHECK: %[[T1:.*]] = addi %[[ARG1]], %[[T0]]
-  //     CHECK: %[[T2:.*]] = muli %[[ARG4]], %[[ARG2]]
-  //     CHECK: %[[T3:.*]] = addi %[[T2]], %[[C2]]
-  //     CHECK: %[[LOADVAL:.*]] = memref.load %[[ARG0]][%[[T1]], %[[T3]]]
-  //     CHECK: %[[STOREVAL:.*]] = math.sqrt %[[LOADVAL]]
-  //     CHECK: %[[T6:.*]] = muli %[[ARG3]], %[[C3]]
-  //     CHECK: %[[T7:.*]] = addi %[[ARG1]], %[[T6]]
-  //     CHECK: %[[T8:.*]] = muli %[[ARG4]], %[[ARG2]]
-  //     CHECK: %[[T9:.*]] = addi %[[T8]], %[[C2]]
-  //     CHECK: memref.store %[[STOREVAL]], %[[ARG0]][%[[T7]], %[[T9]]]
-  %0 = memref.subview %arg0[%arg1, 2][4, 4][3, %arg2] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [96, ?]>
-  %1 = memref.load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [96, ?]>
-  %2 = math.sqrt %1 : f32
-  memref.store %2, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [96, ?]>
-  return
-}
-
-} // end module

diff  --git a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
similarity index 98%
rename from mlir/test/Conversion/StandardToSPIRV/legalization.mlir
rename to mlir/test/Dialect/MemRef/fold-subview-ops.mlir
index 8077edcc23689..2cddeb93dc301 100644
--- a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir
+++ b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -legalize-std-for-spirv -verify-diagnostics %s -o - | FileCheck %s
+// RUN: mlir-opt -fold-memref-subview-ops -verify-diagnostics %s -o - | FileCheck %s
 
 // CHECK-LABEL: @fold_static_stride_subview_with_load
 // CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index

diff  --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
index 1ef697768ba25..bfb63a63ef3ab 100644
--- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
+++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/GPU/Passes.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
@@ -40,7 +41,7 @@ static LogicalResult runMLIRPasses(ModuleOp module) {
   applyPassManagerCLOptions(passManager);
 
   passManager.addPass(createGpuKernelOutliningPass());
-  passManager.addPass(createLegalizeStdOpsForSPIRVLoweringPass());
+  passManager.addPass(memref::createFoldSubViewOpsPass());
   passManager.addPass(createConvertGPUToSPIRVPass());
   OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>();
   modulePM.addPass(spirv::createLowerABIAttributesPass());


        


More information about the Mlir-commits mailing list