[Mlir-commits] [mlir] faafd26 - [mlir][MemRef] Move transform related functions in Transforms.h

Quentin Colombet llvmlistbot at llvm.org
Tue Mar 28 06:23:14 PDT 2023


Author: Quentin Colombet
Date: 2023-03-28T15:20:19+02:00
New Revision: faafd26c4d58e1292eeb01179d82a375ed2a47d3

URL: https://github.com/llvm/llvm-project/commit/faafd26c4d58e1292eeb01179d82a375ed2a47d3
DIFF: https://github.com/llvm/llvm-project/commit/faafd26c4d58e1292eeb01179d82a375ed2a47d3.diff

LOG: [mlir][MemRef] Move transform related functions in Transforms.h

NFC

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
    mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
    mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
    mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
    mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
    mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
    mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
    mlir/test/lib/Dialect/MemRef/TestMultiBuffer.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index 9f5b095eef8cf..801555dab18b2 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -19,11 +19,6 @@ namespace mlir {
 
 class AffineDialect;
 class ModuleOp;
-class RewriterBase;
-
-namespace arith {
-class WideIntEmulationConverter;
-} // namespace arith
 
 namespace func {
 class FuncDialect;
@@ -36,82 +31,6 @@ class VectorDialect;
 } // namespace vector
 
 namespace memref {
-class AllocOp;
-//===----------------------------------------------------------------------===//
-// Patterns
-//===----------------------------------------------------------------------===//
-
-/// Collects a set of patterns to rewrite ops within the memref dialect.
-void populateExpandOpsPatterns(RewritePatternSet &patterns);
-
-/// Appends patterns for folding memref aliasing ops into consumer load/store
-/// ops into `patterns`.
-void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns);
-
-/// Appends patterns that resolve `memref.dim` operations with values that are
-/// defined by operations that implement the
-/// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input
-/// operands.
-void populateResolveRankedShapeTypeResultDimsPatterns(
-    RewritePatternSet &patterns);
-
-/// Appends patterns that resolve `memref.dim` operations with values that are
-/// defined by operations that implement the `InferShapedTypeOpInterface`, in
-/// terms of shapes of its input operands.
-void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
-
-/// Appends patterns for expanding memref operations that modify the metadata
-/// (sizes, offset, strides) of a memref into easier to analyze constructs.
-void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns);
-
-/// Appends patterns for emulating wide integer memref operations with ops over
-/// narrower integer types.
-void populateMemRefWideIntEmulationPatterns(
-    arith::WideIntEmulationConverter &typeConverter,
-    RewritePatternSet &patterns);
-
-/// Appends type converions for emulating wide integer memref operations with
-/// ops over narrowe integer types.
-void populateMemRefWideIntEmulationConversions(
-    arith::WideIntEmulationConverter &typeConverter);
-
-/// Transformation to do multi-buffering/array expansion to remove dependencies
-/// on the temporary allocation between consecutive loop iterations.
-/// It returns the new allocation if the original allocation was multi-buffered
-/// and returns failure() otherwise.
-/// When `skipOverrideAnalysis`, the pass will apply the transformation
-/// without checking thwt the buffer is overrided at the beginning of each
-/// iteration. This implies that user knows that there is no data carried across
-/// loop iterations. Example:
-/// ```
-/// %0 = memref.alloc() : memref<4x128xf32>
-/// scf.for %iv = %c1 to %c1024 step %c3 {
-///   memref.copy %1, %0 : memref<4x128xf32> to memref<4x128xf32>
-///   "some_use"(%0) : (memref<4x128xf32>) -> ()
-/// }
-/// ```
-/// into:
-/// ```
-/// %0 = memref.alloc() : memref<5x4x128xf32>
-/// scf.for %iv = %c1 to %c1024 step %c3 {
-///   %s = arith.subi %iv, %c1 : index
-///   %d = arith.divsi %s, %c3 : index
-///   %i = arith.remsi %d, %c5 : index
-///   %sv = memref.subview %0[%i, 0, 0] [1, 4, 128] [1, 1, 1] :
-///     memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>>
-///   memref.copy %1, %sv : memref<4x128xf32> to memref<4x128xf32, strided<...>>
-///   "some_use"(%sv) : (memref<4x128xf32, strided<...>) -> ()
-/// }
-/// ```
-FailureOr<memref::AllocOp> multiBuffer(RewriterBase &rewriter,
-                                       memref::AllocOp allocOp,
-                                       unsigned multiplier,
-                                       bool skipOverrideAnalysis = false);
-/// Call into `multiBuffer` with  locally constructed IRRewriter.
-FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
-                                       unsigned multiplier,
-                                       bool skipOverrideAnalysis = false);
-
 //===----------------------------------------------------------------------===//
 // Passes
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 18b12d6b31dc7..81f3988c068f3 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -16,8 +16,89 @@
 
 namespace mlir {
 class RewritePatternSet;
+class RewriterBase;
+
+namespace arith {
+class WideIntEmulationConverter;
+} // namespace arith
 
 namespace memref {
+class AllocOp;
+//===----------------------------------------------------------------------===//
+// Patterns
+//===----------------------------------------------------------------------===//
+
+/// Collects a set of patterns to rewrite ops within the memref dialect.
+void populateExpandOpsPatterns(RewritePatternSet &patterns);
+
+/// Appends patterns for folding memref aliasing ops into consumer load/store
+/// ops into `patterns`.
+void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns);
+
+/// Appends patterns that resolve `memref.dim` operations with values that are
+/// defined by operations that implement the
+/// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input
+/// operands.
+void populateResolveRankedShapeTypeResultDimsPatterns(
+    RewritePatternSet &patterns);
+
+/// Appends patterns that resolve `memref.dim` operations with values that are
+/// defined by operations that implement the `InferShapedTypeOpInterface`, in
+/// terms of shapes of its input operands.
+void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
+
+/// Appends patterns for expanding memref operations that modify the metadata
+/// (sizes, offset, strides) of a memref into easier to analyze constructs.
+void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns);
+
+/// Appends patterns for emulating wide integer memref operations with ops over
+/// narrower integer types.
+void populateMemRefWideIntEmulationPatterns(
+    arith::WideIntEmulationConverter &typeConverter,
+    RewritePatternSet &patterns);
+
+/// Appends type converions for emulating wide integer memref operations with
+/// ops over narrowe integer types.
+void populateMemRefWideIntEmulationConversions(
+    arith::WideIntEmulationConverter &typeConverter);
+
+/// Transformation to do multi-buffering/array expansion to remove dependencies
+/// on the temporary allocation between consecutive loop iterations.
+/// It returns the new allocation if the original allocation was multi-buffered
+/// and returns failure() otherwise.
+/// When `skipOverrideAnalysis`, the pass will apply the transformation
+/// without checking thwt the buffer is overrided at the beginning of each
+/// iteration. This implies that user knows that there is no data carried across
+/// loop iterations. Example:
+/// ```
+/// %0 = memref.alloc() : memref<4x128xf32>
+/// scf.for %iv = %c1 to %c1024 step %c3 {
+///   memref.copy %1, %0 : memref<4x128xf32> to memref<4x128xf32>
+///   "some_use"(%0) : (memref<4x128xf32>) -> ()
+/// }
+/// ```
+/// into:
+/// ```
+/// %0 = memref.alloc() : memref<5x4x128xf32>
+/// scf.for %iv = %c1 to %c1024 step %c3 {
+///   %s = arith.subi %iv, %c1 : index
+///   %d = arith.divsi %s, %c3 : index
+///   %i = arith.remsi %d, %c5 : index
+///   %sv = memref.subview %0[%i, 0, 0] [1, 4, 128] [1, 1, 1] :
+///     memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>>
+///   memref.copy %1, %sv : memref<4x128xf32> to memref<4x128xf32, strided<...>>
+///   "some_use"(%sv) : (memref<4x128xf32, strided<...>) -> ()
+/// }
+/// ```
+FailureOr<memref::AllocOp> multiBuffer(RewriterBase &rewriter,
+                                       memref::AllocOp allocOp,
+                                       unsigned multiplier,
+                                       bool skipOverrideAnalysis = false);
+/// Call into `multiBuffer` with  locally constructed IRRewriter.
+FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
+                                       unsigned multiplier,
+                                       bool skipOverrideAnalysis = false);
+
 /// Appends patterns for extracting address computations from the instructions
 /// with memory accesses such that these memory accesses use only a base
 /// pointer.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index f122712b209bf..4a2c0a64fc07a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -19,7 +19,7 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
-#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/Utils/Utils.h"

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
index 81b9b279a6500..6202b5730c218 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/Support/FormatVariadic.h"

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index d093653a84a5d..de0766daeff30 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 72e6a2925e83f..918055d25b6c2 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index c850348c85480..ecf7bcbf997a4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/AffineMap.h"

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index 660d1022e0182..64874487c3d3a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Dominance.h"

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 50ac04d6d45cc..8c544bbd9fb0d 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"

diff  --git a/mlir/test/lib/Dialect/MemRef/TestMultiBuffer.cpp b/mlir/test/lib/Dialect/MemRef/TestMultiBuffer.cpp
index 30cf06b866697..6996ef2dc3179 100644
--- a/mlir/test/lib/Dialect/MemRef/TestMultiBuffer.cpp
+++ b/mlir/test/lib/Dialect/MemRef/TestMultiBuffer.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Pass/Pass.h"


        


More information about the Mlir-commits mailing list