[Mlir-commits] [mlir] 6e557bc - [mlir][spirv] Add Vector to SPIR-V conversion pass
Thomas Raoux
llvmlistbot at llvm.org
Tue Oct 6 11:53:40 PDT 2020
Author: Thomas Raoux
Date: 2020-10-06T11:53:23-07:00
New Revision: 6e557bc40507cbc5e331179b26f7ae5fe9624294
URL: https://github.com/llvm/llvm-project/commit/6e557bc40507cbc5e331179b26f7ae5fe9624294
DIFF: https://github.com/llvm/llvm-project/commit/6e557bc40507cbc5e331179b26f7ae5fe9624294.diff
LOG: [mlir][spirv] Add Vector to SPIR-V conversion pass
Add conversion pass for Vector dialect to SPIR-V dialect and add some simple
conversion pattern for vector.broadcast, vector.insert, vector.extract.
Differential Revision: https://reviews.llvm.org/D88761
Added:
mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h
mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h
mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/simple.mlir
Modified:
mlir/include/mlir/Conversion/Passes.h
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td
mlir/lib/Conversion/CMakeLists.txt
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index b04498598b29..b4418bb2e0ac 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -29,6 +29,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
+#include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h"
namespace mlir {
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 547b952b60b4..36618384bb39 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -381,4 +381,15 @@ def ConvertVectorToROCDL : Pass<"convert-vector-to-rocdl", "ModuleOp"> {
let dependentDialects = ["ROCDL::ROCDLDialect"];
}
+//===----------------------------------------------------------------------===//
+// VectorToSPIRV
+//===----------------------------------------------------------------------===//
+
+def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv", "ModuleOp"> {
+ let summary = "Lower the operations from the vector dialect into the SPIR-V "
+ "dialect";
+ let constructor = "mlir::createConvertVectorToSPIRVPass()";
+ let dependentDialects = ["spirv::SPIRVDialect"];
+}
+
#endif // MLIR_CONVERSION_PASSES
diff --git a/mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h b/mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h
new file mode 100644
index 000000000000..de664df83e83
--- /dev/null
+++ b/mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h
@@ -0,0 +1,29 @@
+//=- ConvertVectorToSPIRV.h - Vector Ops to SPIR-V dialect patterns - C++ -*-=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Provides patterns for lowering Vector Ops to SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INCLUDE_MLIR_CONVERSION_VECTORTOSPIRV_CONVERTVECTORTOSPIRV_H_
+#define MLIR_INCLUDE_MLIR_CONVERSION_VECTORTOSPIRV_CONVERTVECTORTOSPIRV_H_
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+class SPIRVTypeConverter;
+
+/// Appends to a pattern list additional patterns for translating Vector Ops to
+/// SPIR-V ops.
+void populateVectorToSPIRVPatterns(MLIRContext *context,
+ SPIRVTypeConverter &typeConverter,
+ OwningRewritePatternList &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_INCLUDE_MLIR_CONVERSION_VECTORTOSPIRV_CONVERTVECTORTOSPIRV_H_
diff --git a/mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h b/mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h
new file mode 100644
index 000000000000..7d4c7c1fb025
--- /dev/null
+++ b/mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h
@@ -0,0 +1,25 @@
+//=- ConvertVectorToSPIRVPass.h - Pass converting Vector to SPIRV -*- C++ -*-=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Provides a pass to convert Vector ops to SPIR-V ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_VECTORTOSPIRV_CONVERTGPUTOSPIRVPASS_H
+#define MLIR_CONVERSION_VECTORTOSPIRV_CONVERTGPUTOSPIRVPASS_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+/// Pass to convert Vector Ops to SPIR-V ops.
+std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToSPIRVPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_VECTORTOSPIRV_CONVERTGPUTOSPIRVPASS_H
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td
index d6e66a6ee1a7..c3a867977b3e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td
@@ -161,6 +161,11 @@ def SPV_CompositeInsertOp : SPV_Op<"CompositeInsert", [NoSideEffect]> {
let results = (outs
SPV_Composite:$result
);
+
+ let builders = [
+ OpBuilder<[{OpBuilder &builder, OperationState &state, Value object,
+ Value composite, ArrayRef<int32_t> indices}]>
+ ];
}
#endif // SPIRV_COMPOSITE_OPS
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index fe2af07b2a6a..dbb9ed699798 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -19,3 +19,4 @@ add_subdirectory(StandardToSPIRV)
add_subdirectory(VectorToROCDL)
add_subdirectory(VectorToLLVM)
add_subdirectory(VectorToSCF)
+add_subdirectory(VectorToSPIRV)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
new file mode 100644
index 000000000000..a6e73002de25
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_conversion_library(MLIRVectorToSPIRV
+ VectorToSPIRV.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToSPIRV
+
+ DEPENDS
+ MLIRConversionPassIncGen
+ intrinsics_gen
+
+ LINK_LIBS PUBLIC
+ MLIRSPIRV
+ MLIRVector
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
new file mode 100644
index 000000000000..05949fb59910
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -0,0 +1,119 @@
+//===------- VectorToSPIRV.cpp - Vector to SPIRV lowering passes ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to generate SPIRV operations for Vector
+// operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "../PassDetail.h"
+#include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h"
+#include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+struct VectorBroadcastConvert final
+ : public SPIRVOpLowering<vector::BroadcastOp> {
+ using SPIRVOpLowering<vector::BroadcastOp>::SPIRVOpLowering;
+ LogicalResult
+ matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (broadcastOp.source().getType().isa<VectorType>() ||
+ !spirv::CompositeType::isValid(broadcastOp.getVectorType()))
+ return failure();
+ vector::BroadcastOp::Adaptor adaptor(operands);
+ SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
+ adaptor.source());
+ Value construct = rewriter.create<spirv::CompositeConstructOp>(
+ broadcastOp.getLoc(), broadcastOp.getVectorType(), source);
+ rewriter.replaceOp(broadcastOp, construct);
+ return success();
+ }
+};
+
+struct VectorExtractOpConvert final
+ : public SPIRVOpLowering<vector::ExtractOp> {
+ using SPIRVOpLowering<vector::ExtractOp>::SPIRVOpLowering;
+ LogicalResult
+ matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (extractOp.getType().isa<VectorType>() ||
+ !spirv::CompositeType::isValid(extractOp.getVectorType()))
+ return failure();
+ vector::ExtractOp::Adaptor adaptor(operands);
+ int32_t id = extractOp.position().begin()->cast<IntegerAttr>().getInt();
+ Value newExtract = rewriter.create<spirv::CompositeExtractOp>(
+ extractOp.getLoc(), adaptor.vector(), id);
+ rewriter.replaceOp(extractOp, newExtract);
+ return success();
+ }
+};
+
+struct VectorInsertOpConvert final : public SPIRVOpLowering<vector::InsertOp> {
+ using SPIRVOpLowering<vector::InsertOp>::SPIRVOpLowering;
+ LogicalResult
+ matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (insertOp.getSourceType().isa<VectorType>() ||
+ !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
+ return failure();
+ vector::InsertOp::Adaptor adaptor(operands);
+ int32_t id = insertOp.position().begin()->cast<IntegerAttr>().getInt();
+ Value newInsert = rewriter.create<spirv::CompositeInsertOp>(
+ insertOp.getLoc(), adaptor.source(), adaptor.dest(), id);
+ rewriter.replaceOp(insertOp, newInsert);
+ return success();
+ }
+};
+} // namespace
+
+void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
+ SPIRVTypeConverter &typeConverter,
+ OwningRewritePatternList &patterns) {
+ patterns.insert<VectorBroadcastConvert, VectorExtractOpConvert,
+ VectorInsertOpConvert>(context, typeConverter);
+}
+
+namespace {
+struct LowerVectorToSPIRVPass
+ : public ConvertVectorToSPIRVBase<LowerVectorToSPIRVPass> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void LowerVectorToSPIRVPass::runOnOperation() {
+ MLIRContext *context = &getContext();
+ ModuleOp module = getOperation();
+
+ auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
+ std::unique_ptr<ConversionTarget> target =
+ spirv::SPIRVConversionTarget::get(targetAttr);
+
+ SPIRVTypeConverter typeConverter(targetAttr);
+ OwningRewritePatternList patterns;
+ populateVectorToSPIRVPatterns(context, typeConverter, patterns);
+
+ target->addLegalOp<ModuleOp, ModuleTerminatorOp>();
+ target->addLegalOp<FuncOp>();
+
+ if (failed(applyFullConversion(module, *target, patterns)))
+ return signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<ModuleOp>>
+mlir::createConvertVectorToSPIRVPass() {
+ return std::make_unique<LowerVectorToSPIRVPass>();
+}
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index ad25ecb427a6..c17490c05e6b 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1410,6 +1410,13 @@ static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
// spv.CompositeInsert
//===----------------------------------------------------------------------===//
+void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
+ Value object, Value composite,
+ ArrayRef<int32_t> indices) {
+ auto indexAttr = builder.getI32ArrayAttr(indices);
+ build(builder, state, composite.getType(), object, composite, indexAttr);
+}
+
static ParseResult parseCompositeInsertOp(OpAsmParser &parser,
OperationState &state) {
SmallVector<OpAsmParser::OperandType, 2> operands;
diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
new file mode 100644
index 000000000000..34f1ef52c237
--- /dev/null
+++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt -split-input-file -convert-vector-to-spirv %s -o - | FileCheck %s
+
+// CHECK-LABEL: broadcast
+// CHECK-SAME: %[[A:.*]]: f32
+// CHECK: spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32>
+// CHECK: spv.CompositeConstruct %[[A]], %[[A]] : vector<2xf32>
+func @broadcast(%arg0 : f32) {
+ %0 = vector.broadcast %arg0 : f32 to vector<4xf32>
+ %1 = vector.broadcast %arg0 : f32 to vector<2xf32>
+ spv.Return
+}
+
+// -----
+
+// CHECK-LABEL: extract_insert
+// CHECK-SAME: %[[V:.*]]: vector<4xf32>
+// CHECK: %[[S:.*]] = spv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
+// CHECK: spv.CompositeInsert %[[S]], %[[V]][0 : i32] : f32 into vector<4xf32>
+func @extract_insert(%arg0 : vector<4xf32>) {
+ %0 = vector.extract %arg0[1] : vector<4xf32>
+ %1 = vector.insert %0, %arg0[0] : f32 into vector<4xf32>
+ spv.Return
+}
More information about the Mlir-commits
mailing list