[Mlir-commits] [mlir] 0deeaac - [mlir] Move memref.subview patterns to MemRef/Transforms/
Lei Zhang
llvmlistbot at llvm.org
Mon Apr 12 13:43:24 PDT 2021
Author: Lei Zhang
Date: 2021-04-12T16:38:22-04:00
New Revision: 0deeaaca399b381ddccffde71c921e7636be7fdc
URL: https://github.com/llvm/llvm-project/commit/0deeaaca399b381ddccffde71c921e7636be7fdc
DIFF: https://github.com/llvm/llvm-project/commit/0deeaaca399b381ddccffde71c921e7636be7fdc.diff
LOG: [mlir] Move memref.subview patterns to MemRef/Transforms/
These patterns have been used as a prerequisite step for lowering
to SPIR-V. But they don't involve SPIR-V dialect ops; they are
pure memref/vector op transformations. Given now we have a dedicated
MemRef dialect, moving them to Memref/Transforms/, which is a more
suitable place to host them, to allow used by others.
This commit just moves code around and renames patterns/passes
accordingly. CMakeLists.txt for existing MemRef libraries are
also improved along the way.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D100326
Added:
mlir/include/mlir/Dialect/MemRef/Transforms/CMakeLists.txt
mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt
mlir/test/Dialect/MemRef/fold-subview-ops.mlir
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h
mlir/include/mlir/Dialect/MemRef/CMakeLists.txt
mlir/include/mlir/InitAllPasses.h
mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
mlir/lib/Dialect/MemRef/CMakeLists.txt
mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
Removed:
mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
mlir/test/Conversion/StandardToSPIRV/legalization.mlir
mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d569af5235d62..6eb5abdefe552 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -410,17 +410,6 @@ def ConvertStandardToLLVM : Pass<"convert-std-to-llvm", "ModuleOp"> {
// StandardToSPIRV
//===----------------------------------------------------------------------===//
-def LegalizeStandardForSPIRV : Pass<"legalize-std-for-spirv"> {
- let summary = "Legalize standard ops for SPIR-V lowering";
- let description = [{
- The pass contains certain intra standard op conversions that are meant for
- lowering to SPIR-V ops, e.g., folding subviews loads/stores to the original
- loads/stores from/to the original memref.
- }];
- let constructor = "mlir::createLegalizeStdOpsForSPIRVLoweringPass()";
- let dependentDialects = ["spirv::SPIRVDialect"];
-}
-
def ConvertStandardToSPIRV : Pass<"convert-std-to-spirv", "ModuleOp"> {
let summary = "Convert Standard dialect to SPIR-V dialect";
let constructor = "mlir::createConvertStandardToSPIRVPass()";
diff --git a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h b/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
index 165ba0081b776..9091e28c9a7ea 100644
--- a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
+++ b/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
@@ -40,11 +40,6 @@ void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
int64_t byteCountThreshold,
RewritePatternSet &patterns);
-/// Appends to a pattern list patterns to legalize ops that are not directly
-/// lowered to SPIR-V.
-void populateStdLegalizationPatternsForSPIRVLowering(
- RewritePatternSet &patterns);
-
} // namespace mlir
#endif // MLIR_CONVERSION_STANDARDTOSPIRV_STANDARDTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h b/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h
index e987527e56f5d..de8b474230fa2 100644
--- a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h
+++ b/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h
@@ -20,9 +20,6 @@ namespace mlir {
/// Creates a pass to convert standard ops to SPIR-V ops.
std::unique_ptr<OperationPass<ModuleOp>> createConvertStandardToSPIRVPass();
-/// Creates a pass to legalize ops that are not directly lowered to SPIR-V.
-std::unique_ptr<Pass> createLegalizeStdOpsForSPIRVLoweringPass();
-
} // namespace mlir
#endif // MLIR_CONVERSION_STANDARDTOSPIRV_STANDARDTOSPIRVPASS_H
diff --git a/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt b/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt
index f33061b2d87cf..9f57627c321fb 100644
--- a/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt
@@ -1 +1,2 @@
add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/MemRef/Transforms/CMakeLists.txt
new file mode 100644
index 0000000000000..27a79d8ed3216
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name MemRef)
+add_public_tablegen_target(MLIRMemRefPassIncGen)
+add_dependencies(mlir-headers MLIRMemRefPassIncGen)
+
+add_mlir_doc(Passes -gen-pass-doc MemRefPasses ./)
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
new file mode 100644
index 0000000000000..1eae023a6df19
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -0,0 +1,47 @@
+//===- Passes.h - MemRef Patterns and Passes --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header declares patterns and passes on MemRef operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace memref {
+
+//===----------------------------------------------------------------------===//
+// Patterns
+//===----------------------------------------------------------------------===//
+
+/// Appends patterns for folding memref.subview ops into consumer load/store ops
+/// into `patterns`.
+void populateFoldSubViewOpPatterns(RewritePatternSet &patterns);
+
+//===----------------------------------------------------------------------===//
+// Passes
+//===----------------------------------------------------------------------===//
+
+/// Creates an operation pass to fold memref.subview ops into consumer
+/// load/store ops into `patterns`.
+std::unique_ptr<Pass> createFoldSubViewOpsPass();
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
new file mode 100644
index 0000000000000..18be136ac6d18
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -0,0 +1,26 @@
+//===-- Passes.td - MemRef transformation definition file --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
+#define MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def FoldSubViewOps : Pass<"fold-memref-subview-ops"> {
+ let summary = "Fold memref.subview ops into consumer load/store ops";
+ let description = [{
+ The pass folds loading/storing from/to subview ops to loading/storing
+ from/to the original memref.
+ }];
+ let constructor = "mlir::memref::createFoldSubViewOpsPass()";
+ let dependentDialects = ["memref::MemRefDialect", "vector::VectorDialect"];
+}
+
+
+#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
+
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index ab9629ac86c15..cd22a96b24d05 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -20,6 +20,7 @@
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/Quant/Passes.h"
#include "mlir/Dialect/SCF/Passes.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
@@ -55,6 +56,7 @@ inline void registerAllPasses() {
registerGpuSerializeToHsacoPass();
registerLinalgPasses();
LLVM::registerLLVMPasses();
+ memref::registerMemRefPasses();
quant::registerQuantPasses();
registerSCFPasses();
registerShapePasses();
diff --git a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
index 094934dd53adf..b296a2ef0ce83 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
@@ -1,5 +1,4 @@
add_mlir_conversion_library(MLIRStandardToSPIRV
- LegalizeStandardForSPIRV.cpp
StandardToSPIRV.cpp
StandardToSPIRVPass.cpp
diff --git a/mlir/lib/Dialect/MemRef/CMakeLists.txt b/mlir/lib/Dialect/MemRef/CMakeLists.txt
index dc79a5087f8ec..31167e6af908b 100644
--- a/mlir/lib/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/CMakeLists.txt
@@ -1,23 +1,3 @@
-add_mlir_dialect_library(MLIRMemRef
- IR/MemRefDialect.cpp
- IR/MemRefOps.cpp
- Utils/MemRefUtils.cpp
-
- ADDITIONAL_HEADER_DIRS
- ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect
-
- DEPENDS
- MLIRStandardOpsIncGen
- MLIRMemRefOpsIncGen
-
- LINK_COMPONENTS
- Core
-
- LINK_LIBS PUBLIC
- MLIRDialect
- MLIRInferTypeOpInterface
- MLIRIR
- MLIRStandard
- MLIRTensor
- MLIRViewLikeInterface
-)
+add_subdirectory(IR)
+add_subdirectory(Transforms)
+add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index f4ec6ea05f740..6ac47b11996a3 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRMemRef
MLIRDialect
MLIRInferTypeOpInterface
MLIRIR
+ MLIRMemRefUtils
MLIRStandard
MLIRTensor
MLIRViewLikeInterface
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
new file mode 100644
index 0000000000000..cb27354f36dfe
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRMemRefTransforms
+ FoldSubViewOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef
+
+ DEPENDS
+ MLIRMemRefPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRMemRef
+ MLIRPass
+ MLIRStandard
+ MLIRTransforms
+ MLIRVector
+)
+
diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
similarity index 84%
rename from mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
rename to mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
index 02dc56c4054be..bd1fef099e668 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
@@ -1,4 +1,4 @@
-//===- LegalizeStandardForSPIRV.cpp - Legalize ops for SPIR-V lowering ----===//
+//===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,16 +6,13 @@
//
//===----------------------------------------------------------------------===//
//
-// This transformation pass legalizes operations before the conversion to SPIR-V
-// dialect to handle ops that cannot be lowered directly.
+// This transformation pass folds loading/storing from/to subview ops into
+// loading/storing from/to the original memref.
//
//===----------------------------------------------------------------------===//
-#include "../PassDetail.h"
-#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h"
-#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -23,6 +20,49 @@
using namespace mlir;
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+/// Given the 'indices' of an load/store operation where the memref is a result
+/// of a subview op, returns the indices w.r.t to the source memref of the
+/// subview op. For example
+///
+/// %0 = ... : memref<12x42xf32>
+/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
+/// memref<4x4xf32, offset=?, strides=[?, ?]>
+/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
+///
+/// could be folded into
+///
+/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
+/// memref<12x42xf32>
+static LogicalResult
+resolveSourceIndices(Location loc, PatternRewriter &rewriter,
+ memref::SubViewOp subViewOp, ValueRange indices,
+ SmallVectorImpl<Value> &sourceIndices) {
+ // TODO: Aborting when the offsets are static. There might be a way to fold
+ // the subview op with load even if the offsets have been canonicalized
+ // away.
+ SmallVector<Range, 4> opRanges = subViewOp.getOrCreateRanges(rewriter, loc);
+ auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; });
+ auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; });
+ assert(opRanges.size() == indices.size() &&
+ "expected as many indices as rank of subview op result type");
+
+ // New indices for the load are the current indices * subview_stride +
+ // subview_offset.
+ sourceIndices.resize(indices.size());
+ for (auto index : llvm::enumerate(indices)) {
+ auto offset = *(opOffsets.begin() + index.index());
+ auto stride = *(opStrides.begin() + index.index());
+ auto mul = rewriter.create<MulIOp>(loc, index.value(), stride);
+ sourceIndices[index.index()] =
+ rewriter.create<AddIOp>(loc, offset, mul).getResult();
+ }
+ return success();
+}
+
/// Helpers to access the memref operand for each op.
static Value getMemRefOperand(memref::LoadOp op) { return op.memref(); }
@@ -34,6 +74,10 @@ static Value getMemRefOperand(vector::TransferWriteOp op) {
return op.source();
}
+//===----------------------------------------------------------------------===//
+// Patterns
+//===----------------------------------------------------------------------===//
+
namespace {
/// Merges subview operation with load/transferRead operation.
template <typename OpTy>
@@ -101,62 +145,15 @@ void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
}
} // namespace
-//===----------------------------------------------------------------------===//
-// Utility functions for op legalization.
-//===----------------------------------------------------------------------===//
-
-/// Given the 'indices' of an load/store operation where the memref is a result
-/// of a subview op, returns the indices w.r.t to the source memref of the
-/// subview op. For example
-///
-/// %0 = ... : memref<12x42xf32>
-/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
-/// memref<4x4xf32, offset=?, strides=[?, ?]>
-/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
-///
-/// could be folded into
-///
-/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
-/// memref<12x42xf32>
-static LogicalResult
-resolveSourceIndices(Location loc, PatternRewriter &rewriter,
- memref::SubViewOp subViewOp, ValueRange indices,
- SmallVectorImpl<Value> &sourceIndices) {
- // TODO: Aborting when the offsets are static. There might be a way to fold
- // the subview op with load even if the offsets have been canonicalized
- // away.
- SmallVector<Range, 4> opRanges = subViewOp.getOrCreateRanges(rewriter, loc);
- auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; });
- auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; });
- assert(opRanges.size() == indices.size() &&
- "expected as many indices as rank of subview op result type");
-
- // New indices for the load are the current indices * subview_stride +
- // subview_offset.
- sourceIndices.resize(indices.size());
- for (auto index : llvm::enumerate(indices)) {
- auto offset = *(opOffsets.begin() + index.index());
- auto stride = *(opStrides.begin() + index.index());
- auto mul = rewriter.create<MulIOp>(loc, index.value(), stride);
- sourceIndices[index.index()] =
- rewriter.create<AddIOp>(loc, offset, mul).getResult();
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Folding SubViewOp and LoadOp/TransferReadOp.
-//===----------------------------------------------------------------------===//
-
template <typename OpTy>
LogicalResult
LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
PatternRewriter &rewriter) const {
auto subViewOp =
getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
- if (!subViewOp) {
+ if (!subViewOp)
return failure();
- }
+
SmallVector<Value, 4> sourceIndices;
if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
loadOp.indices(), sourceIndices)))
@@ -166,19 +163,15 @@ LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
return success();
}
-//===----------------------------------------------------------------------===//
-// Folding SubViewOp and StoreOp/TransferWriteOp.
-//===----------------------------------------------------------------------===//
-
template <typename OpTy>
LogicalResult
StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
PatternRewriter &rewriter) const {
auto subViewOp =
getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
- if (!subViewOp) {
+ if (!subViewOp)
return failure();
- }
+
SmallVector<Value, 4> sourceIndices;
if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
storeOp.indices(), sourceIndices)))
@@ -188,12 +181,7 @@ StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
return success();
}
-//===----------------------------------------------------------------------===//
-// Hook for adding patterns.
-//===----------------------------------------------------------------------===//
-
-void mlir::populateStdLegalizationPatternsForSPIRVLowering(
- RewritePatternSet &patterns) {
+void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
patterns.add<LoadOpOfSubViewFolder<memref::LoadOp>,
LoadOpOfSubViewFolder<vector::TransferReadOp>,
StoreOpOfSubViewFolder<memref::StoreOp>,
@@ -202,23 +190,28 @@ void mlir::populateStdLegalizationPatternsForSPIRVLowering(
}
//===----------------------------------------------------------------------===//
-// Pass for testing just the legalization patterns.
+// Pass registration
//===----------------------------------------------------------------------===//
namespace {
-struct SPIRVLegalization final
- : public LegalizeStandardForSPIRVBase<SPIRVLegalization> {
+
+#define GEN_PASS_CLASSES
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+
+struct FoldSubViewOpsPass final
+ : public FoldSubViewOpsBase<FoldSubViewOpsPass> {
void runOnOperation() override;
};
+
} // namespace
-void SPIRVLegalization::runOnOperation() {
+void FoldSubViewOpsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
- populateStdLegalizationPatternsForSPIRVLowering(patterns);
+ memref::populateFoldSubViewOpPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
std::move(patterns));
}
-std::unique_ptr<Pass> mlir::createLegalizeStdOpsForSPIRVLoweringPass() {
- return std::make_unique<SPIRVLegalization>();
+std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() {
+ return std::make_unique<FoldSubViewOpsPass>();
}
diff --git a/mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt
new file mode 100644
index 0000000000000..17a6ecba09106
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_mlir_dialect_library(MLIRMemRefUtils
+ MemRefUtils.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRSideEffectInterfaces
+)
+
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 26a9a217134e2..eb9817014c2ff 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -11,7 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
using namespace mlir;
diff --git a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
deleted file mode 100644
index 942f827275f62..0000000000000
--- a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
+++ /dev/null
@@ -1,38 +0,0 @@
-// RUN: mlir-opt -legalize-std-for-spirv %s -o - | FileCheck %s
-
-module {
-
-//===----------------------------------------------------------------------===//
-// memref.subview
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: @fold_static_stride_subview
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<12x32xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: index
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: index
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: index
-// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]*]]: index
-func @fold_static_stride_subview
- (%arg0 : memref<12x32xf32>, %arg1 : index,
- %arg2 : index, %arg3 : index, %arg4 : index) {
- // CHECK-DAG: %[[C2:.*]] = constant 2
- // CHECK-DAG: %[[C3:.*]] = constant 3
- // CHECK: %[[T0:.*]] = muli %[[ARG3]], %[[C3]]
- // CHECK: %[[T1:.*]] = addi %[[ARG1]], %[[T0]]
- // CHECK: %[[T2:.*]] = muli %[[ARG4]], %[[ARG2]]
- // CHECK: %[[T3:.*]] = addi %[[T2]], %[[C2]]
- // CHECK: %[[LOADVAL:.*]] = memref.load %[[ARG0]][%[[T1]], %[[T3]]]
- // CHECK: %[[STOREVAL:.*]] = math.sqrt %[[LOADVAL]]
- // CHECK: %[[T6:.*]] = muli %[[ARG3]], %[[C3]]
- // CHECK: %[[T7:.*]] = addi %[[ARG1]], %[[T6]]
- // CHECK: %[[T8:.*]] = muli %[[ARG4]], %[[ARG2]]
- // CHECK: %[[T9:.*]] = addi %[[T8]], %[[C2]]
- // CHECK: memref.store %[[STOREVAL]], %[[ARG0]][%[[T7]], %[[T9]]]
- %0 = memref.subview %arg0[%arg1, 2][4, 4][3, %arg2] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [96, ?]>
- %1 = memref.load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [96, ?]>
- %2 = math.sqrt %1 : f32
- memref.store %2, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [96, ?]>
- return
-}
-
-} // end module
diff --git a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
similarity index 98%
rename from mlir/test/Conversion/StandardToSPIRV/legalization.mlir
rename to mlir/test/Dialect/MemRef/fold-subview-ops.mlir
index 8077edcc23689..2cddeb93dc301 100644
--- a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir
+++ b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -legalize-std-for-spirv -verify-diagnostics %s -o - | FileCheck %s
+// RUN: mlir-opt -fold-memref-subview-ops -verify-diagnostics %s -o - | FileCheck %s
// CHECK-LABEL: @fold_static_stride_subview_with_load
// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index
diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
index 1ef697768ba25..bfb63a63ef3ab 100644
--- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
+++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
@@ -20,6 +20,7 @@
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
@@ -40,7 +41,7 @@ static LogicalResult runMLIRPasses(ModuleOp module) {
applyPassManagerCLOptions(passManager);
passManager.addPass(createGpuKernelOutliningPass());
- passManager.addPass(createLegalizeStdOpsForSPIRVLoweringPass());
+ passManager.addPass(memref::createFoldSubViewOpsPass());
passManager.addPass(createConvertGPUToSPIRVPass());
OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>();
modulePM.addPass(spirv::createLowerABIAttributesPass());
More information about the Mlir-commits
mailing list