[Mlir-commits] [mlir] 4db65e2 - [mlir][vector] Reorder elementwise(transpose)

Lei Zhang llvmlistbot at llvm.org
Fri Apr 15 06:06:21 PDT 2022


Author: Lei Zhang
Date: 2022-04-15T09:05:35-04:00
New Revision: 4db65e279b96e2af9a4ea2c1e2acc40a64de2a0e

URL: https://github.com/llvm/llvm-project/commit/4db65e279b96e2af9a4ea2c1e2acc40a64de2a0e
DIFF: https://github.com/llvm/llvm-project/commit/4db65e279b96e2af9a4ea2c1e2acc40a64de2a0e.diff

LOG: [mlir][vector] Reorder elementwise(transpose)

Similar to the existing pattern for reodering cast(transpose),
this makes transpose following transpose and increases the chance
of embedding the transposition inside contraction op. Actually
cast ops are just special instances of elementwise ops.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D123596

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 3312095a62e61..50b89c1a9374f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
@@ -1048,43 +1049,86 @@ struct ReorderCastOpsOnBroadcast
   }
 };
 
-/// Reorders cast(transpose) to transpose(cast). This makes broadcast ops and
-/// contraction ops closer, which kicks in CombineContractTranspose pattern when
-/// casting ops are around these operations.
-/// Ex:
+/// Reorders elementwise(transpose) to transpose(elementwise). This makes
+/// transpose ops and contraction ops closer, which kicks in
+/// CombineContractTranspose pattern when elementwise ops are between these
+/// operations. Ex:
 /// ```
-///   %0 = vector.transpose %arg0, [2, 0, 1]
-///     : vector<32x16x8xi8> to vector<8x32x16xi8>
-///   %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
+/// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+/// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+/// %r = arith.addf %at, %bt : vector<2x4xf32>
 /// ```
 /// Gets converted to:
 /// ```
-///   %0 = arith.extsi %0 : vector<32x16x8xi8> to vector<32x16x8xi32>
-///   %1 = vector.transpose %arg0, [2, 0, 1]
-///     : vector<32x16x8xi32> to vector<8x32x16xi32>
+/// %0 = arith.addf %a, %b : vector<4x2xf32>
+/// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
 /// ```
-struct ReorderCastOpsOnTranspose
-    : public OpInterfaceRewritePattern<CastOpInterface> {
-
-  using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
-
-  LogicalResult matchAndRewrite(CastOpInterface op,
+struct ReorderElementwiseOpsOnTranspose final
+    : public OpTraitRewritePattern<OpTrait::Elementwise> {
+  using OpTraitRewritePattern::OpTraitRewritePattern;
+  LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
-    if (op->getNumOperands() != 1)
+    if (op->getNumResults() != 1 || op->getNumRegions() != 0)
       return failure();
-    auto transpOp = op->getOperand(0).getDefiningOp<vector::TransposeOp>();
-    if (!transpOp)
+
+    // Make sure all operands are transpose/constant ops and collect their
+    // transposition maps.
+    SmallVector<ArrayAttr, 4> transposeMaps;
+    transposeMaps.reserve(op->getNumOperands());
+    // Record the initial type before transposition. We'll use its shape later.
+    // Any type will do here as we will check all transpose maps are the same.
+    VectorType srcType;
+    for (Value operand : op->getOperands()) {
+      auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
+      if (transposeOp) {
+        transposeMaps.push_back(transposeOp.getTransp());
+        srcType = transposeOp.getVectorType();
+      } else if (!matchPattern(operand, m_Constant())) {
+        return failure();
+      }
+    }
+    if (transposeMaps.empty())
       return failure();
+    // This is an elementwise op, so all transposed operands should have the
+    // same type. We need to additionally check that all transposes uses the
+    // same map.
+    if (!llvm::is_splat(transposeMaps))
+      return rewriter.notifyMatchFailure(op, "
diff erent transpose map");
+
+    SmallVector<Value, 4> srcValues;
+    srcValues.reserve(op->getNumOperands());
+
+    // If there are constant operands, we need to insert inverse transposes for
+    // them. Calculate the inverse order first.
+    auto order = extractVector<unsigned>(transposeMaps.front());
+    SmallVector<int64_t> invOrder(order.size());
+    for (int i = 0, e = order.size(); i < e; ++i)
+      invOrder[order[i]] = i;
+
+    for (Value operand : op->getOperands()) {
+      auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
+      if (transposeOp) {
+        srcValues.push_back(transposeOp.getVector());
+      } else {
+        // This is a constant. Create a reverse transpose op for it.
+        auto vectorType = VectorType::get(
+            srcType.getShape(),
+            operand.getType().cast<VectorType>().getElementType());
+        srcValues.push_back(rewriter.create<vector::TransposeOp>(
+            operand.getLoc(), vectorType, operand,
+            rewriter.getI64ArrayAttr(invOrder)));
+      }
+    }
 
-    auto castResTy = transpOp.getVectorType();
-    castResTy = VectorType::get(castResTy.getShape(),
-                                getElementTypeOrSelf(op->getResult(0)));
-    auto *castOp =
-        rewriter.create(op->getLoc(), op->getName().getIdentifier(),
-                        transpOp.getVector(), castResTy, op->getAttrs());
+    auto vectorType = VectorType::get(
+        srcType.getShape(),
+        op->getResultTypes()[0].cast<VectorType>().getElementType());
+    Operation *elementwiseOp =
+        rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
+                        vectorType, op->getAttrs());
     rewriter.replaceOpWithNewOp<vector::TransposeOp>(
-        op, op->getResult(0).getType(), castOp->getResult(0),
-        transpOp.getTransp());
+        op, op->getResultTypes()[0], elementwiseOp->getResult(0),
+        transposeMaps.front());
     return success();
   }
 };
@@ -2647,7 +2691,7 @@ void mlir::vector::populateVectorReductionToContractPatterns(
     RewritePatternSet &patterns) {
   patterns.add<MultiReduceToContract, CombineContractBroadcast,
                CombineContractTranspose, ReorderCastOpsOnBroadcast,
-               ReorderCastOpsOnTranspose>(patterns.getContext());
+               ReorderElementwiseOpsOnTranspose>(patterns.getContext());
 }
 
 void mlir::vector::

diff  --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
index 1167f1eba7f86..b17771a25ea5c 100644
--- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
+++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
@@ -120,3 +120,80 @@ func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
   %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
   return %r : vector<2x4xi32>
 }
+
+//===----------------------------------------------------------------------===//
+// Reorder elementwise ops and vector ops.
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: func @transpose_elementwise_same_type
+//  CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
+//       CHECK:   %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32>
+//       CHECK:   %[[T:.+]] = vector.transpose %[[ADD]], [1, 0]
+//       CHECK:   return %[[T]]
+
+func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
+  %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+  %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+  %r = arith.addf %at, %bt : vector<2x4xf32>
+  return %r : vector<2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_elementwise_
diff _operand_types
+//  CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
+//       CHECK:   %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32>
+//       CHECK:   %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<4x2xf32> to vector<2x4xf32>
+//       CHECK:   return %[[T]]
+func @transpose_elementwise_
diff _operand_types(%cond: vector<4x2xi1>, %a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
+  %condt = vector.transpose %cond, [1, 0]: vector<4x2xi1> to vector<2x4xi1>
+  %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+  %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+  %r = arith.select %condt, %at, %bt : vector<2x4xi1>, vector<2x4xf32>
+  return %r : vector<2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_elementwise_
diff _operand_result_type
+//  CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
+//       CHECK:   %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32>
+//       CHECK:   %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<4x2xi1> to vector<2x4xi1>
+//       CHECK:   return %[[T]]
+func @transpose_elementwise_
diff _operand_result_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xi1> {
+  %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+  %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+  %r = arith.cmpf olt, %at, %bt : vector<2x4xf32>
+  return %r : vector<2x4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_elementwise_splat_constant
+//  CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>)
+//       CHECK:   %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32>
+//       CHECK:   %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x6x3x2xf32>
+//       CHECK:   %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
+//       CHECK:   return %[[T:.+]] : vector<6x4x2x3xf32>
+
+func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vector<6x4x2x3xf32> {
+  %b = arith.constant dense<5.0> : vector<6x4x2x3xf32>
+  %at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
+  %r = arith.addf %at, %b : vector<6x4x2x3xf32>
+  return %r : vector<6x4x2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_elementwise_
diff _map
+//       CHECK:   vector.transpose
+//       CHECK:   vector.transpose
+//       CHECK:   arith.addf
+func @transpose_elementwise_
diff _map(%a : vector<4x6x3x2xf32>, %b: vector<6x2x4x3xf32>) -> vector<6x4x2x3xf32> {
+  %at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
+  %bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x4x3xf32> to vector<6x4x2x3xf32>
+  %r = arith.addf %at, %bt : vector<6x4x2x3xf32>
+  return %r : vector<6x4x2x3xf32>
+}


        


More information about the Mlir-commits mailing list