[Mlir-commits] [mlir] 5d45f75 - [mlir][vector] Improve vector distribute integration test and fix block distribution

Thomas Raoux llvmlistbot at llvm.org
Thu Oct 29 14:55:23 PDT 2020


Author: Thomas Raoux
Date: 2020-10-29T14:54:53-07:00
New Revision: 5d45f758f0fba3174126bda24b315006b8b48f1f

URL: https://github.com/llvm/llvm-project/commit/5d45f758f0fba3174126bda24b315006b8b48f1f
DIFF: https://github.com/llvm/llvm-project/commit/5d45f758f0fba3174126bda24b315006b8b48f1f.diff

LOG: [mlir][vector] Improve vector distribute integration test and fix block distribution

Fix semantic in the distribute integration test based on offline feedback. This
exposed a bug in block distribution, we need to make sure the id is multiplied
by the stride of the vector. Fix the transformation and unit test.

Differential Revision: https://reviews.llvm.org/D89291

Added: 
    

Modified: 
    mlir/integration_test/Dialect/Vector/CPU/test-vector-distribute.mlir
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-distribution.mlir
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/integration_test/Dialect/Vector/CPU/test-vector-distribute.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-vector-distribute.mlir
index befbcd8a87f32..b83b1f6539005 100644
--- a/mlir/integration_test/Dialect/Vector/CPU/test-vector-distribute.mlir
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-vector-distribute.mlir
@@ -1,9 +1,18 @@
-// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32 \
-// RUN:  -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm | \
+// RUN: mlir-opt %s -test-vector-to-forloop -convert-vector-to-scf \
+// RUN:   -lower-affine -convert-scf-to-std -convert-vector-to-llvm | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void  \
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext | \
 // RUN: FileCheck %s
 
+// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine \
+// RUN: -convert-scf-to-std -convert-vector-to-llvm | mlir-cpu-runner -e main \
+// RUN: -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -test-vector-to-forloop | FileCheck %s -check-prefix=TRANSFORM
+
+
 func @print_memref_f32(memref<*xf32>)
 
 func @alloc_1d_filled_inc_f32(%arg0: index, %arg1: f32) -> memref<?xf32> {
@@ -19,30 +28,29 @@ func @alloc_1d_filled_inc_f32(%arg0: index, %arg1: f32) -> memref<?xf32> {
   return %0 : memref<?xf32>
 }
 
-func @vector_add_cycle(%id : index, %A: memref<?xf32>, %B: memref<?xf32>, %C: memref<?xf32>) {
-  %c0 = constant 0 : index
-  %cf0 = constant 0.0 : f32
-  %a = vector.transfer_read %A[%c0], %cf0: memref<?xf32>, vector<64xf32>
-  %b = vector.transfer_read %B[%c0], %cf0: memref<?xf32>, vector<64xf32>
-  %acc = addf %a, %b: vector<64xf32>
-  vector.transfer_write %acc, %C[%c0]: vector<64xf32>, memref<?xf32>
-  return
-}
-
-// Loop over a function containinng a large add vector and distribute it so that
-// each iteration of the loop process part of the vector operation.
+// Large vector addf that can be broken down into a loop of smaller vector addf.
 func @main() {
+  %cf0 = constant 0.0 : f32
   %cf1 = constant 1.0 : f32
   %cf2 = constant 2.0 : f32
   %c0 = constant 0 : index
   %c1 = constant 1 : index
+  %c32 = constant 32 : index
   %c64 = constant 64 : index
   %out = alloc(%c64) : memref<?xf32>
   %in1 = call @alloc_1d_filled_inc_f32(%c64, %cf1) : (index, f32) -> memref<?xf32>
   %in2 = call @alloc_1d_filled_inc_f32(%c64, %cf2) : (index, f32) -> memref<?xf32>
-  scf.for %arg5 = %c0 to %c64 step %c1 {
-    call @vector_add_cycle(%arg5, %in1, %in2, %out) : (index, memref<?xf32>, memref<?xf32>, memref<?xf32>) -> ()
-  }
+  // Check that the tansformatio correctly happened.
+  // TRANSFORM: scf.for
+  // TRANSFORM:   vector.transfer_read {{.*}} : memref<?xf32>, vector<2xf32>
+  // TRANSFORM:   vector.transfer_read {{.*}} : memref<?xf32>, vector<2xf32>
+  // TRANSFORM:   %{{.*}} = addf %{{.*}}, %{{.*}} : vector<2xf32>
+  // TRANSFORM:   vector.transfer_write {{.*}} : vector<2xf32>, memref<?xf32>
+  // TRANSFORM: }
+  %a = vector.transfer_read %in1[%c0], %cf0: memref<?xf32>, vector<64xf32>
+  %b = vector.transfer_read %in2[%c0], %cf0: memref<?xf32>, vector<64xf32>
+  %acc = addf %a, %b: vector<64xf32>
+  vector.transfer_write %acc, %out[%c0]: vector<64xf32>, memref<?xf32>
   %converted = memref_cast %out : memref<?xf32> to memref<*xf32>
   call @print_memref_f32(%converted): (memref<*xf32>) -> ()
   // CHECK:      Unranked{{.*}}data =

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 1d8d0c6fc60f0..c24a1d5b85ec6 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2526,9 +2526,13 @@ struct TransferReadExtractPattern
       return failure();
     edsc::ScopedContext scope(rewriter, read.getLoc());
     using mlir::edsc::op::operator+;
+    using mlir::edsc::op::operator*;
     using namespace mlir::edsc::intrinsics;
     SmallVector<Value, 4> indices(read.indices().begin(), read.indices().end());
-    indices.back() = indices.back() + extract.id();
+    indices.back() =
+        indices.back() +
+        (extract.id() *
+         std_constant_index(extract.getResultType().getDimSize(0)));
     Value newRead = vector_transfer_read(extract.getType(), read.memref(),
                                          indices, read.permutation_map(),
                                          read.padding(), ArrayAttr());
@@ -2552,10 +2556,14 @@ struct TransferWriteInsertPattern
       return failure();
     edsc::ScopedContext scope(rewriter, write.getLoc());
     using mlir::edsc::op::operator+;
+    using mlir::edsc::op::operator*;
     using namespace mlir::edsc::intrinsics;
     SmallVector<Value, 4> indices(write.indices().begin(),
                                   write.indices().end());
-    indices.back() = indices.back() + insert.id();
+    indices.back() =
+        indices.back() +
+        (insert.id() *
+         std_constant_index(insert.getSourceVectorType().getDimSize(0)));
     vector_transfer_write(insert.vector(), write.memref(), indices,
                           write.permutation_map(), ArrayAttr());
     rewriter.eraseOp(write);

diff  --git a/mlir/test/Dialect/Vector/vector-distribution.mlir b/mlir/test/Dialect/Vector/vector-distribution.mlir
index 5fb32daa4ee8a..f93e96b2b902a 100644
--- a/mlir/test/Dialect/Vector/vector-distribution.mlir
+++ b/mlir/test/Dialect/Vector/vector-distribution.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32 | FileCheck %s
+// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32 -split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @distribute_vector_add
 //  CHECK-SAME: (%[[ID:.*]]: index
@@ -13,6 +13,8 @@ func @distribute_vector_add(%id : index, %A: vector<32xf32>, %B: vector<32xf32>)
   return %0: vector<32xf32>
 }
 
+// -----
+
 // CHECK-LABEL: func @vector_add_read_write
 //  CHECK-SAME: (%[[ID:.*]]: index
 //       CHECK:    %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32>
@@ -34,12 +36,19 @@ func @vector_add_read_write(%id : index, %A: memref<32xf32>, %B: memref<32xf32>,
   return
 }
 
-// CHECK-LABEL: func @vector_add_cycle
+// -----
+
+// CHECK-DAG: #[[MAP0:map[0-9]+]] = affine_map<()[s0] -> (s0 * 2)>
+
+//       CHECK: func @vector_add_cycle
 //  CHECK-SAME: (%[[ID:.*]]: index
-//       CHECK:    %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<64xf32>, vector<2xf32>
-//  CHECK-NEXT:    %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<64xf32>, vector<2xf32>
+//       CHECK:    %[[ID1:.*]] = affine.apply #[[MAP0]]()[%[[ID]]]
+//  CHECK-NEXT:    %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID1]]], %{{.*}} : memref<64xf32>, vector<2xf32>
+//  CHECK-NEXT:    %[[ID2:.*]] = affine.apply #[[MAP0]]()[%[[ID]]]
+//  CHECK-NEXT:    %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID2]]], %{{.*}} : memref<64xf32>, vector<2xf32>
 //  CHECK-NEXT:    %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<2xf32>
-//  CHECK-NEXT:    vector.transfer_write %[[ADD]], %{{.*}}[%[[ID]]] : vector<2xf32>, memref<64xf32>
+//  CHECK-NEXT:    %[[ID3:.*]] = affine.apply #[[MAP0]]()[%[[ID]]]
+//  CHECK-NEXT:    vector.transfer_write %[[ADD]], %{{.*}}[%[[ID3]]] : vector<2xf32>, memref<64xf32>
 //  CHECK-NEXT:    return
 func @vector_add_cycle(%id : index, %A: memref<64xf32>, %B: memref<64xf32>, %C: memref<64xf32>) {
   %c0 = constant 0 : index
@@ -51,6 +60,8 @@ func @vector_add_cycle(%id : index, %A: memref<64xf32>, %B: memref<64xf32>, %C:
   return
 }
 
+// -----
+
 // Negative test to make sure nothing is done in case the vector size is not a
 // multiple of multiplicity.
 // CHECK-LABEL: func @vector_negative_test

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 20903b3064806..e2a507fa1dfce 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -8,6 +8,7 @@
 
 #include <type_traits>
 
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/SCF/SCF.h"
@@ -185,6 +186,64 @@ struct TestVectorDistributePatterns
   }
 };
 
+struct TestVectorToLoopPatterns
+    : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
+  TestVectorToLoopPatterns() = default;
+  TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<VectorDialect>();
+    registry.insert<AffineDialect>();
+  }
+  Option<int32_t> multiplicity{
+      *this, "distribution-multiplicity",
+      llvm::cl::desc("Set the multiplicity used for distributing vector"),
+      llvm::cl::init(32)};
+  void runOnFunction() override {
+    MLIRContext *ctx = &getContext();
+    OwningRewritePatternList patterns;
+    FuncOp func = getFunction();
+    func.walk([&](AddFOp op) {
+      // Check that the operation type can be broken down into a loop.
+      VectorType type = op.getType().dyn_cast<VectorType>();
+      if (!type || type.getRank() != 1 ||
+          type.getNumElements() % multiplicity != 0)
+        return mlir::WalkResult::advance();
+      auto filterAlloc = [](Operation *op) {
+        if (isa<ConstantOp, AllocOp, CallOp>(op))
+          return false;
+        return true;
+      };
+      auto dependentOps = getSlice(op, filterAlloc);
+      // Create a loop and move instructions from the Op slice into the loop.
+      OpBuilder builder(op);
+      auto zero = builder.create<ConstantOp>(
+          op.getLoc(), builder.getIndexType(),
+          builder.getIntegerAttr(builder.getIndexType(), 0));
+      auto one = builder.create<ConstantOp>(
+          op.getLoc(), builder.getIndexType(),
+          builder.getIntegerAttr(builder.getIndexType(), 1));
+      auto numIter = builder.create<ConstantOp>(
+          op.getLoc(), builder.getIndexType(),
+          builder.getIntegerAttr(builder.getIndexType(), multiplicity));
+      auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
+      for (Operation *it : dependentOps) {
+        it->moveBefore(forOp.getBody()->getTerminator());
+      }
+      // break up the original op and let the patterns propagate.
+      Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
+          builder, op.getOperation(), forOp.getInductionVar(), multiplicity);
+      if (ops.hasValue()) {
+        SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
+        op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
+      }
+      return mlir::WalkResult::interrupt();
+    });
+    patterns.insert<PointwiseExtractPattern>(ctx);
+    populateVectorToVectorTransformationPatterns(patterns, ctx);
+    applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
+  }
+};
+
 struct TestVectorTransferUnrollingPatterns
     : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> {
   void getDependentDialects(DialectRegistry &registry) const override {
@@ -264,5 +323,8 @@ void registerTestVectorConversions() {
       "test-vector-distribute-patterns",
       "Test conversion patterns to distribute vector ops in the vector "
       "dialect");
+  PassRegistration<TestVectorToLoopPatterns> vectorToForLoop(
+      "test-vector-to-forloop",
+      "Test conversion patterns to break up a vector op into a for loop");
 }
 } // namespace mlir


        


More information about the Mlir-commits mailing list