[Mlir-commits] [mlir] d1c8e17 - [mlir][vector] Add canonicalization patterns for extractMap/insertMap
Thomas Raoux
llvmlistbot at llvm.org
Fri Oct 2 10:13:43 PDT 2020
Author: Thomas Raoux
Date: 2020-10-02T10:13:11-07:00
New Revision: d1c8e179d8773f82cdba818dac25667224a9e8d1
URL: https://github.com/llvm/llvm-project/commit/d1c8e179d8773f82cdba818dac25667224a9e8d1
DIFF: https://github.com/llvm/llvm-project/commit/d1c8e179d8773f82cdba818dac25667224a9e8d1.diff
LOG: [mlir][vector] Add canonicalization patterns for extractMap/insertMap
Add basic canonicalization patterns for the extractMap/insertMap to allow them
to be folded into Transfer ops.
Also mark transferRead as memory read so that it can be removed by dead code.
Differential Revision: https://reviews.llvm.org/D88622
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-distribution.mlir
mlir/test/lib/Transforms/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 42e947071403..137e130c4594 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -517,6 +517,8 @@ def Vector_ExtractMapOp :
$vector `[` $id `:` $multiplicity `]` attr-dict `:` type($vector) `to`
type(results)
}];
+
+ let hasFolder = 1;
}
def Vector_FMAOp :
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 1a83c556d47b..663595ce161c 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -923,6 +923,14 @@ static LogicalResult verify(ExtractMapOp op) {
return success();
}
+OpFoldResult ExtractMapOp::fold(ArrayRef<Attribute> operands) {
+ auto insert = vector().getDefiningOp<vector::InsertMapOp>();
+ if (insert == nullptr || multiplicity() != insert.multiplicity() ||
+ id() != insert.id())
+ return {};
+ return insert.vector();
+}
+
//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 6a244a454e06..20b928fb9a81 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -12,6 +12,7 @@
#include <type_traits>
+#include "mlir/Dialect/Affine/EDSC/Builders.h"
#include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
@@ -2452,6 +2453,55 @@ mlir::vector::distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
return ops;
}
+struct TransferReadExtractPattern
+ : public OpRewritePattern<vector::TransferReadOp> {
+ TransferReadExtractPattern(MLIRContext *context)
+ : OpRewritePattern<vector::TransferReadOp>(context) {}
+ LogicalResult matchAndRewrite(vector::TransferReadOp read,
+ PatternRewriter &rewriter) const override {
+ if (!read.getResult().hasOneUse())
+ return failure();
+ auto extract =
+ dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin());
+ if (!extract)
+ return failure();
+ edsc::ScopedContext scope(rewriter, read.getLoc());
+ using mlir::edsc::op::operator+;
+ using namespace mlir::edsc::intrinsics;
+ SmallVector<Value, 4> indices(read.indices().begin(), read.indices().end());
+ indices.back() = indices.back() + extract.id();
+ Value newRead = vector_transfer_read(extract.getType(), read.memref(),
+ indices, read.permutation_map(),
+ read.padding(), ArrayAttr());
+ newRead = rewriter.create<vector::InsertMapOp>(
+ read.getLoc(), newRead, extract.id(), extract.multiplicity());
+ rewriter.replaceOp(read, newRead);
+ return success();
+ }
+};
+
+struct TransferWriteInsertPattern
+ : public OpRewritePattern<vector::TransferWriteOp> {
+ TransferWriteInsertPattern(MLIRContext *context)
+ : OpRewritePattern<vector::TransferWriteOp>(context) {}
+ LogicalResult matchAndRewrite(vector::TransferWriteOp write,
+ PatternRewriter &rewriter) const override {
+ auto insert = write.vector().getDefiningOp<vector::InsertMapOp>();
+ if (!insert)
+ return failure();
+ edsc::ScopedContext scope(rewriter, write.getLoc());
+ using mlir::edsc::op::operator+;
+ using namespace mlir::edsc::intrinsics;
+ SmallVector<Value, 4> indices(write.indices().begin(),
+ write.indices().end());
+ indices.back() = indices.back() + insert.id();
+ vector_transfer_write(insert.vector(), write.memref(), indices,
+ write.permutation_map(), ArrayAttr());
+ rewriter.eraseOp(write);
+ return success();
+ }
+};
+
// TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
// TODO: Add this as DRR pattern.
void mlir::vector::populateVectorToVectorTransformationPatterns(
@@ -2461,7 +2511,9 @@ void mlir::vector::populateVectorToVectorTransformationPatterns(
ShapeCastOpFolder,
SplitTransferReadOp,
SplitTransferWriteOp,
- TupleGetFolderOp>(context);
+ TupleGetFolderOp,
+ TransferReadExtractPattern,
+ TransferWriteInsertPattern>(context);
// clang-format on
}
diff --git a/mlir/test/Dialect/Vector/vector-distribution.mlir b/mlir/test/Dialect/Vector/vector-distribution.mlir
index 0216a017d7af..264e0195b4ab 100644
--- a/mlir/test/Dialect/Vector/vector-distribution.mlir
+++ b/mlir/test/Dialect/Vector/vector-distribution.mlir
@@ -11,3 +11,24 @@ func @distribute_vector_add(%id : index, %A: vector<32xf32>, %B: vector<32xf32>)
%0 = addf %A, %B : vector<32xf32>
return %0: vector<32xf32>
}
+
+// CHECK-LABEL: func @vector_add_read_write
+// CHECK-SAME: (%[[ID:.*]]: index
+// CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32>
+// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32>
+// CHECK-NEXT: %[[ADD1:.*]] = addf %[[EXA]], %[[EXB]] : vector<1xf32>
+// CHECK-NEXT: %[[EXC:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32>
+// CHECK-NEXT: %[[ADD2:.*]] = addf %[[ADD1]], %[[EXC]] : vector<1xf32>
+// CHECK-NEXT: vector.transfer_write %[[ADD2]], %{{.*}}[%{{.*}}] : vector<1xf32>, memref<32xf32>
+// CHECK-NEXT: return
+func @vector_add_read_write(%id : index, %A: memref<32xf32>, %B: memref<32xf32>, %C: memref<32xf32>, %D: memref<32xf32>) {
+ %c0 = constant 0 : index
+ %cf0 = constant 0.0 : f32
+ %a = vector.transfer_read %A[%c0], %cf0: memref<32xf32>, vector<32xf32>
+ %b = vector.transfer_read %B[%c0], %cf0: memref<32xf32>, vector<32xf32>
+ %acc = addf %a, %b: vector<32xf32>
+ %c = vector.transfer_read %C[%c0], %cf0: memref<32xf32>, vector<32xf32>
+ %d = addf %acc, %c: vector<32xf32>
+ vector.transfer_write %d, %D[%c0]: vector<32xf32>, memref<32xf32>
+ return
+}
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 2ffe10bc1682..c1faf23d85df 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -129,6 +129,7 @@ struct TestVectorDistributePatterns
: public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<VectorDialect>();
+ registry.insert<AffineDialect>();
}
void runOnFunction() override {
MLIRContext *ctx = &getContext();
@@ -143,6 +144,7 @@ struct TestVectorDistributePatterns
op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
});
patterns.insert<PointwiseExtractPattern>(ctx);
+ populateVectorToVectorTransformationPatterns(patterns, ctx);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
};
More information about the Mlir-commits
mailing list