[Mlir-commits] [mlir] a213ece - [mlir] [VectorOps, LinAlg] Remove direct LLVM lowering for vector.broadcast
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 13 11:43:01 PDT 2020
Author: aartbik
Date: 2020-03-13T11:42:51-07:00
New Revision: a213ece30bdb8b604ea0933edbd9c2ca77b9631f
URL: https://github.com/llvm/llvm-project/commit/a213ece30bdb8b604ea0933edbd9c2ca77b9631f
DIFF: https://github.com/llvm/llvm-project/commit/a213ece30bdb8b604ea0933edbd9c2ca77b9631f.diff
LOG: [mlir] [VectorOps,LinAlg] Remove direct LLVM lowering for vector.broadcast
Summary:
The direct lowering of vector.broadcast into LLVM has been replaced by
progressive lowering into elementary vector ops. This also required a
small refactoring of a llvm.mlir test that used a direct vector.broadcast
operator (just to define a matmul).
Reviewers: nicolasvasilache, andydavis1, rriddle
Reviewed By: nicolasvasilache
Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D76143
Added:
Modified:
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/test/Dialect/Linalg/llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a41a9d257417..828a964afa06 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -817,59 +817,6 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
}
};
-// TODO(ajcbik): remove this rule once LinAlg tests are cleaned up
-class VectorOuterProductOpConversion : public ConvertToLLVMPattern {
-public:
- explicit VectorOuterProductOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::OuterProductOp::getOperationName(),
- context, typeConverter) {}
-
- PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
- auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
- auto *ctx = op->getContext();
- auto vLHS = adaptor.lhs().getType().cast<LLVM::LLVMType>();
- auto vRHS = adaptor.rhs().getType().cast<LLVM::LLVMType>();
- auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements();
- auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
- auto llvmArrayOfVectType = typeConverter.convertType(
- cast<vector::OuterProductOp>(op).getResult().getType());
- Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
- Value a = adaptor.lhs(), b = adaptor.rhs();
- Value acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
- SmallVector<Value, 8> lhs, accs;
- lhs.reserve(rankLHS);
- accs.reserve(rankLHS);
- for (unsigned d = 0, e = rankLHS; d < e; ++d) {
- // shufflevector explicitly requires i32.
- auto attr = rewriter.getI32IntegerAttr(d);
- SmallVector<Attribute, 4> bcastAttr(rankRHS, attr);
- auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx);
- Value aD = nullptr, accD = nullptr;
- // 1. Broadcast the element a[d] into vector aD.
- aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr);
- // 2. If acc is present, extract 1-d vector acc[d] into accD.
- if (acc)
- accD = rewriter.create<LLVM::ExtractValueOp>(
- loc, vRHS, acc, rewriter.getI64ArrayAttr(d));
- // 3. Compute aD outer b (plus accD, if relevant).
- Value aOuterbD =
- accD
- ? rewriter.create<LLVM::FMAOp>(loc, vRHS, aD, b, accD).getResult()
- : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult();
- // 4. Insert as value `d` in the descriptor.
- desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType,
- desc, aOuterbD,
- rewriter.getI64ArrayAttr(d));
- }
- rewriter.replaceOp(op, desc);
- return matchSuccess();
- }
-};
-
class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorTypeCastOpConversion(MLIRContext *context,
@@ -1160,8 +1107,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorShuffleOpConversion, VectorExtractElementOpConversion,
VectorExtractOpConversion, VectorFMAOp1DConversion,
VectorInsertElementOpConversion, VectorInsertOpConversion,
- VectorOuterProductOpConversion, VectorTypeCastOpConversion,
- VectorPrintOpConversion>(ctx, converter);
+ VectorTypeCastOpConversion, VectorPrintOpConversion>(
+ ctx, converter);
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(
diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir
index 82ec950584d0..290e1a2fe4d7 100644
--- a/mlir/test/Dialect/Linalg/llvm.mlir
+++ b/mlir/test/Dialect/Linalg/llvm.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s -convert-linalg-to-llvm | FileCheck %s
-// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm | FileCheck %s --check-prefix=LLVM-LOOPS
+// RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck %s --check-prefix=LLVM-LOOPS
func @range(%arg0: index) {
%c0 = constant 0 : index
@@ -172,14 +172,22 @@ func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C
// CHECK-SAME: !llvm<"[4 x <4 x float>]*">, !llvm<"[4 x <4 x float>]*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
// LLVM-LOOPS-LABEL: func @matmul_vec_impl(
-// LLVM-LOOPS: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
-// LLVM-LOOPS: llvm.shufflevector {{.*}} [1 : i32, 1 : i32, 1 : i32, 1 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
-// LLVM-LOOPS: llvm.shufflevector {{.*}} [2 : i32, 2 : i32, 2 : i32, 2 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
-// LLVM-LOOPS: llvm.shufflevector {{.*}} [3 : i32, 3 : i32, 3 : i32, 3 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
-// LLVM-LOOPS-NEXT: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <4 x float>]">
-// LLVM-LOOPS-NEXT: "llvm.intr.fma"({{.*}}) : (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
-// LLVM-LOOPS-NEXT: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <4 x float>]">
-
+// LLVM-LOOPS-SAME: %[[A:.*0]]: memref<?x?xvector<4xf32>>,
+// LLVM-LOOPS-SAME: %[[B:.*1]]: memref<?x?xvector<4xf32>>,
+// LLVM-LOOPS-SAME: %[[C:.*2]]: memref<?x?xvector<4x4xf32>>)
+// LLVM-LOOPS: %[[C0:.*]] = constant 0 : index
+// LLVM-LOOPS: %[[C1:.*]] = constant 1 : index
+// LLVM-LOOPS: %[[T0:.*]] = dim %[[A]], 0 : memref<?x?xvector<4xf32>>
+// LLVM-LOOPS: %[[T1:.*]] = dim %[[A]], 1 : memref<?x?xvector<4xf32>>
+// LLVM-LOOPS: %[[T2:.*]] = dim %[[B]], 1 : memref<?x?xvector<4xf32>>
+// LLVM-LOOPS: loop.for %[[I:.*]] = %[[C0]] to %[[T0]] step %[[C1]] {
+// LLVM-LOOPS: loop.for %[[J:.*]] = %[[C0]] to %[[T2]] step %[[C1]] {
+// LLVM-LOOPS: loop.for %[[K:.*]] = %[[C0]] to %[[T1]] step %[[C1]] {
+// LLVM-LOOPS: %[[T3:.*]] = load %[[A]][%[[I]], %[[K]]] : memref<?x?xvector<4xf32>>
+// LLVM-LOOPS: %[[T4:.*]] = load %[[B]][%[[K]], %[[J]]] : memref<?x?xvector<4xf32>>
+// LLVM-LOOPS: %[[T5:.*]] = load %[[C]][%[[I]], %[[J]]] : memref<?x?xvector<4x4xf32>>
+// LLVM-LOOPS: %[[T6:.*]] = vector.outerproduct %3, %4, %5 : vector<4xf32>, vector<4xf32>
+// LLVM-LOOPS: store %[[T6]], %[[C]][%[[I]], %[[J]]] : memref<?x?xvector<4x4xf32>>
#indexed_matmul_trait = {
args_in = 2,
More information about the Mlir-commits
mailing list