[Mlir-commits] [mlir] 1694765 - [mlir][linalg] Extend linalg vectorization to support non-identity input maps

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 18 12:33:00 PDT 2021


Author: thomasraoux
Date: 2021-03-18T12:32:35-07:00
New Revision: 16947650d5ca602d63d5cd64e68bb0bb0f3674b7

URL: https://github.com/llvm/llvm-project/commit/16947650d5ca602d63d5cd64e68bb0bb0f3674b7
DIFF: https://github.com/llvm/llvm-project/commit/16947650d5ca602d63d5cd64e68bb0bb0f3674b7.diff

LOG: [mlir][linalg] Extend linalg vectorization to support non-identity input maps

This propagates the affine map to transfer_read op in case it is not a
minor identity map.

Differential Revision: https://reviews.llvm.org/D98523

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/IR/AffineMap.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 880e7f385724..dab32d2e2727 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -87,11 +87,14 @@ static VectorType extractVectorTypeFromShapedValue(Value v) {
 /// Build a vector.transfer_read from `source` at indices set to all `0`.
 /// If source has rank zero, build an memref.load.
 /// Return the produced value.
-static Value buildVectorRead(OpBuilder &builder, Value source) {
+static Value buildVectorRead(OpBuilder &builder, Value source,
+                             VectorType vectorType, AffineMap map) {
   edsc::ScopedContext scope(builder);
   auto shapedType = source.getType().cast<ShapedType>();
-  if (VectorType vectorType = extractVectorTypeFromShapedValue(source)) {
+  if (vectorType) {
     SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0));
+    if (map)
+      return vector_transfer_read(vectorType, source, indices, map);
     return vector_transfer_read(vectorType, source, indices);
   }
   return memref_load(source);
@@ -238,6 +241,51 @@ vectorizeOneOp(OpBuilder &builder, Operation *op,
                              builder.createOperation(state)};
 }
 
+/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
+static bool hasOnlyScalarElementwiseOp(Region &r) {
+  if (!llvm::hasSingleElement(r))
+    return false;
+  for (Operation &op : r.front()) {
+    if (!(isa<ConstantOp, linalg::YieldOp>(op) ||
+          OpTrait::hasElementwiseMappableTraits(&op)) ||
+        llvm::any_of(op.getResultTypes(),
+                     [](Type type) { return !type.isIntOrIndexOrFloat(); }))
+      return false;
+  }
+  return true;
+}
+
+// Return true if the op is an element-wise linalg op.
+static bool isElementwise(Operation *op) {
+  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
+  if (!linalgOp)
+    return false;
+  if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
+    return false;
+  // TODO: relax the restrictions on indexing map.
+  for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) {
+    if (!linalgOp.getOutputIndexingMap(i).isIdentity())
+      return false;
+  }
+  if (linalgOp->getNumRegions() != 1)
+    return false;
+  return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
+}
+
+// Calculate the map to apply to transfer_read to convert the input shape into
+// the output shape.
+static AffineMap getTransferReadMap(LinalgOp linalgOp, unsigned argIndex) {
+  AffineMap linalgMap = linalgOp.getIndexingMap(argIndex);
+  MLIRContext *context = linalgMap.getContext();
+  AffineExpr zero = mlir::getAffineConstantExpr(0, context);
+  SmallVector<AffineExpr, 4> exprs(linalgMap.getNumInputs(), zero);
+  for (unsigned i : llvm::seq(unsigned(0), linalgMap.getNumResults())) {
+    exprs[linalgMap.getDimPosition(i)] = getAffineDimExpr(i, context);
+  }
+  return AffineMap::get(linalgMap.getNumResults(), /*symbolCount=*/0, exprs,
+                        context);
+}
+
 /// Generic vectorization function that rewrites the body of a `linalgOp` into
 /// vector form. Generic vectorization proceeds as follows:
 ///   1. The region for the linalg op is created if necessary.
@@ -282,7 +330,19 @@ LogicalResult vectorizeAsLinalgGeneric(
   SmallVector<AffineMap> indexings;
   for (auto bbarg : block->getArguments()) {
     Value vectorArg = linalgOp.getShapedOperand(bbarg.getArgNumber());
-    Value vectorRead = buildVectorRead(builder, vectorArg);
+    AffineMap map;
+    VectorType vectorType = extractVectorTypeFromShapedValue(vectorArg);
+    if (isElementwise(linalgOp) &&
+        !linalgOp.getIndexingMap(bbarg.getArgNumber()).isMinorIdentity()) {
+      // Currently assume we don't support output permutations.
+      assert(linalgOp.getNumOutputs() > 0 &&
+             linalgOp.getOutputIndexingMap(0).isIdentity());
+      ArrayRef<int64_t> outputShape =
+          linalgOp.getOutputShapedType(0).getShape();
+      vectorType = VectorType::get(outputShape, vectorType.getElementType());
+      map = getTransferReadMap(linalgOp, bbarg.getArgNumber());
+    }
+    Value vectorRead = buildVectorRead(builder, vectorArg, vectorType, map);
     LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
                       << bbarg.getArgNumber() << "): " << vectorRead);
     bvm.map(bbarg, vectorRead);
@@ -316,44 +376,6 @@ LogicalResult vectorizeAsLinalgGeneric(
   return success();
 }
 
-/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
-static bool hasOnlyScalarElementwiseOp(Region &r) {
-  if (!llvm::hasSingleElement(r))
-    return false;
-  for (Operation &op : r.front()) {
-    if (!(isa<ConstantOp, linalg::YieldOp>(op) ||
-          OpTrait::hasElementwiseMappableTraits(&op)) ||
-        llvm::any_of(op.getResultTypes(),
-                     [](Type type) { return !type.isIntOrIndexOrFloat(); }))
-      return false;
-  }
-  return true;
-}
-
-// Return true if the op is an element-wise linalg op.
-static bool isElementwise(Operation *op) {
-  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
-  if (!linalgOp)
-    return false;
-  if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
-    return false;
-  // TODO: relax the restrictions on indexing map.
-  for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) {
-    if (!linalgOp.getOutputIndexingMap(i).isIdentity())
-      return false;
-  }
-  // Currently bound the input indexing map to minor identity as other
-  // permutations might require adding transpose ops to convert the vector read
-  // to the right shape.
-  for (unsigned i = 0, e = linalgOp.getNumInputs(); i < e; i++) {
-    if (!linalgOp.getInputIndexingMap(i).isMinorIdentity())
-      return false;
-  }
-  if (linalgOp->getNumRegions() != 1)
-    return false;
-  return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
-}
-
 static LogicalResult vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp,
                                           SmallVectorImpl<Value> &newResults) {
   assert(isaContractionOpInterface(linalgOp) &&

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 6ca28ba681ef..08bf7628e8c0 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2294,8 +2294,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
 
 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
   SmallVector<StringRef, 2> elidedAttrs;
-  if (op.permutation_map() ==
-      getTransferMinorIdentityMap(op.getShapedType(), op.getVectorType()))
+  if (op.permutation_map().isMinorIdentity())
     elidedAttrs.push_back(op.getPermutationMapAttrName());
   bool elideMasked = true;
   if (auto maybeMasked = op.masked()) {

diff  --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 9de80e96d451..98ca45bbb6f6 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -106,8 +106,9 @@ AffineMap AffineMap::getMinorIdentityMap(unsigned dims, unsigned results,
 }
 
 bool AffineMap::isMinorIdentity() const {
-  return *this ==
-         getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
+  return getNumDims() >= getNumResults() &&
+         *this ==
+             getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
 }
 
 /// Returns true if this affine map is a minor identity up to broadcasted

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index c43bf07d775d..74ff4367724e 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -341,6 +341,42 @@ func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>,
 
 // -----
 
+// Test 
diff erent input maps.
+#matmul_trait = {
+  indexing_maps = [
+    affine_map<(d0, d1, d2, d3) -> (d1, d0)>,
+    affine_map<(d0, d1, d2, d3) -> (d3, d1)>,
+    affine_map<(d0, d1, d2, d3) -> (d3, d1, d0, d2)>,
+    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+  ],
+  iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+}
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0, 0, 0)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (0, d1, 0, d0)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>
+//       CHECK: func @vectorization_transpose
+//       CHECK: vector.transfer_read {{.*}}{permutation_map = #[[MAP0]]} : memref<14x7xf32>, vector<7x14x8x16xf32>
+//       CHECK: vector.transfer_read {{.*}}{permutation_map = #[[MAP1]]} : memref<16x14xf32>, vector<7x14x8x16xf32>
+//       CHECK: vector.transfer_read {{.*}}{permutation_map = #[[MAP2]]} : memref<16x14x7x8xf32>, vector<7x14x8x16xf32>
+//       CHECK: addf {{.*}} : vector<7x14x8x16xf32>
+//       CHECK: addf {{.*}} : vector<7x14x8x16xf32>
+//       CHECK: vector.transfer_write {{.*}} : vector<7x14x8x16xf32>, memref<7x14x8x16xf32>
+func @vectorization_transpose(%A: memref<14x7xf32>, %B: memref<16x14xf32>,
+                         %C: memref<16x14x7x8xf32>, %D: memref<7x14x8x16xf32>) {
+  linalg.generic #matmul_trait
+    ins(%A, %B, %C : memref<14x7xf32>, memref<16x14xf32>, memref<16x14x7x8xf32>)
+   outs(%D : memref<7x14x8x16xf32>) {
+    ^bb(%a: f32, %b: f32, %c: f32, %d: f32) :
+      %e = addf %a, %b: f32
+      %f = addf %e, %c: f32
+      linalg.yield %f : f32
+  }
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @matmul_tensors
 //  CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>,
 //  CHECK-SAME:  %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32>


        


More information about the Mlir-commits mailing list