[Mlir-commits] [mlir] 1538bd5 - [mlir][Vector] Add patterns to reorder elementwise ops and broadcast/transpose ops.
Hanhan Wang
llvmlistbot at llvm.org
Mon Mar 7 12:52:41 PST 2022
Author: Hanhan Wang
Date: 2022-03-07T12:52:12-08:00
New Revision: 1538bd518cd236f4321695e9c5f0dd24601db366
URL: https://github.com/llvm/llvm-project/commit/1538bd518cd236f4321695e9c5f0dd24601db366
DIFF: https://github.com/llvm/llvm-project/commit/1538bd518cd236f4321695e9c5f0dd24601db366.diff
LOG: [mlir][Vector] Add patterns to reorder elementwise ops and broadcast/transpose ops.
In quantized comutation, there are casting ops around computation ops.
Reorder the ops to make reduce-to-contract actually work.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D120760
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 4421b55fb40d7..ee6429d0abd77 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1009,6 +1009,84 @@ struct CombineContractBroadcast
}
};
+/// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
+/// contraction ops closer, which kicks in CombineContractBroadcast pattern when
+/// casting ops are around these operations.
+/// Ex:
+/// ```
+/// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
+/// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
+/// ```
+/// Gets converted to:
+/// ```
+/// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
+/// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
+/// ```
+struct ReorderCastOpsOnBroadcast
+ : public OpInterfaceRewritePattern<CastOpInterface> {
+ using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
+
+ LogicalResult matchAndRewrite(CastOpInterface op,
+ PatternRewriter &rewriter) const override {
+ if (op->getNumOperands() != 1)
+ return failure();
+ auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
+ if (!bcastOp)
+ return failure();
+
+ Type castResTy = getElementTypeOrSelf(op->getResult(0));
+ if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>())
+ castResTy = VectorType::get(vecTy.getShape(), castResTy);
+ OperationState state(op->getLoc(), op->getName(), bcastOp.source(),
+ castResTy, op->getAttrs());
+ auto castOp = rewriter.createOperation(state);
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ op, op->getResult(0).getType(), castOp->getResult(0));
+ return success();
+ }
+};
+
+/// Reorders cast(transpose) to transpose(cast). This makes broadcast ops and
+/// contraction ops closer, which kicks in CombineContractTranspose pattern when
+/// casting ops are around these operations.
+/// Ex:
+/// ```
+/// %0 = vector.transpose %arg0, [2, 0, 1]
+/// : vector<32x16x8xi8> to vector<8x32x16xi8>
+/// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
+/// ```
+/// Gets converted to:
+/// ```
+/// %0 = arith.extsi %0 : vector<32x16x8xi8> to vector<32x16x8xi32>
+/// %1 = vector.transpose %arg0, [2, 0, 1]
+/// : vector<32x16x8xi32> to vector<8x32x16xi32>
+/// ```
+struct ReorderCastOpsOnTranspose
+ : public OpInterfaceRewritePattern<CastOpInterface> {
+
+ using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
+
+ LogicalResult matchAndRewrite(CastOpInterface op,
+ PatternRewriter &rewriter) const override {
+ if (op->getNumOperands() != 1)
+ return failure();
+ auto transpOp = op->getOperand(0).getDefiningOp<vector::TransposeOp>();
+ if (!transpOp)
+ return failure();
+
+ auto castResTy = transpOp.getVectorType();
+ castResTy = VectorType::get(castResTy.getShape(),
+ getElementTypeOrSelf(op->getResult(0)));
+ OperationState state(op->getLoc(), op->getName(), transpOp.vector(),
+ castResTy, op->getAttrs());
+ auto castOp = rewriter.createOperation(state);
+ rewriter.replaceOpWithNewOp<vector::TransposeOp>(
+ op, op->getResult(0).getType(), castOp->getResult(0),
+ transpOp.getTransp());
+ return success();
+ }
+};
+
} // namespace
/// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
@@ -2585,7 +2663,8 @@ void mlir::vector::populateVectorTransposeLoweringPatterns(
void mlir::vector::populateVectorReductionToContractPatterns(
RewritePatternSet &patterns) {
patterns.add<MultiReduceToContract, CombineContractBroadcast,
- CombineContractTranspose>(patterns.getContext());
+ CombineContractTranspose, ReorderCastOpsOnBroadcast,
+ ReorderCastOpsOnTranspose>(patterns.getContext());
}
void mlir::vector::
diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
index 85389b2a767d6..1167f1eba7f86 100644
--- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
+++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
@@ -85,3 +85,38 @@ func @contract_broadcast(
kind = #vector.kind<add>} %0, %arg1, %cst : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
return %1 : vector<8x32xf32>
}
+
+//===----------------------------------------------------------------------===//
+// Reorder casting ops and vector ops. The casting ops have almost identical
+// pattern, so only arith.extsi op is tested.
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
+ // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32>
+ // CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32>
+ %b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8>
+ %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
+ return %r : vector<2x4xi32>
+}
+
+// -----
+
+func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
+ // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
+ // CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
+ %b = vector.broadcast %a : i8 to vector<2x4xi8>
+ %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
+ return %r : vector<2x4xi32>
+}
+
+// -----
+
+func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
+ // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32>
+ // CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32>
+ %b = vector.transpose %a, [1, 0]: vector<4x2xi8> to vector<2x4xi8>
+ %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
+ return %r : vector<2x4xi32>
+}
More information about the Mlir-commits
mailing list