[Mlir-commits] [mlir] 1538bd5 - [mlir][Vector] Add patterns to reorder elementwise ops and broadcast/transpose ops.

Hanhan Wang llvmlistbot at llvm.org
Mon Mar 7 12:52:41 PST 2022


Author: Hanhan Wang
Date: 2022-03-07T12:52:12-08:00
New Revision: 1538bd518cd236f4321695e9c5f0dd24601db366

URL: https://github.com/llvm/llvm-project/commit/1538bd518cd236f4321695e9c5f0dd24601db366
DIFF: https://github.com/llvm/llvm-project/commit/1538bd518cd236f4321695e9c5f0dd24601db366.diff

LOG: [mlir][Vector] Add patterns to reorder elementwise ops and broadcast/transpose ops.

In quantized comutation, there are casting ops around computation ops.
Reorder the ops to make reduce-to-contract actually work.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D120760

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 4421b55fb40d7..ee6429d0abd77 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1009,6 +1009,84 @@ struct CombineContractBroadcast
   }
 };
 
+/// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
+/// contraction ops closer, which kicks in CombineContractBroadcast pattern when
+/// casting ops are around these operations.
+/// Ex:
+/// ```
+///   %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
+///   %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
+/// ```
+/// Gets converted to:
+/// ```
+///   %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
+///   %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
+/// ```
+struct ReorderCastOpsOnBroadcast
+    : public OpInterfaceRewritePattern<CastOpInterface> {
+  using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
+
+  LogicalResult matchAndRewrite(CastOpInterface op,
+                                PatternRewriter &rewriter) const override {
+    if (op->getNumOperands() != 1)
+      return failure();
+    auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
+    if (!bcastOp)
+      return failure();
+
+    Type castResTy = getElementTypeOrSelf(op->getResult(0));
+    if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>())
+      castResTy = VectorType::get(vecTy.getShape(), castResTy);
+    OperationState state(op->getLoc(), op->getName(), bcastOp.source(),
+                         castResTy, op->getAttrs());
+    auto castOp = rewriter.createOperation(state);
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+        op, op->getResult(0).getType(), castOp->getResult(0));
+    return success();
+  }
+};
+
+/// Reorders cast(transpose) to transpose(cast). This makes broadcast ops and
+/// contraction ops closer, which kicks in CombineContractTranspose pattern when
+/// casting ops are around these operations.
+/// Ex:
+/// ```
+///   %0 = vector.transpose %arg0, [2, 0, 1]
+///     : vector<32x16x8xi8> to vector<8x32x16xi8>
+///   %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
+/// ```
+/// Gets converted to:
+/// ```
+///   %0 = arith.extsi %0 : vector<32x16x8xi8> to vector<32x16x8xi32>
+///   %1 = vector.transpose %arg0, [2, 0, 1]
+///     : vector<32x16x8xi32> to vector<8x32x16xi32>
+/// ```
+struct ReorderCastOpsOnTranspose
+    : public OpInterfaceRewritePattern<CastOpInterface> {
+
+  using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
+
+  LogicalResult matchAndRewrite(CastOpInterface op,
+                                PatternRewriter &rewriter) const override {
+    if (op->getNumOperands() != 1)
+      return failure();
+    auto transpOp = op->getOperand(0).getDefiningOp<vector::TransposeOp>();
+    if (!transpOp)
+      return failure();
+
+    auto castResTy = transpOp.getVectorType();
+    castResTy = VectorType::get(castResTy.getShape(),
+                                getElementTypeOrSelf(op->getResult(0)));
+    OperationState state(op->getLoc(), op->getName(), transpOp.vector(),
+                         castResTy, op->getAttrs());
+    auto castOp = rewriter.createOperation(state);
+    rewriter.replaceOpWithNewOp<vector::TransposeOp>(
+        op, op->getResult(0).getType(), castOp->getResult(0),
+        transpOp.getTransp());
+    return success();
+  }
+};
+
 } // namespace
 
 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
@@ -2585,7 +2663,8 @@ void mlir::vector::populateVectorTransposeLoweringPatterns(
 void mlir::vector::populateVectorReductionToContractPatterns(
     RewritePatternSet &patterns) {
   patterns.add<MultiReduceToContract, CombineContractBroadcast,
-               CombineContractTranspose>(patterns.getContext());
+               CombineContractTranspose, ReorderCastOpsOnBroadcast,
+               ReorderCastOpsOnTranspose>(patterns.getContext());
 }
 
 void mlir::vector::

diff  --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
index 85389b2a767d6..1167f1eba7f86 100644
--- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
+++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
@@ -85,3 +85,38 @@ func @contract_broadcast(
     kind = #vector.kind<add>} %0, %arg1, %cst : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
   return %1 : vector<8x32xf32>
 }
+
+//===----------------------------------------------------------------------===//
+// Reorder casting ops and vector ops. The casting ops have almost identical
+// pattern, so only arith.extsi op is tested.
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
+  // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32>
+  // CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32>
+  %b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8>
+  %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
+  return %r : vector<2x4xi32>
+}
+
+// -----
+
+func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
+  // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
+  // CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
+  %b = vector.broadcast %a : i8 to vector<2x4xi8>
+  %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
+  return %r : vector<2x4xi32>
+}
+
+// -----
+
+func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
+  // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32>
+  // CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32>
+  %b = vector.transpose %a, [1, 0]: vector<4x2xi8> to vector<2x4xi8>
+  %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
+  return %r : vector<2x4xi32>
+}


        


More information about the Mlir-commits mailing list