[Mlir-commits] [mlir] faafd26 - [mlir][MemRef] Move transform related functions in Transforms.h
Quentin Colombet
llvmlistbot at llvm.org
Tue Mar 28 06:23:14 PDT 2023
Author: Quentin Colombet
Date: 2023-03-28T15:20:19+02:00
New Revision: faafd26c4d58e1292eeb01179d82a375ed2a47d3
URL: https://github.com/llvm/llvm-project/commit/faafd26c4d58e1292eeb01179d82a375ed2a47d3
DIFF: https://github.com/llvm/llvm-project/commit/faafd26c4d58e1292eeb01179d82a375ed2a47d3.diff
LOG: [mlir][MemRef] Move transform related functions in Transforms.h
NFC
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
mlir/test/lib/Dialect/MemRef/TestMultiBuffer.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index 9f5b095eef8cf..801555dab18b2 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -19,11 +19,6 @@ namespace mlir {
class AffineDialect;
class ModuleOp;
-class RewriterBase;
-
-namespace arith {
-class WideIntEmulationConverter;
-} // namespace arith
namespace func {
class FuncDialect;
@@ -36,82 +31,6 @@ class VectorDialect;
} // namespace vector
namespace memref {
-class AllocOp;
-//===----------------------------------------------------------------------===//
-// Patterns
-//===----------------------------------------------------------------------===//
-
-/// Collects a set of patterns to rewrite ops within the memref dialect.
-void populateExpandOpsPatterns(RewritePatternSet &patterns);
-
-/// Appends patterns for folding memref aliasing ops into consumer load/store
-/// ops into `patterns`.
-void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns);
-
-/// Appends patterns that resolve `memref.dim` operations with values that are
-/// defined by operations that implement the
-/// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input
-/// operands.
-void populateResolveRankedShapeTypeResultDimsPatterns(
- RewritePatternSet &patterns);
-
-/// Appends patterns that resolve `memref.dim` operations with values that are
-/// defined by operations that implement the `InferShapedTypeOpInterface`, in
-/// terms of shapes of its input operands.
-void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
-
-/// Appends patterns for expanding memref operations that modify the metadata
-/// (sizes, offset, strides) of a memref into easier to analyze constructs.
-void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns);
-
-/// Appends patterns for emulating wide integer memref operations with ops over
-/// narrower integer types.
-void populateMemRefWideIntEmulationPatterns(
- arith::WideIntEmulationConverter &typeConverter,
- RewritePatternSet &patterns);
-
-/// Appends type converions for emulating wide integer memref operations with
-/// ops over narrowe integer types.
-void populateMemRefWideIntEmulationConversions(
- arith::WideIntEmulationConverter &typeConverter);
-
-/// Transformation to do multi-buffering/array expansion to remove dependencies
-/// on the temporary allocation between consecutive loop iterations.
-/// It returns the new allocation if the original allocation was multi-buffered
-/// and returns failure() otherwise.
-/// When `skipOverrideAnalysis`, the pass will apply the transformation
-/// without checking thwt the buffer is overrided at the beginning of each
-/// iteration. This implies that user knows that there is no data carried across
-/// loop iterations. Example:
-/// ```
-/// %0 = memref.alloc() : memref<4x128xf32>
-/// scf.for %iv = %c1 to %c1024 step %c3 {
-/// memref.copy %1, %0 : memref<4x128xf32> to memref<4x128xf32>
-/// "some_use"(%0) : (memref<4x128xf32>) -> ()
-/// }
-/// ```
-/// into:
-/// ```
-/// %0 = memref.alloc() : memref<5x4x128xf32>
-/// scf.for %iv = %c1 to %c1024 step %c3 {
-/// %s = arith.subi %iv, %c1 : index
-/// %d = arith.divsi %s, %c3 : index
-/// %i = arith.remsi %d, %c5 : index
-/// %sv = memref.subview %0[%i, 0, 0] [1, 4, 128] [1, 1, 1] :
-/// memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>>
-/// memref.copy %1, %sv : memref<4x128xf32> to memref<4x128xf32, strided<...>>
-/// "some_use"(%sv) : (memref<4x128xf32, strided<...>) -> ()
-/// }
-/// ```
-FailureOr<memref::AllocOp> multiBuffer(RewriterBase &rewriter,
- memref::AllocOp allocOp,
- unsigned multiplier,
- bool skipOverrideAnalysis = false);
-/// Call into `multiBuffer` with locally constructed IRRewriter.
-FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
- unsigned multiplier,
- bool skipOverrideAnalysis = false);
-
//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 18b12d6b31dc7..81f3988c068f3 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -16,8 +16,89 @@
namespace mlir {
class RewritePatternSet;
+class RewriterBase;
+
+namespace arith {
+class WideIntEmulationConverter;
+} // namespace arith
namespace memref {
+class AllocOp;
+//===----------------------------------------------------------------------===//
+// Patterns
+//===----------------------------------------------------------------------===//
+
+/// Collects a set of patterns to rewrite ops within the memref dialect.
+void populateExpandOpsPatterns(RewritePatternSet &patterns);
+
+/// Appends patterns for folding memref aliasing ops into consumer load/store
+/// ops into `patterns`.
+void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns);
+
+/// Appends patterns that resolve `memref.dim` operations with values that are
+/// defined by operations that implement the
+/// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input
+/// operands.
+void populateResolveRankedShapeTypeResultDimsPatterns(
+ RewritePatternSet &patterns);
+
+/// Appends patterns that resolve `memref.dim` operations with values that are
+/// defined by operations that implement the `InferShapedTypeOpInterface`, in
+/// terms of shapes of its input operands.
+void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
+
+/// Appends patterns for expanding memref operations that modify the metadata
+/// (sizes, offset, strides) of a memref into easier to analyze constructs.
+void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns);
+
+/// Appends patterns for emulating wide integer memref operations with ops over
+/// narrower integer types.
+void populateMemRefWideIntEmulationPatterns(
+ arith::WideIntEmulationConverter &typeConverter,
+ RewritePatternSet &patterns);
+
+/// Appends type converions for emulating wide integer memref operations with
+/// ops over narrowe integer types.
+void populateMemRefWideIntEmulationConversions(
+ arith::WideIntEmulationConverter &typeConverter);
+
+/// Transformation to do multi-buffering/array expansion to remove dependencies
+/// on the temporary allocation between consecutive loop iterations.
+/// It returns the new allocation if the original allocation was multi-buffered
+/// and returns failure() otherwise.
+/// When `skipOverrideAnalysis`, the pass will apply the transformation
+/// without checking thwt the buffer is overrided at the beginning of each
+/// iteration. This implies that user knows that there is no data carried across
+/// loop iterations. Example:
+/// ```
+/// %0 = memref.alloc() : memref<4x128xf32>
+/// scf.for %iv = %c1 to %c1024 step %c3 {
+/// memref.copy %1, %0 : memref<4x128xf32> to memref<4x128xf32>
+/// "some_use"(%0) : (memref<4x128xf32>) -> ()
+/// }
+/// ```
+/// into:
+/// ```
+/// %0 = memref.alloc() : memref<5x4x128xf32>
+/// scf.for %iv = %c1 to %c1024 step %c3 {
+/// %s = arith.subi %iv, %c1 : index
+/// %d = arith.divsi %s, %c3 : index
+/// %i = arith.remsi %d, %c5 : index
+/// %sv = memref.subview %0[%i, 0, 0] [1, 4, 128] [1, 1, 1] :
+/// memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>>
+/// memref.copy %1, %sv : memref<4x128xf32> to memref<4x128xf32, strided<...>>
+/// "some_use"(%sv) : (memref<4x128xf32, strided<...>) -> ()
+/// }
+/// ```
+FailureOr<memref::AllocOp> multiBuffer(RewriterBase &rewriter,
+ memref::AllocOp allocOp,
+ unsigned multiplier,
+ bool skipOverrideAnalysis = false);
+/// Call into `multiBuffer` with locally constructed IRRewriter.
+FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
+ unsigned multiplier,
+ bool skipOverrideAnalysis = false);
+
/// Appends patterns for extracting address computations from the instructions
/// with memory accesses such that these memory accesses use only a base
/// pointer.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index f122712b209bf..4a2c0a64fc07a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -19,7 +19,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
-#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
index 81b9b279a6500..6202b5730c218 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index d093653a84a5d..de0766daeff30 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 72e6a2925e83f..918055d25b6c2 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index c850348c85480..ecf7bcbf997a4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineMap.h"
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index 660d1022e0182..64874487c3d3a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h"
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 50ac04d6d45cc..8c544bbd9fb0d 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/mlir/test/lib/Dialect/MemRef/TestMultiBuffer.cpp b/mlir/test/lib/Dialect/MemRef/TestMultiBuffer.cpp
index 30cf06b866697..6996ef2dc3179 100644
--- a/mlir/test/lib/Dialect/MemRef/TestMultiBuffer.cpp
+++ b/mlir/test/lib/Dialect/MemRef/TestMultiBuffer.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h"
More information about the Mlir-commits
mailing list