[Mlir-commits] [mlir] cf402a1 - [mlir][vector] Add unit test for vector distribute by block

Thomas Raoux llvmlistbot at llvm.org
Thu Oct 8 14:45:11 PDT 2020


Author: Thomas Raoux
Date: 2020-10-08T14:44:03-07:00
New Revision: cf402a1987591923492fe697b2e84b1affbae6dd

URL: https://github.com/llvm/llvm-project/commit/cf402a1987591923492fe697b2e84b1affbae6dd
DIFF: https://github.com/llvm/llvm-project/commit/cf402a1987591923492fe697b2e84b1affbae6dd.diff

LOG: [mlir][vector] Add unit test for vector distribute by block

When distributing a vector larger than the given multiplicity, we can
distribute it by block where each id gets a chunk of consecutive element
along the dimension distributed. This adds a test for this case and adds extra
checks to make sure we don't distribute for cases not multiple of multiplicity.

Differential Revision: https://reviews.llvm.org/D89061

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-distribution.mlir
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 20b928fb9a81..08ee5c64af09 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2444,7 +2444,14 @@ mlir::vector::distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
   OpBuilder::InsertionGuard guard(builder);
   builder.setInsertionPointAfter(op);
   Location loc = op->getLoc();
+  if (op->getNumResults() != 1)
+    return {};
   Value result = op->getResult(0);
+  VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
+  // Currently only support distributing 1-D vectors of size multiple of the
+  // given multiplicty. To handle more sizes we would need to support masking.
+  if (!type || type.getRank() != 1 || type.getNumElements() % multiplicity != 0)
+    return {};
   DistributeOps ops;
   ops.extract =
       builder.create<vector::ExtractMapOp>(loc, result, id, multiplicity);

diff  --git a/mlir/test/Dialect/Vector/vector-distribution.mlir b/mlir/test/Dialect/Vector/vector-distribution.mlir
index 264e0195b4ab..79f9b7871dfb 100644
--- a/mlir/test/Dialect/Vector/vector-distribution.mlir
+++ b/mlir/test/Dialect/Vector/vector-distribution.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-vector-distribute-patterns | FileCheck %s
+// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32 | FileCheck %s
 
 // CHECK-LABEL: func @distribute_vector_add
 //  CHECK-SAME: (%[[ID:.*]]: index
@@ -14,12 +14,12 @@ func @distribute_vector_add(%id : index, %A: vector<32xf32>, %B: vector<32xf32>)
 
 // CHECK-LABEL: func @vector_add_read_write
 //  CHECK-SAME: (%[[ID:.*]]: index
-//       CHECK:    %[[EXA:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32>
-//  CHECK-NEXT:    %[[EXB:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32>
+//       CHECK:    %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32>
+//  CHECK-NEXT:    %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32>
 //  CHECK-NEXT:    %[[ADD1:.*]] = addf %[[EXA]], %[[EXB]] : vector<1xf32>
-//  CHECK-NEXT:    %[[EXC:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32>
+//  CHECK-NEXT:    %[[EXC:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32>
 //  CHECK-NEXT:    %[[ADD2:.*]] = addf %[[ADD1]], %[[EXC]] : vector<1xf32>
-//  CHECK-NEXT:    vector.transfer_write %[[ADD2]], %{{.*}}[%{{.*}}] : vector<1xf32>, memref<32xf32>
+//  CHECK-NEXT:    vector.transfer_write %[[ADD2]], %{{.*}}[%[[ID]]] : vector<1xf32>, memref<32xf32>
 //  CHECK-NEXT:    return
 func @vector_add_read_write(%id : index, %A: memref<32xf32>, %B: memref<32xf32>, %C: memref<32xf32>, %D: memref<32xf32>) {
   %c0 = constant 0 : index
@@ -32,3 +32,41 @@ func @vector_add_read_write(%id : index, %A: memref<32xf32>, %B: memref<32xf32>,
   vector.transfer_write %d, %D[%c0]: vector<32xf32>, memref<32xf32>
   return
 }
+
+// CHECK-LABEL: func @vector_add_cycle
+//  CHECK-SAME: (%[[ID:.*]]: index
+//       CHECK:    %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<64xf32>, vector<2xf32>
+//  CHECK-NEXT:    %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<64xf32>, vector<2xf32>
+//  CHECK-NEXT:    %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<2xf32>
+//  CHECK-NEXT:    vector.transfer_write %[[ADD]], %{{.*}}[%[[ID]]] : vector<2xf32>, memref<64xf32>
+//  CHECK-NEXT:    return
+func @vector_add_cycle(%id : index, %A: memref<64xf32>, %B: memref<64xf32>, %C: memref<64xf32>) {
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %a = vector.transfer_read %A[%c0], %cf0: memref<64xf32>, vector<64xf32>
+  %b = vector.transfer_read %B[%c0], %cf0: memref<64xf32>, vector<64xf32>
+  %acc = addf %a, %b: vector<64xf32>
+  vector.transfer_write %acc, %C[%c0]: vector<64xf32>, memref<64xf32>
+  return
+}
+
+// Negative test to make sure nothing is done in case the vector size is not a
+// multiple of multiplicity.
+// CHECK-LABEL: func @vector_negative_test
+//       CHECK:    %[[C0:.*]] = constant 0 : index
+//       CHECK:    %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %{{.*}} : memref<64xf32>, vector<16xf32>
+//  CHECK-NEXT:    %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %{{.*}} : memref<64xf32>, vector<16xf32>
+//  CHECK-NEXT:    %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<16xf32>
+//  CHECK-NEXT:    vector.transfer_write %[[ADD]], %{{.*}}[%[[C0]]] {{.*}} : vector<16xf32>, memref<64xf32>
+//  CHECK-NEXT:    return
+func @vector_negative_test(%id : index, %A: memref<64xf32>, %B: memref<64xf32>, %C: memref<64xf32>) {
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %a = vector.transfer_read %A[%c0], %cf0: memref<64xf32>, vector<16xf32>
+  %b = vector.transfer_read %B[%c0], %cf0: memref<64xf32>, vector<16xf32>
+  %acc = addf %a, %b: vector<16xf32>
+  vector.transfer_write %acc, %C[%c0]: vector<16xf32>, memref<64xf32>
+  return
+}
+
+

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index c1faf23d85df..fe0947d0ac30 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -127,10 +127,16 @@ struct TestVectorUnrollingPatterns
 
 struct TestVectorDistributePatterns
     : public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
+  TestVectorDistributePatterns() = default;
+  TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {}
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<VectorDialect>();
     registry.insert<AffineDialect>();
   }
+  Option<int32_t> multiplicity{
+      *this, "distribution-multiplicity",
+      llvm::cl::desc("Set the multiplicity used for distributing vector"),
+      llvm::cl::init(32)};
   void runOnFunction() override {
     MLIRContext *ctx = &getContext();
     OwningRewritePatternList patterns;
@@ -138,10 +144,11 @@ struct TestVectorDistributePatterns
     func.walk([&](AddFOp op) {
       OpBuilder builder(op);
       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
-          builder, op.getOperation(), func.getArgument(0), 32);
-      assert(ops.hasValue());
-      SmallPtrSet<Operation *, 1> extractOp({ops->extract});
-      op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
+          builder, op.getOperation(), func.getArgument(0), multiplicity);
+      if (ops.hasValue()) {
+        SmallPtrSet<Operation *, 1> extractOp({ops->extract});
+        op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
+      }
     });
     patterns.insert<PointwiseExtractPattern>(ctx);
     populateVectorToVectorTransformationPatterns(patterns, ctx);


        


More information about the Mlir-commits mailing list