[Mlir-commits] [mlir] 8568921 - [mlir][spirv] Convert `ub.poison` to `spirv.undef`
Ivan Butygin
llvmlistbot at llvm.org
Mon Jul 24 15:30:12 PDT 2023
Author: Ivan Butygin
Date: 2023-07-25T00:23:09+02:00
New Revision: 8568921d43b1dc6e273e89397d273aeba375a513
URL: https://github.com/llvm/llvm-project/commit/8568921d43b1dc6e273e89397d273aeba375a513
DIFF: https://github.com/llvm/llvm-project/commit/8568921d43b1dc6e273e89397d273aeba375a513.diff
LOG: [mlir][spirv] Convert `ub.poison` to `spirv.undef`
SPIR-V doesn't have poison, but poison can be converted to undef.
Differential Revision: https://reviews.llvm.org/D156163
Added:
mlir/include/mlir/Conversion/UBToSPIRV/UBToSPIRV.h
mlir/lib/Conversion/UBToSPIRV/CMakeLists.txt
mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir
Modified:
mlir/include/mlir/Conversion/Passes.h
mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index ab1d38603017f0..7cef25a6f4c7f1 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -58,6 +58,7 @@
#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
+#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index f3a5a9edab4e59..39517a67915cc4 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1054,6 +1054,18 @@ def UBToLLVMConversionPass : Pass<"convert-ub-to-llvm"> {
];
}
+//===----------------------------------------------------------------------===//
+// UBToSPIRV
+//===----------------------------------------------------------------------===//
+
+def UBToSPIRVConversionPass : Pass<"convert-ub-to-spirv"> {
+ let summary = "Convert UB dialect to SPIR-V dialect";
+ let description = [{
+ This pass converts supported UB ops to SPIR-V dialect ops.
+ }];
+ let dependentDialects = ["spirv::SPIRVDialect"];
+}
+
//===----------------------------------------------------------------------===//
// VectorToGPU
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/UBToSPIRV/UBToSPIRV.h b/mlir/include/mlir/Conversion/UBToSPIRV/UBToSPIRV.h
new file mode 100644
index 00000000000000..3843f2707a520d
--- /dev/null
+++ b/mlir/include/mlir/Conversion/UBToSPIRV/UBToSPIRV.h
@@ -0,0 +1,29 @@
+//===- UBToSPIRV.h - UB to SPIR-V dialect conversion ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_UBTOSPIRV_UBSPIRV_H
+#define MLIR_CONVERSION_UBTOSPIRV_UBSPIRV_H
+
+#include <memory>
+
+namespace mlir {
+
+class SPIRVTypeConverter;
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_UBTOSPIRVCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace ub {
+void populateUBToSPIRVConversionPatterns(SPIRVTypeConverter &converter,
+ RewritePatternSet &patterns);
+} // namespace ub
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_UBTOSPIRV_UBSPIRV_H
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 40ef939023607d..938254816ccd33 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -48,6 +48,7 @@ add_subdirectory(TosaToLinalg)
add_subdirectory(TosaToSCF)
add_subdirectory(TosaToTensor)
add_subdirectory(UBToLLVM)
+add_subdirectory(UBToSPIRV)
add_subdirectory(VectorToArmSME)
add_subdirectory(VectorToGPU)
add_subdirectory(VectorToLLVM)
diff --git a/mlir/lib/Conversion/UBToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/UBToSPIRV/CMakeLists.txt
new file mode 100644
index 00000000000000..a7a26b0a8c606a
--- /dev/null
+++ b/mlir/lib/Conversion/UBToSPIRV/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_conversion_library(MLIRUBToSPIRV
+ UBToSPIRV.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/UBToSPIRV
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRSPIRVConversion
+ MLIRSPIRVDialect
+ MLIRUBDialect
+ )
diff --git a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
new file mode 100644
index 00000000000000..001b7fefb175df
--- /dev/null
+++ b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
@@ -0,0 +1,84 @@
+//===- UBToSPIRV.cpp - UB to SPIRV-V dialect conversion -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_UBTOSPIRVCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+struct PoisonOpLowering final : OpConversionPattern<ub::PoisonOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ub::PoisonOp op, OpAdaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type origType = op.getType();
+ if (!origType.isIntOrIndexOrFloat())
+ return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+ diag << "unsupported type " << origType;
+ });
+
+ Type resType = getTypeConverter()->convertType(origType);
+ if (!resType)
+ return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+ diag << "failed to convert result type " << origType;
+ });
+
+ rewriter.replaceOpWithNewOp<spirv::UndefOp>(op, resType);
+ return success();
+ }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct UBToSPIRVConversionPass final
+ : impl::UBToSPIRVConversionPassBase<UBToSPIRVConversionPass> {
+ using Base::Base;
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
+ std::unique_ptr<SPIRVConversionTarget> target =
+ SPIRVConversionTarget::get(targetAttr);
+
+ SPIRVConversionOptions options;
+ SPIRVTypeConverter typeConverter(targetAttr, options);
+
+ RewritePatternSet patterns(&getContext());
+ ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
+
+ if (failed(applyPartialConversion(op, *target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern Population
+//===----------------------------------------------------------------------===//
+
+void mlir::ub::populateUBToSPIRVConversionPatterns(
+ SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
+ patterns.add<PoisonOpLowering>(converter, patterns.getContext());
+}
diff --git a/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir
new file mode 100644
index 00000000000000..771b53ad123b92
--- /dev/null
+++ b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt -split-input-file -convert-ub-to-spirv -verify-diagnostics %s | FileCheck %s
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @check_poison
+func.func @check_poison() {
+// CHECK: {{.*}} = spirv.Undef : i32
+ %0 = ub.poison : index
+// CHECK: {{.*}} = spirv.Undef : i16
+ %1 = ub.poison : i16
+// CHECK: {{.*}} = spirv.Undef : f64
+ %2 = ub.poison : f64
+// TODO: vector is not covered yet
+// CHECK: {{.*}} = ub.poison : vector<4xf32>
+ %3 = ub.poison : vector<4xf32>
+ return
+}
+
+}
More information about the Mlir-commits
mailing list