[Mlir-commits] [mlir] d1c8e17 - [mlir][vector] Add canonicalization patterns for extractMap/insertMap

Thomas Raoux llvmlistbot at llvm.org
Fri Oct 2 10:13:43 PDT 2020


Author: Thomas Raoux
Date: 2020-10-02T10:13:11-07:00
New Revision: d1c8e179d8773f82cdba818dac25667224a9e8d1

URL: https://github.com/llvm/llvm-project/commit/d1c8e179d8773f82cdba818dac25667224a9e8d1
DIFF: https://github.com/llvm/llvm-project/commit/d1c8e179d8773f82cdba818dac25667224a9e8d1.diff

LOG: [mlir][vector] Add canonicalization patterns for extractMap/insertMap

Add basic canonicalization patterns for the extractMap/insertMap to allow them
to be folded into Transfer ops.
Also mark transferRead as memory read so that it can be removed by dead code.

Differential Revision: https://reviews.llvm.org/D88622

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-distribution.mlir
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 42e947071403..137e130c4594 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -517,6 +517,8 @@ def Vector_ExtractMapOp :
     $vector `[` $id `:` $multiplicity `]` attr-dict `:` type($vector) `to`
     type(results)
   }];
+
+  let hasFolder = 1;
 }
 
 def Vector_FMAOp :

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 1a83c556d47b..663595ce161c 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -923,6 +923,14 @@ static LogicalResult verify(ExtractMapOp op) {
   return success();
 }
 
+OpFoldResult ExtractMapOp::fold(ArrayRef<Attribute> operands) {
+  auto insert = vector().getDefiningOp<vector::InsertMapOp>();
+  if (insert == nullptr || multiplicity() != insert.multiplicity() ||
+      id() != insert.id())
+    return {};
+  return insert.vector();
+}
+
 //===----------------------------------------------------------------------===//
 // BroadcastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 6a244a454e06..20b928fb9a81 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -12,6 +12,7 @@
 
 #include <type_traits>
 
+#include "mlir/Dialect/Affine/EDSC/Builders.h"
 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
@@ -2452,6 +2453,55 @@ mlir::vector::distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
   return ops;
 }
 
+struct TransferReadExtractPattern
+    : public OpRewritePattern<vector::TransferReadOp> {
+  TransferReadExtractPattern(MLIRContext *context)
+      : OpRewritePattern<vector::TransferReadOp>(context) {}
+  LogicalResult matchAndRewrite(vector::TransferReadOp read,
+                                PatternRewriter &rewriter) const override {
+    if (!read.getResult().hasOneUse())
+      return failure();
+    auto extract =
+        dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin());
+    if (!extract)
+      return failure();
+    edsc::ScopedContext scope(rewriter, read.getLoc());
+    using mlir::edsc::op::operator+;
+    using namespace mlir::edsc::intrinsics;
+    SmallVector<Value, 4> indices(read.indices().begin(), read.indices().end());
+    indices.back() = indices.back() + extract.id();
+    Value newRead = vector_transfer_read(extract.getType(), read.memref(),
+                                         indices, read.permutation_map(),
+                                         read.padding(), ArrayAttr());
+    newRead = rewriter.create<vector::InsertMapOp>(
+        read.getLoc(), newRead, extract.id(), extract.multiplicity());
+    rewriter.replaceOp(read, newRead);
+    return success();
+  }
+};
+
+struct TransferWriteInsertPattern
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  TransferWriteInsertPattern(MLIRContext *context)
+      : OpRewritePattern<vector::TransferWriteOp>(context) {}
+  LogicalResult matchAndRewrite(vector::TransferWriteOp write,
+                                PatternRewriter &rewriter) const override {
+    auto insert = write.vector().getDefiningOp<vector::InsertMapOp>();
+    if (!insert)
+      return failure();
+    edsc::ScopedContext scope(rewriter, write.getLoc());
+    using mlir::edsc::op::operator+;
+    using namespace mlir::edsc::intrinsics;
+    SmallVector<Value, 4> indices(write.indices().begin(),
+                                  write.indices().end());
+    indices.back() = indices.back() + insert.id();
+    vector_transfer_write(insert.vector(), write.memref(), indices,
+                          write.permutation_map(), ArrayAttr());
+    rewriter.eraseOp(write);
+    return success();
+  }
+};
+
 // TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
 // TODO: Add this as DRR pattern.
 void mlir::vector::populateVectorToVectorTransformationPatterns(
@@ -2461,7 +2511,9 @@ void mlir::vector::populateVectorToVectorTransformationPatterns(
                   ShapeCastOpFolder,
                   SplitTransferReadOp,
                   SplitTransferWriteOp,
-                  TupleGetFolderOp>(context);
+                  TupleGetFolderOp,
+                  TransferReadExtractPattern,
+                  TransferWriteInsertPattern>(context);
   // clang-format on
 }
 

diff  --git a/mlir/test/Dialect/Vector/vector-distribution.mlir b/mlir/test/Dialect/Vector/vector-distribution.mlir
index 0216a017d7af..264e0195b4ab 100644
--- a/mlir/test/Dialect/Vector/vector-distribution.mlir
+++ b/mlir/test/Dialect/Vector/vector-distribution.mlir
@@ -11,3 +11,24 @@ func @distribute_vector_add(%id : index, %A: vector<32xf32>, %B: vector<32xf32>)
   %0 = addf %A, %B : vector<32xf32>
   return %0: vector<32xf32>
 }
+
+// CHECK-LABEL: func @vector_add_read_write
+//  CHECK-SAME: (%[[ID:.*]]: index
+//       CHECK:    %[[EXA:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32>
+//  CHECK-NEXT:    %[[EXB:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32>
+//  CHECK-NEXT:    %[[ADD1:.*]] = addf %[[EXA]], %[[EXB]] : vector<1xf32>
+//  CHECK-NEXT:    %[[EXC:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32>
+//  CHECK-NEXT:    %[[ADD2:.*]] = addf %[[ADD1]], %[[EXC]] : vector<1xf32>
+//  CHECK-NEXT:    vector.transfer_write %[[ADD2]], %{{.*}}[%{{.*}}] : vector<1xf32>, memref<32xf32>
+//  CHECK-NEXT:    return
+func @vector_add_read_write(%id : index, %A: memref<32xf32>, %B: memref<32xf32>, %C: memref<32xf32>, %D: memref<32xf32>) {
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %a = vector.transfer_read %A[%c0], %cf0: memref<32xf32>, vector<32xf32>
+  %b = vector.transfer_read %B[%c0], %cf0: memref<32xf32>, vector<32xf32>
+  %acc = addf %a, %b: vector<32xf32>
+  %c = vector.transfer_read %C[%c0], %cf0: memref<32xf32>, vector<32xf32>
+  %d = addf %acc, %c: vector<32xf32>
+  vector.transfer_write %d, %D[%c0]: vector<32xf32>, memref<32xf32>
+  return
+}

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 2ffe10bc1682..c1faf23d85df 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -129,6 +129,7 @@ struct TestVectorDistributePatterns
     : public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<VectorDialect>();
+    registry.insert<AffineDialect>();
   }
   void runOnFunction() override {
     MLIRContext *ctx = &getContext();
@@ -143,6 +144,7 @@ struct TestVectorDistributePatterns
       op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
     });
     patterns.insert<PointwiseExtractPattern>(ctx);
+    populateVectorToVectorTransformationPatterns(patterns, ctx);
     applyPatternsAndFoldGreedily(getFunction(), patterns);
   }
 };


        


More information about the Mlir-commits mailing list