[Mlir-commits] [mlir] 9621c1e - [mlir][linalg] Fix vectorization bug in vector transfer indexing map calculation

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 3 12:16:55 PDT 2021


Author: thomasraoux
Date: 2021-05-03T12:16:38-07:00
New Revision: 9621c1ef56c568ffe2db903af2f61137a4453430

URL: https://github.com/llvm/llvm-project/commit/9621c1ef56c568ffe2db903af2f61137a4453430
DIFF: https://github.com/llvm/llvm-project/commit/9621c1ef56c568ffe2db903af2f61137a4453430.diff

LOG: [mlir][linalg] Fix vectorization bug in vector transfer indexing map calculation

The current implementation had a bug as it was relying on the target vector
dimension sizes to calculate where to insert broadcast. If several dimensions
have the same size we may insert the broadcast on the wrong dimension. The
correct broadcast cannot be inferred from the type of the source and
destination vector.

Instead when we want to extend transfer ops we calculate an "inverse" map to the
projected permutation and insert broadcast in place of the projected dimensions.

Differential Revision: https://reviews.llvm.org/D101738

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/include/mlir/IR/AffineMap.h
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/IR/AffineMap.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 3d83bcd0aa57a..45c0ccaa09280 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1383,18 +1383,6 @@ def Vector_TransferReadOp :
       "ArrayAttr":$inBounds)>
   ];
 
-  let extraClassDeclaration = [{
-    /// Return a new `result` map with `0` inserted in the proper positions so 
-    /// that vector.transfer_read `result` produces a vector of same element 
-    /// type as `vt` and shape `targetShape.
-    /// Assume that `map` is a permutation map for a vector.transfer_read op, 
-    /// `vt` the vector type produced by the vector.transfer_read and 
-    /// `targetShape` is the desired `targetShape` for a broadcast version of 
-    /// `vt`.
-    static AffineMap insertBroadcasts(AffineMap map, VectorType vt,
-                                      ArrayRef<int64_t> targetShape);
-  }];
-
   let hasFolder = 1;
 }
 

diff  --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 31682847bb1a8..e9295650761f0 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -404,6 +404,48 @@ AffineMap removeDuplicateExprs(AffineMap map);
 /// ```
 AffineMap inversePermutation(AffineMap map);
 
+/// Return the reverse map of a projected permutation where the projected
+/// dimensions are transformed into 0s.
+///
+/// Prerequisites: `map` must be a projected permuation.
+///
+/// Example 1:
+///
+/// ```mlir
+///    affine_map<(d0, d1, d2, d3) -> (d2, d0)>
+/// ```
+///
+/// returns:
+///
+/// ```mlir
+///    affine_map<(d0, d1) -> (d1, 0, d0, 0)>
+/// ```
+///
+/// Example 2:
+///
+/// ```mlir
+///    affine_map<(d0, d1, d2, d3) -> (d0, d3)>
+/// ```
+///
+/// returns:
+///
+/// ```mlir
+///    affine_map<(d0, d1) -> (d0, 0, 0, d1)>
+/// ```
+///
+/// Example 3:
+///
+/// ```mlir
+///    affine_map<(d0, d1, d2, d3) -> (d2)>
+/// ```
+///
+/// returns:
+///
+/// ```mlir
+///    affine_map<(d0) -> (0, 0, d0, 0)>
+/// ```
+AffineMap inverseAndBroadcastProjectedPermuation(AffineMap map);
+
 /// Concatenates a list of `maps` into a single AffineMap, stepping over
 /// potentially empty maps. Assumes each of the underlying map has 0 symbols.
 /// The resulting map has a number of dims equal to the max of `maps`' dims and

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c4afed4d71f25..aee5233f4930e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -493,15 +493,18 @@ LogicalResult vectorizeAsLinalgGeneric(
       bvm.map(shapedArg, loaded);
       continue;
     }
-    AffineMap map = inversePermutation(
-        reindexIndexingMap(linalgOp.getIndexingMap(bbarg.getArgNumber())));
-    VectorType vectorType = VectorType::get(map.compose(shapedType.getShape()),
-                                            shapedType.getElementType());
+    AffineMap map;
+    VectorType vectorType;
     if (broadcastToMaximalCommonShape) {
-      map = vector::TransferReadOp::insertBroadcasts(map, vectorType,
-                                                     commonVectorShape);
+      map = inverseAndBroadcastProjectedPermuation(
+          linalgOp.getIndexingMap(bbarg.getArgNumber()));
       vectorType =
-          VectorType::get(commonVectorShape, vectorType.getElementType());
+          VectorType::get(commonVectorShape, shapedType.getElementType());
+    } else {
+      map = inversePermutation(
+          reindexIndexingMap(linalgOp.getIndexingMap(bbarg.getArgNumber())));
+      vectorType = VectorType::get(map.compose(shapedType.getShape()),
+                                   shapedType.getElementType());
     }
     Value vectorRead = buildVectorRead(builder, shapedArg, vectorType, map);
     LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 2958088b258df..5190482f576b9 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2253,29 +2253,6 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
 // TransferReadOp
 //===----------------------------------------------------------------------===//
 
-AffineMap TransferReadOp::insertBroadcasts(AffineMap map, VectorType vt,
-                                           ArrayRef<int64_t> targetShape) {
-  unsigned targetRank = targetShape.size();
-  assert(vt.getShape().size() <= targetRank && "mismatching ranks");
-  if (vt.getShape().size() == targetRank)
-    return map;
-  MLIRContext *ctx = map.getContext();
-  SmallVector<AffineExpr> exprs;
-  exprs.reserve(targetRank);
-  for (unsigned idx = 0, vtidx = 0; idx < targetRank; ++idx) {
-    // If shapes match, just keep the existing indexing and advance ranks.
-    if (vtidx < vt.getShape().size() &&
-        vt.getShape()[vtidx] == targetShape[idx]) {
-      exprs.push_back(map.getResult(vtidx));
-      ++vtidx;
-      continue;
-    }
-    // Otherwise insert a broadcast.
-    exprs.push_back(getAffineConstantExpr(0, ctx));
-  }
-  return AffineMap::get(map.getNumDims(), /*numSymbols=*/0, exprs, ctx);
-}
-
 template <typename EmitFun>
 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
                                           EmitFun emitOpError) {

diff  --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 380cdc65e4151..3c453de10fa75 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -664,6 +664,19 @@ AffineMap mlir::inversePermutation(AffineMap map) {
   return AffineMap::get(map.getNumResults(), 0, seenExprs, map.getContext());
 }
 
+AffineMap mlir::inverseAndBroadcastProjectedPermuation(AffineMap map) {
+  assert(map.isProjectedPermutation());
+  MLIRContext *context = map.getContext();
+  AffineExpr zero = mlir::getAffineConstantExpr(0, context);
+  // Start with all the results as 0.
+  SmallVector<AffineExpr, 4> exprs(map.getNumInputs(), zero);
+  for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
+    // Reverse each dimension existing in the oringal map result.
+    exprs[map.getDimPosition(i)] = getAffineDimExpr(i, context);
+  }
+  return AffineMap::get(map.getNumResults(), /*symbolCount=*/0, exprs, context);
+}
+
 AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
   unsigned numResults = 0, numDims = 0, numSymbols = 0;
   for (auto m : maps)

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 3eafc5acd6f5e..f4caad70ae197 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -381,6 +381,43 @@ func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>,
 
 // -----
 
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, 0, 0, d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0, 0, 0, 0)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (0, 0, d0, 0)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1, 0, d0, 0)>
+//     CHECK: func @generic_vectorize_broadcast_transpose
+// CHECK-DAG:   %[[C0:.*]] = constant 0 : index
+// CHECK-DAG:   %[[CF:.*]] = constant 0.000000e+00 : f32
+//     CHECK:   %[[V0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF]] {permutation_map = #[[$MAP0]]} : memref<4x4xf32>, vector<4x4x4x4xf32>
+//     CHECK:   %[[V1:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %[[CF]] {permutation_map = #[[$MAP1]]} : memref<4xf32>, vector<4x4x4x4xf32>
+//     CHECK:   %[[V2:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %[[CF]] {permutation_map = #[[$MAP2]]} : memref<4xf32>, vector<4x4x4x4xf32>
+//     CHECK:   %[[V3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF]] {permutation_map = #[[$MAP3]]} : memref<4x4xf32>, vector<4x4x4x4xf32>
+//     CHECK:   %[[SUB:.*]] = subf %[[V0]], %[[V1]] : vector<4x4x4x4xf32>
+//     CHECK:   %[[ADD0:.*]] = addf %[[V2]], %[[SUB]] : vector<4x4x4x4xf32>
+//     CHECK:   %[[ADD1:.*]] = addf %[[V3]], %[[ADD0]] : vector<4x4x4x4xf32>
+//     CHECK: vector.transfer_write %[[ADD1]], {{.*}} : vector<4x4x4x4xf32>, memref<4x4x4x4xf32>
+func @generic_vectorize_broadcast_transpose(
+  %A: memref<4xf32>, %B: memref<4x4xf32>, %C: memref<4x4x4x4xf32>) {
+  linalg.generic {
+  indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3)>,
+                   affine_map<(d0, d1, d2, d3) -> (d0)>,
+                   affine_map<(d0, d1, d2, d3) -> (d2)>,
+                   affine_map<(d0, d1, d2, d3) -> (d2, d0)>,
+                   affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+  iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+  ins(%B, %A, %A, %B: memref<4x4xf32>, memref<4xf32>, memref<4xf32>, memref<4x4xf32>)
+  outs(%C : memref<4x4x4x4xf32>) {
+  ^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32):  // no predecessors
+    %s = subf %arg0, %arg1 : f32
+    %a = addf %arg2, %s : f32
+    %b = addf %arg3, %a : f32
+    linalg.yield %b : f32
+  }
+  return
+}
+
+// -----
+
 // Test 
diff erent input maps.
 #matmul_trait = {
   indexing_maps = [


        


More information about the Mlir-commits mailing list