[Mlir-commits] [mlir] 6e557bc - [mlir][spirv] Add Vector to SPIR-V conversion pass

Thomas Raoux llvmlistbot at llvm.org
Tue Oct 6 11:53:40 PDT 2020


Author: Thomas Raoux
Date: 2020-10-06T11:53:23-07:00
New Revision: 6e557bc40507cbc5e331179b26f7ae5fe9624294

URL: https://github.com/llvm/llvm-project/commit/6e557bc40507cbc5e331179b26f7ae5fe9624294
DIFF: https://github.com/llvm/llvm-project/commit/6e557bc40507cbc5e331179b26f7ae5fe9624294.diff

LOG: [mlir][spirv] Add Vector to SPIR-V conversion pass

Add conversion pass for Vector dialect to SPIR-V dialect and add some simple
conversion pattern for vector.broadcast, vector.insert, vector.extract.

Differential Revision: https://reviews.llvm.org/D88761

Added: 
    mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h
    mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h
    mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/simple.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td
    mlir/lib/Conversion/CMakeLists.txt
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index b04498598b29..b4418bb2e0ac 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -29,6 +29,7 @@
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
+#include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h"
 
 namespace mlir {
 

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 547b952b60b4..36618384bb39 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -381,4 +381,15 @@ def ConvertVectorToROCDL : Pass<"convert-vector-to-rocdl", "ModuleOp"> {
   let dependentDialects = ["ROCDL::ROCDLDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// VectorToSPIRV
+//===----------------------------------------------------------------------===//
+
+def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv", "ModuleOp"> {
+  let summary = "Lower the operations from the vector dialect into the SPIR-V "
+                "dialect";
+  let constructor = "mlir::createConvertVectorToSPIRVPass()";
+  let dependentDialects = ["spirv::SPIRVDialect"];
+}
+
 #endif // MLIR_CONVERSION_PASSES

diff  --git a/mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h b/mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h
new file mode 100644
index 000000000000..de664df83e83
--- /dev/null
+++ b/mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h
@@ -0,0 +1,29 @@
+//=- ConvertVectorToSPIRV.h - Vector Ops to SPIR-V dialect patterns - C++ -*-=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Provides patterns for lowering Vector Ops to SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INCLUDE_MLIR_CONVERSION_VECTORTOSPIRV_CONVERTVECTORTOSPIRV_H_
+#define MLIR_INCLUDE_MLIR_CONVERSION_VECTORTOSPIRV_CONVERTVECTORTOSPIRV_H_
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+class SPIRVTypeConverter;
+
+/// Appends to a pattern list additional patterns for translating Vector Ops to
+/// SPIR-V ops.
+void populateVectorToSPIRVPatterns(MLIRContext *context,
+                                   SPIRVTypeConverter &typeConverter,
+                                   OwningRewritePatternList &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_INCLUDE_MLIR_CONVERSION_VECTORTOSPIRV_CONVERTVECTORTOSPIRV_H_

diff  --git a/mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h b/mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h
new file mode 100644
index 000000000000..7d4c7c1fb025
--- /dev/null
+++ b/mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h
@@ -0,0 +1,25 @@
+//=- ConvertVectorToSPIRVPass.h - Pass converting Vector to SPIRV -*- C++ -*-=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Provides a pass to convert Vector ops to SPIR-V ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_VECTORTOSPIRV_CONVERTGPUTOSPIRVPASS_H
+#define MLIR_CONVERSION_VECTORTOSPIRV_CONVERTGPUTOSPIRVPASS_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+/// Pass to convert Vector Ops to SPIR-V ops.
+std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToSPIRVPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_VECTORTOSPIRV_CONVERTGPUTOSPIRVPASS_H

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td
index d6e66a6ee1a7..c3a867977b3e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td
@@ -161,6 +161,11 @@ def SPV_CompositeInsertOp : SPV_Op<"CompositeInsert", [NoSideEffect]> {
   let results = (outs
     SPV_Composite:$result
   );
+
+  let builders = [
+    OpBuilder<[{OpBuilder &builder, OperationState &state, Value object,
+                Value composite, ArrayRef<int32_t> indices}]>
+  ];
 }
 
 #endif // SPIRV_COMPOSITE_OPS

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index fe2af07b2a6a..dbb9ed699798 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -19,3 +19,4 @@ add_subdirectory(StandardToSPIRV)
 add_subdirectory(VectorToROCDL)
 add_subdirectory(VectorToLLVM)
 add_subdirectory(VectorToSCF)
+add_subdirectory(VectorToSPIRV)

diff  --git a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
new file mode 100644
index 000000000000..a6e73002de25
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_conversion_library(MLIRVectorToSPIRV
+  VectorToSPIRV.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToSPIRV
+
+  DEPENDS
+  MLIRConversionPassIncGen
+  intrinsics_gen
+
+  LINK_LIBS PUBLIC
+  MLIRSPIRV
+  MLIRVector
+  MLIRTransforms
+  )

diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
new file mode 100644
index 000000000000..05949fb59910
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -0,0 +1,119 @@
+//===------- VectorToSPIRV.cpp - Vector to SPIRV lowering passes ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to generate SPIRV operations for Vector
+// operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "../PassDetail.h"
+#include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h"
+#include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+struct VectorBroadcastConvert final
+    : public SPIRVOpLowering<vector::BroadcastOp> {
+  using SPIRVOpLowering<vector::BroadcastOp>::SPIRVOpLowering;
+  LogicalResult
+  matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (broadcastOp.source().getType().isa<VectorType>() ||
+        !spirv::CompositeType::isValid(broadcastOp.getVectorType()))
+      return failure();
+    vector::BroadcastOp::Adaptor adaptor(operands);
+    SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
+                                 adaptor.source());
+    Value construct = rewriter.create<spirv::CompositeConstructOp>(
+        broadcastOp.getLoc(), broadcastOp.getVectorType(), source);
+    rewriter.replaceOp(broadcastOp, construct);
+    return success();
+  }
+};
+
+struct VectorExtractOpConvert final
+    : public SPIRVOpLowering<vector::ExtractOp> {
+  using SPIRVOpLowering<vector::ExtractOp>::SPIRVOpLowering;
+  LogicalResult
+  matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (extractOp.getType().isa<VectorType>() ||
+        !spirv::CompositeType::isValid(extractOp.getVectorType()))
+      return failure();
+    vector::ExtractOp::Adaptor adaptor(operands);
+    int32_t id = extractOp.position().begin()->cast<IntegerAttr>().getInt();
+    Value newExtract = rewriter.create<spirv::CompositeExtractOp>(
+        extractOp.getLoc(), adaptor.vector(), id);
+    rewriter.replaceOp(extractOp, newExtract);
+    return success();
+  }
+};
+
+struct VectorInsertOpConvert final : public SPIRVOpLowering<vector::InsertOp> {
+  using SPIRVOpLowering<vector::InsertOp>::SPIRVOpLowering;
+  LogicalResult
+  matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (insertOp.getSourceType().isa<VectorType>() ||
+        !spirv::CompositeType::isValid(insertOp.getDestVectorType())) 
+      return failure();
+    vector::InsertOp::Adaptor adaptor(operands);
+    int32_t id = insertOp.position().begin()->cast<IntegerAttr>().getInt();
+    Value newInsert = rewriter.create<spirv::CompositeInsertOp>(
+        insertOp.getLoc(), adaptor.source(), adaptor.dest(), id);
+    rewriter.replaceOp(insertOp, newInsert);
+    return success();
+  }
+};
+} // namespace
+
+void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
+                                         SPIRVTypeConverter &typeConverter,
+                                         OwningRewritePatternList &patterns) {
+  patterns.insert<VectorBroadcastConvert, VectorExtractOpConvert,
+                  VectorInsertOpConvert>(context, typeConverter);
+}
+
+namespace {
+struct LowerVectorToSPIRVPass
+    : public ConvertVectorToSPIRVBase<LowerVectorToSPIRVPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void LowerVectorToSPIRVPass::runOnOperation() {
+  MLIRContext *context = &getContext();
+  ModuleOp module = getOperation();
+
+  auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
+  std::unique_ptr<ConversionTarget> target =
+      spirv::SPIRVConversionTarget::get(targetAttr);
+
+  SPIRVTypeConverter typeConverter(targetAttr);
+  OwningRewritePatternList patterns;
+  populateVectorToSPIRVPatterns(context, typeConverter, patterns);
+
+  target->addLegalOp<ModuleOp, ModuleTerminatorOp>();
+  target->addLegalOp<FuncOp>();
+
+  if (failed(applyFullConversion(module, *target, patterns)))
+    return signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<ModuleOp>>
+mlir::createConvertVectorToSPIRVPass() {
+  return std::make_unique<LowerVectorToSPIRVPass>();
+}

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index ad25ecb427a6..c17490c05e6b 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1410,6 +1410,13 @@ static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
 // spv.CompositeInsert
 //===----------------------------------------------------------------------===//
 
+void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
+                                     Value object, Value composite,
+                                     ArrayRef<int32_t> indices) {
+  auto indexAttr = builder.getI32ArrayAttr(indices);
+  build(builder, state, composite.getType(), object, composite, indexAttr);
+}
+
 static ParseResult parseCompositeInsertOp(OpAsmParser &parser,
                                           OperationState &state) {
   SmallVector<OpAsmParser::OperandType, 2> operands;

diff  --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
new file mode 100644
index 000000000000..34f1ef52c237
--- /dev/null
+++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt -split-input-file -convert-vector-to-spirv %s -o - | FileCheck %s
+
+// CHECK-LABEL: broadcast
+//  CHECK-SAME: %[[A:.*]]: f32
+//       CHECK:   spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32>
+//       CHECK:   spv.CompositeConstruct %[[A]], %[[A]] : vector<2xf32>
+func @broadcast(%arg0 : f32) {
+  %0 = vector.broadcast %arg0 : f32 to vector<4xf32>
+  %1 = vector.broadcast %arg0 : f32 to vector<2xf32>
+  spv.Return
+}
+
+// -----
+
+// CHECK-LABEL: extract_insert
+//  CHECK-SAME: %[[V:.*]]: vector<4xf32>
+//       CHECK:   %[[S:.*]] = spv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
+//       CHECK:   spv.CompositeInsert %[[S]], %[[V]][0 : i32] : f32 into vector<4xf32>
+func @extract_insert(%arg0 : vector<4xf32>) {
+  %0 = vector.extract %arg0[1] : vector<4xf32>
+  %1 = vector.insert %0, %arg0[0] : f32 into vector<4xf32>
+  spv.Return
+}


        


More information about the Mlir-commits mailing list