[Mlir-commits] [mlir] 13f15e8 - [mlir][Vector] Fix vectorization of generic ops with transposed outputs

Diego Caballero llvmlistbot at llvm.org
Mon Jun 26 13:24:48 PDT 2023


Author: Diego Caballero
Date: 2023-06-26T20:24:29Z
New Revision: 13f15e8f14cfd0b5795ef0c0164a913fdaa5378c

URL: https://github.com/llvm/llvm-project/commit/13f15e8f14cfd0b5795ef0c0164a913fdaa5378c
DIFF: https://github.com/llvm/llvm-project/commit/13f15e8f14cfd0b5795ef0c0164a913fdaa5378c.diff

LOG: [mlir][Vector] Fix vectorization of generic ops with transposed outputs

This patch fixes a bug in the way we compute the vector type for vector
transfer writes when the value to store needs to be transposed.

Reviewed By: nicolasvasilache, mravishankar

Differential Revision: https://reviews.llvm.org/D153687

Added: 
    

Modified: 
    mlir/include/mlir/IR/AffineMap.h
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/IR/AffineMap.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 01cd7183e43c7..3430db2b99c3f 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -78,6 +78,20 @@ class AffineMap {
   static AffineMap getMinorIdentityMap(unsigned dims, unsigned results,
                                        MLIRContext *context);
 
+  /// Returns an identity affine map witn `numDims` input dimensions and
+  /// filtered results using `keepDimFilter`. If `keepDimFilter` returns true
+  /// for a dimension, the dimension is kept in the affine map results.
+  /// Otherwise, the dimension is dropped from the results.
+  ///
+  /// Examples:
+  ///   * getFilteredIdentityMap(4, [false, true, false, true])
+  ///       -> affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+  ///   * getFilteredIdentityMap(3, [false, false, true])
+  ///       -> affine_map<(d0, d1, d2) -> (d2)>
+  static AffineMap
+  getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims,
+                         llvm::function_ref<bool(AffineDimExpr)> keepDimFilter);
+
   /// Returns an AffineMap representing a permutation.
   /// The permutation is expressed as a non-empty vector of integers.
   /// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index bbcde44f08618..45b35d3e01a6a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -605,8 +605,18 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
   Location loc = value.getLoc();
   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
   AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
+
+  // Compute the vector type of the value to store. This type should be an
+  // identity or projection of the canonical vector type without any permutation
+  // applied, given that any permutation in a transfer write happens as part of
+  // the write itself.
+  AffineMap vectorTypeMap = AffineMap::getFilteredIdentityMap(
+      opOperandMap.getContext(), opOperandMap.getNumInputs(),
+      [&](AffineDimExpr dimExpr) -> bool {
+        return llvm::is_contained(opOperandMap.getResults(), dimExpr);
+      });
   auto vectorType = state.getCanonicalVecType(
-      getElementTypeOrSelf(outputOperand->get().getType()), opOperandMap);
+      getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
 
   Operation *write;
   if (vectorType.getRank() > 0) {
@@ -614,13 +624,14 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
     SmallVector<Value> indices(linalgOp.getRank(outputOperand),
                                rewriter.create<arith::ConstantIndexOp>(loc, 0));
     value = broadcastIfNeeded(rewriter, value, vectorType);
+    assert(value.getType() == vectorType && "Incorrect type");
     write = rewriter.create<vector::TransferWriteOp>(
         loc, value, outputOperand->get(), indices, writeMap);
   } else {
     // 0-d case is still special: do not invert the reindexing writeMap.
     if (!isa<VectorType>(value.getType()))
       value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
-    assert(value.getType() == vectorType && "incorrect type");
+    assert(value.getType() == vectorType && "Incorrect type");
     write = rewriter.create<vector::TransferWriteOp>(
         loc, value, outputOperand->get(), ValueRange{});
   }

diff  --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 4a67010b7a3a0..9cdac964710ca 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -113,6 +113,19 @@ AffineMap AffineMap::getMinorIdentityMap(unsigned dims, unsigned results,
   return AffineMap::get(dims, 0, id.getResults().take_back(results), context);
 }
 
+AffineMap AffineMap::getFilteredIdentityMap(
+    MLIRContext *ctx, unsigned numDims,
+    llvm::function_ref<bool(AffineDimExpr)> keepDimFilter) {
+  auto identityMap = getMultiDimIdentityMap(numDims, ctx);
+
+  // Apply filter to results.
+  llvm::SmallBitVector dropDimResults(numDims);
+  for (auto [idx, resultExpr] : llvm::enumerate(identityMap.getResults()))
+    dropDimResults[idx] = !keepDimFilter(resultExpr.cast<AffineDimExpr>());
+
+  return identityMap.dropResults(dropDimResults);
+}
+
 bool AffineMap::isMinorIdentity() const {
   return getNumDims() >= getNumResults() &&
          *this ==

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 130c6bcc11abb..933c9c7864988 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1751,3 +1751,38 @@ transform.sequence failures(propagate) {
 //       CHECK:     vector.broadcast %{{.*}} : f32 to vector<f32>
 //       CHECK:     vector.transfer_write {{.*}} : vector<f32>, tensor<f32>
 
+// -----
+
+// Make sure we generate the right transfer writes for multi-output generic ops
+// with 
diff erent permutation maps.
+
+func.func @multi_output_generic_
diff erent_perm_maps(%in0: tensor<4x1xf32>,
+                                                    %out0: tensor<4x1xf32>,
+                                                    %out1: tensor<1x4xf32>) -> (tensor<4x1xf32>, tensor<1x4xf32>) {
+  %13:2 = linalg.generic {indexing_maps = [ affine_map<(d0, d1) -> (d1, d0)>,
+                                            affine_map<(d0, d1) -> (d1, d0)>,
+                                            affine_map<(d0, d1) -> (d0, d1)> ],
+                          iterator_types = ["parallel", "parallel"]}
+                          ins(%in0 : tensor<4x1xf32>)
+                          outs(%out0, %out1 : tensor<4x1xf32>, tensor<1x4xf32>) {
+  ^bb0(%in: f32, %out: f32, %out_2: f32):
+    %16 = arith.addf %in, %in : f32
+    linalg.yield %16, %16 : f32, f32
+  } -> (tensor<4x1xf32>, tensor<1x4xf32>)
+  return %13#0, %13#1 : tensor<4x1xf32>, tensor<1x4xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op
+  %5 = transform.structured.vectorize %4 : (!transform.any_op) -> !transform.any_op
+}
+
+// CHECK-LABEL: func @multi_output_generic_
diff erent_perm_maps
+//       CHECK:     %[[VAL_5:.*]] = vector.transfer_read %{{.*}} {in_bounds = [true, true]} : tensor<4x1xf32>, vector<4x1xf32>
+//       CHECK:     %[[VAL_6:.*]] = arith.addf %[[VAL_5]], %[[VAL_5]] : vector<4x1xf32>
+//       CHECK:     %[[VAL_7:.*]] = vector.transpose %[[VAL_6]], [1, 0] : vector<4x1xf32> to vector<1x4xf32>
+//       CHECK:     %[[VAL_8:.*]] = vector.transpose %[[VAL_7]], [1, 0] : vector<1x4xf32> to vector<4x1xf32>
+//       CHECK:     vector.transfer_write %[[VAL_8]], %{{.*}} {in_bounds = [true, true]} : vector<4x1xf32>, tensor<4x1xf32>
+//       CHECK:     vector.transfer_write %[[VAL_7]], %{{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>


        


More information about the Mlir-commits mailing list