[Mlir-commits] [mlir] c503dc1 - [mlir][linalg] Add vectorization for element-wise linalg ops

Thomas Raoux llvmlistbot at llvm.org
Thu Dec 3 15:31:59 PST 2020


Author: Thomas Raoux
Date: 2020-12-03T15:31:13-08:00
New Revision: c503dc1b8a52946e4daefa1a266e74a102382971

URL: https://github.com/llvm/llvm-project/commit/c503dc1b8a52946e4daefa1a266e74a102382971
DIFF: https://github.com/llvm/llvm-project/commit/c503dc1b8a52946e4daefa1a266e74a102382971.diff

LOG: [mlir][linalg] Add vectorization for element-wise linalg ops

Add support for vectorization for linalg.generic representing element-wise ops.
Those are converted to transfer_read + vector ops + transfer_write.
Also re-organize the vectorization tests to be together.

Implementation derived from the work of @burmako, @agrue and
@fedelebron.

Differential Revision: https://reviews.llvm.org/D92540

Added: 
    mlir/test/Dialect/Linalg/vectorization.mlir

Modified: 
    mlir/include/mlir/EDSC/Builders.h
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/EDSC/Builders.cpp
    mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
    mlir/test/Dialect/Linalg/transform-patterns.mlir
    mlir/test/lib/Transforms/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h
index 70c948d99cda8..83b6634bf8e22 100644
--- a/mlir/include/mlir/EDSC/Builders.h
+++ b/mlir/include/mlir/EDSC/Builders.h
@@ -30,6 +30,7 @@ namespace edsc {
 /// setting and restoring of insertion points.
 class ScopedContext {
 public:
+  ScopedContext(OpBuilder &b);
   ScopedContext(OpBuilder &b, Location location);
 
   /// Sets the insertion point of the builder to 'newInsertPt' for the duration

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 8860674ef8474..a28b90b1d95c5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -84,6 +84,195 @@ static LogicalResult isContraction(Operation *op) {
       hasMultiplyAddBody(genericOp.region()));
 }
 
+static bool hasOnlyScalarElementwiseOp(Region &r) {
+  if (!llvm::hasSingleElement(r))
+    return false;
+  for (Operation &op : r.front()) {
+    if (!(isa<ConstantOp, linalg::YieldOp>(op) ||
+          op.hasTrait<OpTrait::ElementwiseMappable>()) ||
+        llvm::any_of(op.getResultTypes(),
+                     [](Type type) { return !type.isIntOrIndexOrFloat(); }))
+      return false;
+  }
+  return true;
+}
+
+// Return true if the op is an element-wise linalg op.
+static bool isElementwise(Operation *op) {
+  auto genericOp = dyn_cast<linalg::GenericOp>(op);
+  if (!genericOp)
+    return false;
+  if (genericOp.getNumLoops() != genericOp.getNumParallelLoops())
+    return false;
+  // TODO: relax the restrictions on indexing map.
+  for (unsigned i = 0, e = genericOp.getNumOutputs(); i < e; i++) {
+    if (!genericOp.getOutputIndexingMap(i).isIdentity())
+      return false;
+  }
+  // Currently limit the input indexing map to minor identity as other
+  // permutations might require adding transpose ops to convert the vector read
+  // to the right shape.
+  for (unsigned i = 0, e = genericOp.getNumInputs(); i < e; i++) {
+    if (!genericOp.getInputIndexingMap(i).isMinorIdentity())
+      return false;
+  }
+  return hasOnlyScalarElementwiseOp(genericOp.getRegion());
+}
+
+static VectorType extractVectorTypeFromScalarView(Value v) {
+  MemRefType mt = v.getType().cast<MemRefType>();
+  return mt.getShape().empty()
+             ? VectorType()
+             : VectorType::get(mt.getShape(), mt.getElementType());
+}
+
+static Value transferReadVector(OpBuilder &builder, Value memref) {
+  edsc::ScopedContext scope(builder);
+  auto memrefType = memref.getType().cast<MemRefType>();
+  if (VectorType vectorType = extractVectorTypeFromScalarView(memref)) {
+    SmallVector<Value, 4> indices(memrefType.getRank(), std_constant_index(0));
+    return vector_transfer_read(vectorType, memref, indices);
+  }
+  return std_load(memref);
+}
+
+static void transferWriteVector(OpBuilder &builder, Value value, Value memref) {
+  edsc::ScopedContext scope(builder);
+  auto memrefType = memref.getType().cast<MemRefType>();
+  if (VectorType vectorType = extractVectorTypeFromScalarView(memref)) {
+    SmallVector<Value, 4> indices(memrefType.getRank(), std_constant_index(0));
+    if (vectorType != value.getType())
+      value = vector_broadcast(vectorType, value);
+    vector_transfer_write(value, memref, indices);
+  } else {
+    std_store(value, memref);
+  }
+}
+
+namespace {
+// Transforms scalar operations into their vectorized counterparts,
+// while using the provided generic op to map:
+//   * Its arguments to transfer reads from the views of the generic op.
+//   * linalg.yield ops to transfer writes to the views of the generic op.
+class GenericVectorizer {
+public:
+  GenericVectorizer(OpBuilder &builder, linalg::GenericOp generic)
+      : builder(builder), generic(generic) {}
+
+  // Takes a scalar operation and builds its vectorized counterpart or
+  // counterparts using the underlying builder.
+  // If operands of the scalar operation are referring to previously vectorized
+  // operations, then in their vectorized form these operands will be referring
+  // to previous vectorization results.
+  void vectorize(Operation &scalarOp) {
+    auto yieldOp = dyn_cast<linalg::YieldOp>(scalarOp);
+    if (yieldOp) {
+      for (auto outputAndMemref :
+           llvm::zip(yieldOp.values(), generic.getOutputBuffers())) {
+        Value vectorValue = vectorize(std::get<0>(outputAndMemref));
+        transferWriteVector(builder, vectorValue, std::get<1>(outputAndMemref));
+      }
+      return;
+    }
+    Operation *vectorOp = uncachedVectorize(scalarOp);
+    assert(scalarOp.getNumResults() == vectorOp->getNumResults());
+    for (auto result :
+         llvm::zip(scalarOp.getResults(), vectorOp->getResults())) {
+      valueCache[std::get<0>(result)] = std::get<1>(result);
+    }
+  }
+
+private:
+  // Transforms a scalar value into its vectorized counterpart, recursively
+  // vectorizing operations as necessary using the underlying builder.
+  // Keeps track of previously vectorized values and reuses vectorization
+  // results if these values come up again.
+  Value vectorize(Value scalarValue) {
+    // Don't vectorize values coming from outside the region.
+    if (scalarValue.getParentRegion() != &generic.region())
+      return scalarValue;
+    auto vectorValueIt = valueCache.find(scalarValue);
+    if (vectorValueIt != valueCache.end())
+      return vectorValueIt->second;
+
+    // If the value is from the region but not in the cache it means it is a
+    // block argument.
+    auto scalarArg = scalarValue.cast<BlockArgument>();
+    assert(scalarArg.getOwner() == &generic.region().front());
+    Value vector_arg =
+        generic.getInputsAndOutputBuffers()[scalarArg.getArgNumber()];
+    Value vectorResult = transferReadVector(builder, vector_arg);
+    valueCache[scalarArg] = vectorResult;
+    return vectorResult;
+  }
+
+  // Return the largest shape of all the given values. Return an empty
+  // SmallVector if there are no vector value.
+  static SmallVector<int64_t, 4> getLargestShape(ArrayRef<Value> values) {
+    SmallVector<int64_t, 4> largestShape;
+    int64_t maxSize = 1;
+    for (Value value : values) {
+      auto vecType = value.getType().dyn_cast<VectorType>();
+      if (!vecType)
+        continue;
+      if (maxSize < vecType.getNumElements()) {
+        largestShape.assign(vecType.getShape().begin(),
+                            vecType.getShape().end());
+      }
+    }
+    return largestShape;
+  }
+
+  // If the value's type doesn't have the given shape broadcast it.
+  Value broadcastIfNeeded(Value value, ArrayRef<int64_t> shape) {
+    auto vecType = value.getType().dyn_cast<VectorType>();
+    if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape))
+      return value;
+    auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
+                                                     : value.getType());
+    return builder.create<vector::BroadcastOp>(
+        builder.getInsertionPoint()->getLoc(), newVecType, value);
+  }
+
+  // Takes a scalar operation and builds its vectorized counterpart or
+  // counterparts using underlying builder without involving any caches.
+  Operation *uncachedVectorize(Operation &base_scalarOp) {
+    SmallVector<Value, 4> vectorizedOperands;
+    for (Value operand : base_scalarOp.getOperands()) {
+      vectorizedOperands.push_back(vectorize(operand));
+    }
+    SmallVector<int64_t, 4> shape = getLargestShape(vectorizedOperands);
+    for (Value &operand : vectorizedOperands)
+      operand = broadcastIfNeeded(operand, shape);
+    OperationState state(base_scalarOp.getLoc(), base_scalarOp.getName());
+    state.addAttributes(base_scalarOp.getAttrs());
+    state.addOperands(vectorizedOperands);
+    if (shape.empty()) {
+      state.addTypes(base_scalarOp.getResultTypes());
+    } else {
+      SmallVector<VectorType, 4> vectorizedTypes;
+      for (auto Type : base_scalarOp.getResultTypes())
+        vectorizedTypes.push_back(VectorType::get(shape, Type));
+      state.addTypes(vectorizedTypes);
+    }
+    return builder.createOperation(state);
+  }
+
+  OpBuilder &builder;
+  linalg::GenericOp generic;
+  llvm::DenseMap<Value, Value> valueCache;
+};
+} // namespace
+
+// Replaces elementwise linalg.generic ops with their bodies with scalar
+// operations from these bodies promoted to vector operations.
+static void vectorizeElementwise(linalg::GenericOp op, OpBuilder &builder) {
+  GenericVectorizer vectorizer(builder, op);
+  for (Operation &scalarOp : op.region().front()) {
+    vectorizer.vectorize(scalarOp);
+  }
+}
+
 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
   auto linalgOp = cast<linalg::LinalgOp>(op);
   // All types must be static shape to go to vector.
@@ -96,7 +285,8 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
 
   if (isa<linalg::FillOp, linalg::CopyOp>(op))
     return success();
-
+  if (isElementwise(op))
+    return success();
   return isContraction(op);
 }
 
@@ -108,28 +298,11 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
   edsc::ScopedContext scope(builder, op->getLoc());
   // In the case of 0-D memrefs, return null and special case to scalar load or
   // store later.
-  auto extractVectorTypeFromScalarView = [](Value v) {
-    MemRefType mt = v.getType().cast<MemRefType>();
-    return mt.getShape().empty()
-               ? VectorType()
-               : VectorType::get(mt.getShape(), mt.getElementType());
-  };
   if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
     // Vectorize fill as a vector.broadcast.
     LLVM_DEBUG(dbgs() << dbgPref
                       << "Rewrite linalg.fill as vector.broadcast: " << *op);
-    Value viewOutput = fillOp.output();
-    if (VectorType outputType = extractVectorTypeFromScalarView(viewOutput)) {
-      auto vecType =
-          VectorType::get(fillOp.getOutputBufferType(0).getShape(),
-                          fillOp.getOutputBufferType(0).getElementType());
-      Value vector = vector_broadcast(vecType, fillOp.value());
-      Value zero = std_constant_index(0);
-      SmallVector<Value, 4> indicesOutput(outputType.getRank(), zero);
-      vector_transfer_write(vector, viewOutput, indicesOutput);
-    } else {
-      std_store(fillOp.value(), viewOutput);
-    }
+    transferWriteVector(builder, fillOp.value(), fillOp.output());
     return;
   }
   if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
@@ -138,36 +311,19 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
                       << "Rewrite linalg.copy as vector.transfer_read + "
                          "vector.transfer_write: "
                       << *op);
-    Value zero = std_constant_index(0);
-    Value viewInput = copyOp.input();
-    Value viewOutput = copyOp.output();
-    Value vector;
-    if (VectorType inputType = extractVectorTypeFromScalarView(viewInput)) {
-      SmallVector<Value, 4> indicesInput(inputType.getRank(), zero);
-      if (copyOp.inputPermutation())
-        vector = vector_transfer_read(
-            extractVectorTypeFromScalarView(viewInput), viewInput, indicesInput,
-            copyOp.inputPermutation().getValue());
-      else
-        vector =
-            vector_transfer_read(extractVectorTypeFromScalarView(viewInput),
-                                 viewInput, indicesInput);
-    } else {
-      vector = std_load(viewInput).value;
-    }
-    if (VectorType outputType = extractVectorTypeFromScalarView(viewOutput)) {
-      SmallVector<Value, 4> indicesOutput(outputType.getRank(), zero);
-      if (copyOp.outputPermutation())
-        vector_transfer_write(vector, viewOutput, indicesOutput,
-                              copyOp.outputPermutation().getValue());
-      else
-        vector_transfer_write(vector, viewOutput, indicesOutput);
-    } else {
-      std_store(vector, viewOutput);
-    }
+    Value vector = transferReadVector(builder, copyOp.input());
+    transferWriteVector(builder, vector, copyOp.output());
     return;
   }
 
+  if (isElementwise(op)) {
+    LLVM_DEBUG(dbgs() << dbgPref
+                      << "Rewrite linalg op as vector.transfer_read + "
+                         "vector_op + vector.transfer_write: "
+                      << *op);
+    return vectorizeElementwise(cast<linalg::GenericOp>(op), builder);
+  }
+
   assert(succeeded(isContraction(op)) && "Expected contraction");
 
   // Vectorize other ops as vector contraction.

diff  --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp
index 54086c9263730..21a6b922d91fc 100644
--- a/mlir/lib/EDSC/Builders.cpp
+++ b/mlir/lib/EDSC/Builders.cpp
@@ -15,6 +15,9 @@
 using namespace mlir;
 using namespace mlir::edsc;
 
+mlir::edsc::ScopedContext::ScopedContext(OpBuilder &b)
+    : ScopedContext(b, b.getInsertionPoint()->getLoc()) {}
+
 mlir::edsc::ScopedContext::ScopedContext(OpBuilder &b, Location location)
     : builder(b), guard(builder), location(location),
       enclosingScopedContext(ScopedContext::getCurrentScopedContext()) {

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
index 155247a538069..dbdf193419207 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
@@ -1,6 +1,5 @@
 // RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s
 // RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s
-// RUN: mlir-opt %s -test-linalg-transform-patterns=test-contraction-to-vector-patterns | FileCheck %s --check-prefix=VECTOR-CONTRACTION
 
 func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
                   %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
@@ -26,40 +25,3 @@ func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
 // CHECK-SAME:   : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>
 //
 //      CHECK: linalg.copy
-
-// VECTOR-CONTRACTION-LABEL: contraction_dot
-func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
-  // VECTOR-CONTRACTION: vector.contract
-  // VECTOR-CONTRACTION-SAME: vector<1584xf32>, vector<1584xf32> into f32
-  linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>)
-            outs(%C: memref<f32>)
-  return
-}
-
-// VECTOR-CONTRACTION-LABEL: contraction_matvec
-func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
-  // VECTOR-CONTRACTION: vector.contract
-  // VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584xf32> into vector<1584xf32>
-  linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>)
-            outs(%C: memref<1584xf32>)
-  return
-}
-
-// VECTOR-CONTRACTION-LABEL: contraction_matmul
-func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
-  // VECTOR-CONTRACTION: vector.contract
-  // VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584x1584xf32> into vector<1584x1584xf32>
-  linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>)
-            outs(%C: memref<1584x1584xf32>)
-  return
-}
-
-// VECTOR-CONTRACTION-LABEL: contraction_batch_matmul
-func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
-  // VECTOR-CONTRACTION: vector.contract
-  // VECTOR-CONTRACTION-SAME: vector<1584x1584x1584xf32>, vector<1584x1584x1584xf32> into vector<1584x1584x1584xf32>
-  linalg.batch_matmul
-    ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
-   outs(%C: memref<1584x1584x1584xf32>)
-  return
-}

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 9bdc4ad548267..83cb16ba0e3eb 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -5,9 +5,7 @@
 // CHECK-DAG: #[[$STRIDED_2D_u_1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
 // Map corresponding to a 2D memory access where the stride along all dims are unknown.
 // CHECK-DAG: #[[$STRIDED_2D:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
-// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
 // CHECK-DAG: #[[$kn:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
-// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 // CHECK-DAG: #[[$nm:.*]] = affine_map<(d0, d1, d2) -> (d1, d0)>
 // CHECK-DAG: #[[$km:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)>
 
@@ -92,99 +90,6 @@ func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 // CHECK:                                   ins({{.*}}, {{.*}}: memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?x?xf32, #[[$STRIDED_2D]]>)
 // CHECK:                                  outs({{.*}}: memref<?x?xf32, #[[$STRIDED_2D]]>)
 
-#matmul_trait = {
-  args_in = 2,
-  args_out = 1,
-  indexing_maps = [
-    affine_map<(m, n, k) -> (m, k)>,
-    affine_map<(m, n, k) -> (k, n)>,
-    affine_map<(m, n, k) -> (m, n)>
-  ],
-  iterator_types = ["parallel", "parallel", "reduction"],
-  __internal_linalg_transform__ = "VECTORIZE"
-}
-func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
-                         %C: memref<8x32xf32>) {
-  linalg.generic #matmul_trait
-    ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>)
-   outs(%C : memref<8x32xf32>) {
-    ^bb(%a: f32, %b: f32, %c: f32) :
-      %d = mulf %a, %b: f32
-      %e = addf %c, %d: f32
-      linalg.yield %e : f32
-  }
-  return
-}
-// CHECK-LABEL: func @vectorization_test
-//       CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32>
-//       CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32>
-//       CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
-//       CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
-//       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
-
-func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
-                                 %C: memref<8x32xi32>) {
-  linalg.generic #matmul_trait
-    ins(%A, %B : memref<8x16xi32>, memref<16x32xi32>)
-   outs(%C : memref<8x32xi32>) {
-    ^bb(%a: i32, %b: i32, %c: i32) :
-      %d = muli %a, %b: i32
-      %e = addi %c, %d: i32
-      linalg.yield %e : i32
-  }
-  return
-}
-// CHECK-LABEL: func @vectorization_test_integer
-//       CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32>
-//       CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<16x32xi32>
-//       CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
-//       CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32>
-//       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
-
-func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
-                         %C: memref<8x32xf32>) {
-  linalg.matmul { __internal_linalg_transform__ = "VECTORIZE"}
-    ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>)
-   outs(%C: memref<8x32xf32>)
-  return
-}
-// CHECK-LABEL: func @vectorization_test_2
-//       CHECK: vector.contract {{.*}} :
-//                vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
-
-func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
-  linalg.fill(%A, %arg0) { __internal_linalg_transform__ = "VECTORIZE"} :  memref<8x16xf32>, f32
-  return
-}
-// CHECK-LABEL: func @test_vectorize_fill
-//       CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
-//       CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
-
-func @test_vectorize_fill_scalar(%A : memref<f32>, %arg0 : f32) {
-  linalg.fill(%A, %arg0) { __internal_linalg_transform__ = "VECTORIZE"} :  memref<f32>, f32
-  return
-}
-// CHECK-LABEL: func @test_vectorize_fill
-//  CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[V:.*]]: f32)
-//       CHECK:   store %[[V]], %[[M]][] : memref<f32>
-
-func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
-  linalg.copy(%A, %B) { __internal_linalg_transform__ = "VECTORIZE"} :  memref<8x16xf32>, memref<8x16xf32>
-  return
-}
-// CHECK-LABEL: func @test_vectorize_copy
-//       CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32>
-//       CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
-
-func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
-  linalg.copy(%A, %B) { __internal_linalg_transform__ = "VECTORIZE"} :  memref<f32>, memref<f32>
-  return
-}
-// CHECK-LABEL: func @test_vectorize_copy_scalar
-//       CHECK: %[[V:.*]] = load {{.*}} : memref<f32>
-//       CHECK: store %[[V]], {{.*}} : memref<f32>
-
-
 #matmul_accesses = [
   affine_map<(m, n, k) -> (m, k)>,
   affine_map<(m, n, k) -> (k, n)>,

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
new file mode 100644
index 0000000000000..1c3533275e491
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -0,0 +1,210 @@
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns | FileCheck %s
+
+// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$kn:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: contraction_dot
+func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
+  // CHECK: vector.contract
+  // CHECK-SAME: vector<1584xf32>, vector<1584xf32> into f32
+  linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>)
+            outs(%C: memref<f32>)
+  return
+}
+
+// CHECK-LABEL: contraction_matvec
+func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
+  // CHECK: vector.contract
+  // CHECK-SAME: vector<1584x1584xf32>, vector<1584xf32> into vector<1584xf32>
+  linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>)
+            outs(%C: memref<1584xf32>)
+  return
+}
+
+// CHECK-LABEL: contraction_matmul
+func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
+  // CHECK: vector.contract
+  // CHECK-SAME: vector<1584x1584xf32>, vector<1584x1584xf32> into vector<1584x1584xf32>
+  linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>)
+            outs(%C: memref<1584x1584xf32>)
+  return
+}
+
+// CHECK-LABEL: contraction_batch_matmul
+func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
+  // CHECK: vector.contract
+  // CHECK-SAME: vector<1584x1584x1584xf32>, vector<1584x1584x1584xf32> into vector<1584x1584x1584xf32>
+  linalg.batch_matmul
+    ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
+   outs(%C: memref<1584x1584x1584xf32>)
+  return
+}
+
+#matmul_trait = {
+  args_in = 2,
+  args_out = 1,
+  indexing_maps = [
+    affine_map<(m, n, k) -> (m, k)>,
+    affine_map<(m, n, k) -> (k, n)>,
+    affine_map<(m, n, k) -> (m, n)>
+  ],
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
+                         %C: memref<8x32xf32>) {
+  linalg.generic #matmul_trait
+    ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>)
+   outs(%C : memref<8x32xf32>) {
+    ^bb(%a: f32, %b: f32, %c: f32) :
+      %d = mulf %a, %b: f32
+      %e = addf %c, %d: f32
+      linalg.yield %e : f32
+  }
+  return
+}
+// CHECK-LABEL: func @vectorization_test
+//       CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32>
+//       CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32>
+//       CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
+//       CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
+//       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
+
+func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
+                                 %C: memref<8x32xi32>) {
+  linalg.generic #matmul_trait
+    ins(%A, %B : memref<8x16xi32>, memref<16x32xi32>)
+   outs(%C : memref<8x32xi32>) {
+    ^bb(%a: i32, %b: i32, %c: i32) :
+      %d = muli %a, %b: i32
+      %e = addi %c, %d: i32
+      linalg.yield %e : i32
+  }
+  return
+}
+// CHECK-LABEL: func @vectorization_test_integer
+//       CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32>
+//       CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<16x32xi32>
+//       CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
+//       CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32>
+//       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
+
+func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
+                         %C: memref<8x32xf32>) {
+  linalg.matmul
+    ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>)
+   outs(%C: memref<8x32xf32>)
+  return
+}
+// CHECK-LABEL: func @vectorization_test_2
+//       CHECK: vector.contract {{.*}} :
+//                vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
+
+func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
+  linalg.fill(%A, %arg0) :  memref<8x16xf32>, f32
+  return
+}
+// CHECK-LABEL: func @test_vectorize_fill
+//       CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
+//       CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
+
+func @test_vectorize_fill_scalar(%A : memref<f32>, %arg0 : f32) {
+  linalg.fill(%A, %arg0) :  memref<f32>, f32
+  return
+}
+// CHECK-LABEL: func @test_vectorize_fill
+//  CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[V:.*]]: f32)
+//       CHECK:   store %[[V]], %[[M]][] : memref<f32>
+
+func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
+  linalg.copy(%A, %B) :  memref<8x16xf32>, memref<8x16xf32>
+  return
+}
+// CHECK-LABEL: func @test_vectorize_copy
+//       CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32>
+//       CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
+
+func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
+  linalg.copy(%A, %B) :  memref<f32>, memref<f32>
+  return
+}
+// CHECK-LABEL: func @test_vectorize_copy_scalar
+//       CHECK: %[[V:.*]] = load {{.*}} : memref<f32>
+//       CHECK: store %[[V]], {{.*}} : memref<f32>
+
+func @generic_vectorize(%arg0: memref<4x256xf32>, %arg1: memref<4x256xf32>,
+  %arg2: memref<256xf32>, %i: f32) {
+  %c1_f32 = constant 1.0 : f32
+  linalg.generic {
+    args_in = 0 : i64,
+    args_out = 10 : i64,
+    indexing_maps = [
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d1)>,
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]}
+  ins(%arg1, %arg2: memref<4x256xf32>, memref<256xf32>)
+  outs(
+    %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0 :
+    memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>,
+    memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>,
+    memref<4x256xf32>, memref<4x256xf32>) {
+  ^bb0(%arg3 : f32, %arg4 : f32, %arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32,
+    %arg9 : f32, %arg10 : f32, %arg11 : f32, %arg12 : f32, %arg13 : f32,
+    %arg14 : f32):
+    %6 = addf %arg4, %arg6 : f32
+    %7 = cmpf "ogt", %arg3, %arg6 : f32
+    %8 = constant 2.0 : f32
+    %9 = divf %arg5, %i : f32
+    %10 = exp2 %arg5 : f32
+    %11 = mulf %arg5, %8 : f32
+    %12 = rsqrt %arg5 : f32
+    %13 = select %7, %arg5, %arg6 : f32
+    %14 = subf %arg5, %arg6 : f32
+    %15 = tanh %arg5 : f32
+    linalg.yield %6, %8, %c1_f32, %9, %10, %11, %12, %13, %14, %15 : f32, f32,
+      f32, f32, f32, f32, f32, f32, f32, f32
+  }
+  return
+}
+
+// CHECK-LABEL: func @generic_vectorize
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<4x256xf32>, %[[ARG1:.*]]: memref<4x256xf32>,
+//  CHECK-SAME:  %[[ARG2:.*]]: memref<256xf32>, %[[ARG3:.*]]: f32)
+//   CHECK-DAG:   %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32>
+//   CHECK-DAG:   %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32>
+//   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
+//       CHECK:   %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : memref<256xf32>, vector<256xf32>
+//       CHECK:   %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
+//       CHECK:   %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
+//       CHECK:   %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
+//       CHECK:   %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
+//       CHECK:   %[[CMP:.*]] = cmpf "ogt", %[[V2]], %[[V1]] : vector<4x256xf32>
+//       CHECK:   %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> 
+//       CHECK:   %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
+//       CHECK:   %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
+//       CHECK:   %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
+//       CHECK:   %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32>
+//       CHECK:   %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32>
+//       CHECK:   %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32>
+//       CHECK:   %[[SUB:.*]] = subf %[[V3]], %[[V1]] : vector<4x256xf32>
+//       CHECK:   %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32>
+//       CHECK:   vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+//       CHECK:   vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+//       CHECK:   vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+//       CHECK:   vector.transfer_write %[[DIV]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+//       CHECK:   vector.transfer_write %[[EXP]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+//       CHECK:   vector.transfer_write %[[MUL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+//       CHECK:   vector.transfer_write %[[RSQRT]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+//       CHECK:   vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+//       CHECK:   vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+//       CHECK:   vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>

diff  --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index 52e96dc44e0bd..9e3efcf416649 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -71,7 +71,7 @@ struct TestLinalgTransforms
           "Test a fused pass that forwards linalg.copy to vector.transfer"),
       llvm::cl::init(false)};
   Option<bool> testGenericToVectorPattern{
-      *this, "test-contraction-to-vector-patterns",
+      *this, "test-linalg-to-vector-patterns",
       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
                      "in vector.contract form"),
       llvm::cl::init(false)};
@@ -464,14 +464,15 @@ static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
   applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
 }
 
-static void applyContractionToVectorPatterns(FuncOp funcOp) {
+static void applyLinalgToVectorPatterns(FuncOp funcOp) {
   OwningRewritePatternList patterns;
-  patterns.insert<LinalgVectorizationPattern<BatchMatmulOp>,
-                  LinalgVectorizationPattern<MatmulOp>,
-                  LinalgVectorizationPattern<MatvecOp>,
-                  LinalgVectorizationPattern<VecmatOp>,
-                  LinalgVectorizationPattern<DotOp>,
-                  LinalgVectorizationPattern<GenericOp>>(funcOp.getContext());
+  patterns.insert<
+      LinalgVectorizationPattern<BatchMatmulOp>,
+      LinalgVectorizationPattern<MatmulOp>,
+      LinalgVectorizationPattern<MatvecOp>,
+      LinalgVectorizationPattern<VecmatOp>, LinalgVectorizationPattern<DotOp>,
+      LinalgVectorizationPattern<FillOp>, LinalgVectorizationPattern<CopyOp>,
+      LinalgVectorizationPattern<GenericOp>>(funcOp.getContext());
   applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
@@ -516,7 +517,7 @@ void TestLinalgTransforms::runOnFunction() {
   if (testVectorTransferForwardingPatterns)
     return applyVectorTransferForwardingPatterns(getFunction());
   if (testGenericToVectorPattern)
-    return applyContractionToVectorPatterns(getFunction());
+    return applyLinalgToVectorPatterns(getFunction());
   if (testAffineMinSCFCanonicalizationPatterns)
     return applyAffineMinSCFCanonicalizationPatterns(getFunction());
 }


        


More information about the Mlir-commits mailing list