[Mlir-commits] [mlir] 56c638b - [mlir][Linalg] Generalize Vectorization of Linalg contractions

Nicolas Vasilache llvmlistbot at llvm.org
Fri Jul 10 07:29:46 PDT 2020


Author: Nicolas Vasilache
Date: 2020-07-10T10:28:34-04:00
New Revision: 56c638b5c1caf018da3fa1a95b603267e607c89c

URL: https://github.com/llvm/llvm-project/commit/56c638b5c1caf018da3fa1a95b603267e607c89c
DIFF: https://github.com/llvm/llvm-project/commit/56c638b5c1caf018da3fa1a95b603267e607c89c.diff

LOG: [mlir][Linalg] Generalize Vectorization of Linalg contractions

This revision adds support for vectorizing named and generic contraction ops to vector.contract. Cases in which the memref is 0-D are special cased to emit std.load/std.store instead of vector.transfer. Relevant tests are added.

Differential revision: https://reviews.llvm.org/D83307

Added: 
    

Modified: 
    mlir/include/mlir/IR/Attributes.h
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
    mlir/test/lib/Transforms/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 6436bb9550e8..89dad2ec40cf 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -286,6 +286,12 @@ class ArrayAttr : public Attribute::AttrBase<ArrayAttr, Attribute,
     return llvm::make_range(attr_value_iterator<AttrTy>(begin()),
                             attr_value_iterator<AttrTy>(end()));
   }
+  template <typename AttrTy, typename UnderlyingTy>
+  auto getAsRange() {
+    return llvm::map_range(getAsRange<AttrTy>(), [](AttrTy attr) {
+      return static_cast<UnderlyingTy>(attr.getValue());
+    });
+  }
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index bba7b2a10030..bbdb8e7b46b4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -36,8 +36,7 @@ using llvm::dbgs;
 
 #define DEBUG_TYPE "linalg-vectorization"
 
-static bool hasMultiplyAddBody(linalg::GenericOp op) {
-  auto &r = op.region();
+static bool hasMultiplyAddBody(Region &r) {
   if (!llvm::hasSingleElement(r))
     return false;
   if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3))
@@ -59,14 +58,26 @@ static bool hasMultiplyAddBody(linalg::GenericOp op) {
 }
 
 // TODO: Should be Tablegen'd from a single source that generates the op itself.
-static bool isRowMajorMatmul(linalg::GenericOp genericOp) {
-  return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
-         isRowMajorMatmul(genericOp.indexing_maps()) &&
-         hasMultiplyAddBody(genericOp);
+static LogicalResult isContraction(Operation *op) {
+  // TODO: interface for named ops.
+  if (isa<linalg::BatchMatmulOp, linalg::MatmulOp, linalg::MatvecOp,
+          linalg::DotOp>(op))
+    return success();
+
+  auto genericOp = dyn_cast<linalg::GenericOp>(op);
+  if (!genericOp)
+    return failure();
+
+  auto mapRange =
+      genericOp.indexing_maps().getAsRange<AffineMapAttr, AffineMap>();
+
+  return success(
+      genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
+      llvm::all_of(mapRange,
+                   [](AffineMap m) { return m.isProjectedPermutation(); }) &&
+      hasMultiplyAddBody(genericOp.region()));
 }
 
-// TODO: This is in fact much more general than just vectorization for matmul
-// and fill ops.
 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
   auto linalgOp = cast<linalg::LinalgOp>(op);
   // All types must be static shape to go to vector.
@@ -76,33 +87,16 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
   for (Type outputTensorType : linalgOp.getOutputTensorTypes())
     if (!outputTensorType.cast<ShapedType>().hasStaticShape())
       return failure();
-  if (isa<linalg::MatmulOp, linalg::FillOp>(op))
-    return success();
 
-  auto genericOp = dyn_cast<linalg::GenericOp>(op);
-  if (!genericOp || !::isRowMajorMatmul(genericOp))
-    return failure();
+  if (isa<linalg::FillOp>(op))
+    return success();
 
-  // TODO: non-identity layout.
-  auto isStaticMemRefWithIdentityLayout = [](Value v) {
-    auto m = v.getType().dyn_cast<MemRefType>();
-    if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty())
-      return false;
-    return true;
-  };
-  return success(llvm::all_of(genericOp.getInputsAndOutputBuffers(),
-                              isStaticMemRefWithIdentityLayout));
+  return isContraction(op);
 }
 
 void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
   assert(succeeded(vectorizeLinalgOpPrecondition(op)));
 
-  if (auto convOp = dyn_cast<linalg::ConvOp>(op)) {
-    // TODO: add a level of indirection to linalg.generic.
-    if (convOp.padding())
-      llvm_unreachable("Unexpected conv with padding");
-  }
-
   StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
   (void)dbgPref;
   edsc::ScopedContext scope(builder, op->getLoc());
@@ -117,33 +111,47 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
     return;
   }
 
-  // Vectorize other ops as vector contraction (currently only matmul).
+  assert(succeeded(isContraction(op)) && "Expected contraction");
+
+  // Vectorize other ops as vector contraction.
+  // TODO: interface.
   LLVM_DEBUG(dbgs() << dbgPref
                     << "Rewrite linalg op as vector.contract: " << *op);
+  // In the case of 0-D memrefs, return null and special case to scalar load or
+  // store later.
   auto extractVectorTypeFromScalarView = [](Value v) {
     MemRefType mt = v.getType().cast<MemRefType>();
-    return VectorType::get(mt.getShape(), mt.getElementType());
+    return mt.getShape().empty()
+               ? VectorType()
+               : VectorType::get(mt.getShape(), mt.getElementType());
   };
   auto linalgOp = cast<linalg::LinalgOp>(op);
   Value viewA = linalgOp.getInput(0);
   Value viewB = linalgOp.getInput(1);
   Value viewC = linalgOp.getOutputBuffer(0);
+  VectorType vtA = extractVectorTypeFromScalarView(viewA);
+  VectorType vtB = extractVectorTypeFromScalarView(viewB);
+  VectorType vtC = extractVectorTypeFromScalarView(viewC);
   Value zero = std_constant_index(0);
-  SmallVector<Value, 4> indicesA(linalgOp.getInputShapedType(0).getRank(),
-                                 zero);
-  SmallVector<Value, 4> indicesB(linalgOp.getInputShapedType(1).getRank(),
-                                 zero);
-  SmallVector<Value, 4> indicesC(linalgOp.getOutputShapedType(0).getRank(),
-                                 zero);
-  Value a = vector_transfer_read(extractVectorTypeFromScalarView(viewA), viewA,
-                                 indicesA);
-  Value b = vector_transfer_read(extractVectorTypeFromScalarView(viewB), viewB,
-                                 indicesB);
-  Value c = vector_transfer_read(extractVectorTypeFromScalarView(viewC), viewC,
-                                 indicesC);
+  SmallVector<Value, 4> indicesA, indicesB, indicesC;
+  if (vtA)
+    indicesA = SmallVector<Value, 4>(vtA.getRank(), zero);
+  if (vtB)
+    indicesB = SmallVector<Value, 4>(vtB.getRank(), zero);
+  if (vtC)
+    indicesC = SmallVector<Value, 4>(vtC.getRank(), zero);
+  Value a = vtA ? vector_transfer_read(vtA, viewA, indicesA).value
+                : std_load(viewA, indicesA).value;
+  Value b = vtB ? vector_transfer_read(vtB, viewB, indicesB).value
+                : std_load(viewB, indicesB).value;
+  Value c = vtC ? vector_transfer_read(vtC, viewC, indicesC).value
+                : std_load(viewC, indicesC).value;
   Value res = vector_contract(a, b, c, linalgOp.indexing_maps(),
                               linalgOp.iterator_types());
-  vector_transfer_write(res, viewC, indicesC);
+  if (vtC)
+    vector_transfer_write(res, viewC, indicesC);
+  else
+    std_store(res, viewC, indicesC);
 }
 
 /// Check whether there is any interleaved use of any `values` between `firstOp`

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
index cf75ee5691d0..b0702f9fdcfd 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s
 // RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-contraction-to-vector-patterns | FileCheck %s --check-prefix=VECTOR-CONTRACTION
 
 func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
                   %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
@@ -30,3 +31,38 @@ func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
 // CHECK-SAME:   : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>
 //
 //      CHECK: linalg.copy
+
+// VECTOR-CONTRACTION-LABEL: contraction_dot
+func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
+  // VECTOR-CONTRACTION: vector.contract
+  // VECTOR-CONTRACTION-SAME: vector<1584xf32>, vector<1584xf32> into f32
+  linalg.dot(%A, %B, %C) : memref<1584xf32>, memref<1584xf32>, memref<f32>
+  return
+}
+
+// VECTOR-CONTRACTION-LABEL: contraction_matvec
+func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
+  // VECTOR-CONTRACTION: vector.contract
+  // VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584xf32> into vector<1584xf32>
+  linalg.matvec %A, %B, %C :
+    (memref<1584x1584xf32>, memref<1584xf32>, memref<1584xf32>)
+  return
+}
+
+// VECTOR-CONTRACTION-LABEL: contraction_matmul
+func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
+  // VECTOR-CONTRACTION: vector.contract
+  // VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584x1584xf32> into vector<1584x1584xf32>
+  linalg.matmul %A, %B, %C :
+    (memref<1584x1584xf32>, memref<1584x1584xf32>, memref<1584x1584xf32>)
+  return
+}
+
+// VECTOR-CONTRACTION-LABEL: contraction_batch_matmul
+func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
+  // VECTOR-CONTRACTION: vector.contract
+  // VECTOR-CONTRACTION-SAME: vector<1584x1584x1584xf32>, vector<1584x1584x1584xf32> into vector<1584x1584x1584xf32>
+  linalg.batch_matmul %A, %B, %C :
+    (memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
+  return
+}

diff  --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index f93cd9faa504..4fb378c5ab8a 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -54,6 +54,11 @@ struct TestLinalgTransforms
       llvm::cl::desc(
           "Test a fused pass that forwards linalg.copy to vector.transfer"),
       llvm::cl::init(false)};
+  Option<bool> testGenericToVectorPattern{
+      *this, "test-contraction-to-vector-patterns",
+      llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
+                     "in vector.contract form"),
+      llvm::cl::init(false)};
 };
 } // end anonymous namespace
 
@@ -300,6 +305,16 @@ static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
   applyPatternsAndFoldGreedily(funcOp, forwardPattern);
 }
 
+static void applyContractionToVectorPatterns(FuncOp funcOp) {
+  OwningRewritePatternList patterns;
+  patterns.insert<LinalgVectorizationPattern<BatchMatmulOp>,
+                  LinalgVectorizationPattern<MatmulOp>,
+                  LinalgVectorizationPattern<MatvecOp>,
+                  LinalgVectorizationPattern<DotOp>,
+                  LinalgVectorizationPattern<GenericOp>>(funcOp.getContext());
+  applyPatternsAndFoldGreedily(funcOp, patterns);
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnFunction() {
   auto lambda = [&](void *) {
@@ -323,6 +338,8 @@ void TestLinalgTransforms::runOnFunction() {
                                        testMatmulToVectorPatterns2dTiling);
   if (testVectorTransferForwardingPatterns)
     return applyVectorTransferForwardingPatterns(getFunction());
+  if (testGenericToVectorPattern)
+    return applyContractionToVectorPatterns(getFunction());
 }
 
 namespace mlir {


        


More information about the Mlir-commits mailing list