[Mlir-commits] [mlir] [mlir][spirv] Add a generic "convert-to-spirv" pass (PR #95942)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 18 08:19:55 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Angel Zhang (angelz913)
<details>
<summary>Changes</summary>
This PR implements a MVP version of an MLIR lowering pipeline to SPIR-V. The goal of adding this pipeline is to have a better test coverage of SPIR-V compilation upstream, and enable writing simple kernels by hand. The dialects supported in this version include `arith`, `vector` (only 1-D vectors with size 2,3,4,8 or 16), `scf`, `ub`, `index`, `func` and `math`. New test cases for the pass are also included in this PR.
**Relevant links**
- [Open MLIR Meeting - YouTube Video](https://www.youtube.com/watch?v=csWPOQfgLMo)
- [Discussion on LLVM Forum](https://discourse.llvm.org/t/open-mlir-meeting-12-14-2023-discussion-on-improving-handling-of-unit-dimensions-in-the-vector-dialect/75683)
**Future plans**
- Add conversion patterns for other dialects, e.g. `gpu`, `tensor`, etc.
- Include vector transformation to unroll vectors to 1-D, and handle those with unsupported sizes.
- Implement multiple-return. SPIR-V does not support multiple return values since a `spirv.func` can only return zero or one values. It might be possible to wrap the return values in a `spirv.struct`.
- Add a conversion for `scf.parallel`.
---
Patch is 43.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95942.diff
13 Files Affected:
- (added) mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRV.h (+26)
- (modified) mlir/include/mlir/Conversion/Passes.h (+1)
- (modified) mlir/include/mlir/Conversion/Passes.td (+14-1)
- (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
- (added) mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt (+22)
- (added) mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp (+86)
- (added) mlir/test/Conversion/ConvertToSPIRV/arith.mlir (+276)
- (added) mlir/test/Conversion/ConvertToSPIRV/combined.mlir (+47)
- (added) mlir/test/Conversion/ConvertToSPIRV/index.mlir (+104)
- (added) mlir/test/Conversion/ConvertToSPIRV/scf.mlir (+47)
- (added) mlir/test/Conversion/ConvertToSPIRV/simple.mlir (+15)
- (added) mlir/test/Conversion/ConvertToSPIRV/ub.mlir (+9)
- (added) mlir/test/Conversion/ConvertToSPIRV/vector.mlir (+439)
``````````diff
diff --git a/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRV.h b/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRV.h
new file mode 100644
index 0000000000000..b539f5059b871
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRV.h
@@ -0,0 +1,26 @@
+//===- ConvertToSPIRV.h - Conversion to SPIR-V pass ---*- C++ -*-===================//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRV_H
+#define MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRV_H
+
+#include <memory>
+
+#include "mlir/Pass/Pass.h"
+
+#define GEN_PASS_DECL_CONVERTTOSPIRVPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace mlir {
+
+/// Create a pass that performs dialect conversion to SPIR-V for all dialects
+std::unique_ptr<OperationPass<>> createConvertToSPIRVPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 7700299b3a4f3..81daec6fd0138 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -30,6 +30,7 @@
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
+#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRV.h"
#include "mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index db67d6a5ff128..5747b8d38001d 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -11,7 +11,6 @@
include "mlir/Pass/PassBase.td"
-
//===----------------------------------------------------------------------===//
// ToLLVM
//===----------------------------------------------------------------------===//
@@ -31,6 +30,20 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
];
}
+//===----------------------------------------------------------------------===//
+// ToSPIRV
+//===----------------------------------------------------------------------===//
+
+def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
+ let summary = "Convert to SPIR-V";
+ let description = [{
+ This is a generic pass to convert to SPIR-V.
+ }];
+
+ let constructor = "mlir::createConvertToSPIRVPass()";
+ let options = [];
+}
+
//===----------------------------------------------------------------------===//
// AffineToStandard
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 0a03a2e133db1..e107738a4c50c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -19,6 +19,7 @@ add_subdirectory(ControlFlowToLLVM)
add_subdirectory(ControlFlowToSCF)
add_subdirectory(ControlFlowToSPIRV)
add_subdirectory(ConvertToLLVM)
+add_subdirectory(ConvertToSPIRV)
add_subdirectory(FuncToEmitC)
add_subdirectory(FuncToLLVM)
add_subdirectory(FuncToSPIRV)
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
new file mode 100644
index 0000000000000..9a93301f09b48
--- /dev/null
+++ b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
@@ -0,0 +1,22 @@
+set(LLVM_OPTIONAL_SOURCES
+ ConvertToSPIRVPass.cpp
+)
+
+add_mlir_conversion_library(MLIRConvertToSPIRVPass
+ ConvertToSPIRVPass.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ConvertToSPIRV
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRPass
+ MLIRRewrite
+ MLIRSPIRVConversion
+ MLIRSPIRVDialect
+ MLIRSupport
+ MLIRTransformUtils
+ )
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
new file mode 100644
index 0000000000000..b755e7d7ffe13
--- /dev/null
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -0,0 +1,86 @@
+//===- ConvertToSPIRVPass.cpp - MLIR SPIR-V Conversion
+//--------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
+#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRV.h"
+#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
+#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
+#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
+#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
+#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <memory>
+
+#define DEBUG_TYPE "convert-to-spirv"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTTOSPIRVPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// A pass to perform the SPIR-V conversion.
+class ConvertToSPIRVPass
+ : public impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
+
+public:
+ using impl::ConvertToSPIRVPassBase<
+ ConvertToSPIRVPass>::ConvertToSPIRVPassBase;
+
+ // Register dependent dialects for the current pass
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<spirv::SPIRVDialect>();
+ }
+
+ void runOnOperation() final {
+ MLIRContext *context = &getContext();
+ Operation *op = getOperation();
+
+ auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
+ SPIRVTypeConverter typeConverter(targetAttr);
+
+ RewritePatternSet patterns(context);
+ ScfToSPIRVContext scfToSPIRVContext;
+
+ // Populate patterns.
+ arith::populateCeilFloorDivExpandOpsPatterns(patterns);
+ arith::populateArithToSPIRVPatterns(typeConverter, patterns);
+ populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
+ populateFuncToSPIRVPatterns(typeConverter, patterns);
+ index::populateIndexToSPIRVPatterns(typeConverter, patterns);
+ populateVectorToSPIRVPatterns(typeConverter, patterns);
+ populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
+ ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
+
+ std::unique_ptr<ConversionTarget> target =
+ SPIRVConversionTarget::get(targetAttr);
+
+ FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+ if (failed(applyPartialConversion(op, *target, frozenPatterns))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<>> mlir::createConvertToSPIRVPass() {
+ return std::make_unique<ConvertToSPIRVPass>();
+}
diff --git a/mlir/test/Conversion/ConvertToSPIRV/arith.mlir b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir
new file mode 100644
index 0000000000000..823e9b4a6c3ab
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir
@@ -0,0 +1,276 @@
+// RUN: mlir-opt -convert-to-spirv -split-input-file %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// arithmetic ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @int32_scalar
+func.func @int32_scalar(%lhs: i32, %rhs: i32) {
+ // CHECK: spirv.IAdd %{{.*}}, %{{.*}}: i32
+ %0 = arith.addi %lhs, %rhs: i32
+ // CHECK: spirv.ISub %{{.*}}, %{{.*}}: i32
+ %1 = arith.subi %lhs, %rhs: i32
+ // CHECK: spirv.IMul %{{.*}}, %{{.*}}: i32
+ %2 = arith.muli %lhs, %rhs: i32
+ // CHECK: spirv.SDiv %{{.*}}, %{{.*}}: i32
+ %3 = arith.divsi %lhs, %rhs: i32
+ // CHECK: spirv.UDiv %{{.*}}, %{{.*}}: i32
+ %4 = arith.divui %lhs, %rhs: i32
+ // CHECK: spirv.UMod %{{.*}}, %{{.*}}: i32
+ %5 = arith.remui %lhs, %rhs: i32
+ return
+}
+
+// CHECK-LABEL: @int32_scalar_srem
+// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
+func.func @int32_scalar_srem(%lhs: i32, %rhs: i32) {
+ // CHECK: %[[LABS:.+]] = spirv.GL.SAbs %[[LHS]] : i32
+ // CHECK: %[[RABS:.+]] = spirv.GL.SAbs %[[RHS]] : i32
+ // CHECK: %[[ABS:.+]] = spirv.UMod %[[LABS]], %[[RABS]] : i32
+ // CHECK: %[[POS:.+]] = spirv.IEqual %[[LHS]], %[[LABS]] : i32
+ // CHECK: %[[NEG:.+]] = spirv.SNegate %[[ABS]] : i32
+ // CHECK: %{{.+}} = spirv.Select %[[POS]], %[[ABS]], %[[NEG]] : i1, i32
+ %0 = arith.remsi %lhs, %rhs: i32
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// std bit ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @bitwise_scalar
+func.func @bitwise_scalar(%arg0 : i32, %arg1 : i32) {
+ // CHECK: spirv.BitwiseAnd
+ %0 = arith.andi %arg0, %arg1 : i32
+ // CHECK: spirv.BitwiseOr
+ %1 = arith.ori %arg0, %arg1 : i32
+ // CHECK: spirv.BitwiseXor
+ %2 = arith.xori %arg0, %arg1 : i32
+ return
+}
+
+// CHECK-LABEL: @bitwise_vector
+func.func @bitwise_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
+ // CHECK: spirv.BitwiseAnd
+ %0 = arith.andi %arg0, %arg1 : vector<4xi32>
+ // CHECK: spirv.BitwiseOr
+ %1 = arith.ori %arg0, %arg1 : vector<4xi32>
+ // CHECK: spirv.BitwiseXor
+ %2 = arith.xori %arg0, %arg1 : vector<4xi32>
+ return
+}
+
+// CHECK-LABEL: @logical_scalar
+func.func @logical_scalar(%arg0 : i1, %arg1 : i1) {
+ // CHECK: spirv.LogicalAnd
+ %0 = arith.andi %arg0, %arg1 : i1
+ // CHECK: spirv.LogicalOr
+ %1 = arith.ori %arg0, %arg1 : i1
+ // CHECK: spirv.LogicalNotEqual
+ %2 = arith.xori %arg0, %arg1 : i1
+ return
+}
+
+// CHECK-LABEL: @logical_vector
+func.func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
+ // CHECK: spirv.LogicalAnd
+ %0 = arith.andi %arg0, %arg1 : vector<4xi1>
+ // CHECK: spirv.LogicalOr
+ %1 = arith.ori %arg0, %arg1 : vector<4xi1>
+ // CHECK: spirv.LogicalNotEqual
+ %2 = arith.xori %arg0, %arg1 : vector<4xi1>
+ return
+}
+
+// CHECK-LABEL: @shift_scalar
+func.func @shift_scalar(%arg0 : i32, %arg1 : i32) {
+ // CHECK: spirv.ShiftLeftLogical
+ %0 = arith.shli %arg0, %arg1 : i32
+ // CHECK: spirv.ShiftRightArithmetic
+ %1 = arith.shrsi %arg0, %arg1 : i32
+ // CHECK: spirv.ShiftRightLogical
+ %2 = arith.shrui %arg0, %arg1 : i32
+ return
+}
+
+// CHECK-LABEL: @shift_vector
+func.func @shift_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
+ // CHECK: spirv.ShiftLeftLogical
+ %0 = arith.shli %arg0, %arg1 : vector<4xi32>
+ // CHECK: spirv.ShiftRightArithmetic
+ %1 = arith.shrsi %arg0, %arg1 : vector<4xi32>
+ // CHECK: spirv.ShiftRightLogical
+ %2 = arith.shrui %arg0, %arg1 : vector<4xi32>
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// arith.cmpf
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @cmpf
+func.func @cmpf(%arg0 : f32, %arg1 : f32) {
+ // CHECK: spirv.FOrdEqual
+ %1 = arith.cmpf oeq, %arg0, %arg1 : f32
+ // CHECK: spirv.FOrdGreaterThan
+ %2 = arith.cmpf ogt, %arg0, %arg1 : f32
+ // CHECK: spirv.FOrdGreaterThanEqual
+ %3 = arith.cmpf oge, %arg0, %arg1 : f32
+ // CHECK: spirv.FOrdLessThan
+ %4 = arith.cmpf olt, %arg0, %arg1 : f32
+ // CHECK: spirv.FOrdLessThanEqual
+ %5 = arith.cmpf ole, %arg0, %arg1 : f32
+ // CHECK: spirv.FOrdNotEqual
+ %6 = arith.cmpf one, %arg0, %arg1 : f32
+ // CHECK: spirv.FUnordEqual
+ %7 = arith.cmpf ueq, %arg0, %arg1 : f32
+ // CHECK: spirv.FUnordGreaterThan
+ %8 = arith.cmpf ugt, %arg0, %arg1 : f32
+ // CHECK: spirv.FUnordGreaterThanEqual
+ %9 = arith.cmpf uge, %arg0, %arg1 : f32
+ // CHECK: spirv.FUnordLessThan
+ %10 = arith.cmpf ult, %arg0, %arg1 : f32
+ // CHECK: FUnordLessThanEqual
+ %11 = arith.cmpf ule, %arg0, %arg1 : f32
+ // CHECK: spirv.FUnordNotEqual
+ %12 = arith.cmpf une, %arg0, %arg1 : f32
+ return
+}
+
+// CHECK-LABEL: @vec1cmpf
+func.func @vec1cmpf(%arg0 : vector<1xf32>, %arg1 : vector<1xf32>) {
+ // CHECK: spirv.FOrdGreaterThan
+ %0 = arith.cmpf ogt, %arg0, %arg1 : vector<1xf32>
+ // CHECK: spirv.FUnordLessThan
+ %1 = arith.cmpf ult, %arg0, %arg1 : vector<1xf32>
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// arith.cmpi
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @cmpi
+func.func @cmpi(%arg0 : i32, %arg1 : i32) {
+ // CHECK: spirv.IEqual
+ %0 = arith.cmpi eq, %arg0, %arg1 : i32
+ // CHECK: spirv.INotEqual
+ %1 = arith.cmpi ne, %arg0, %arg1 : i32
+ // CHECK: spirv.SLessThan
+ %2 = arith.cmpi slt, %arg0, %arg1 : i32
+ // CHECK: spirv.SLessThanEqual
+ %3 = arith.cmpi sle, %arg0, %arg1 : i32
+ // CHECK: spirv.SGreaterThan
+ %4 = arith.cmpi sgt, %arg0, %arg1 : i32
+ // CHECK: spirv.SGreaterThanEqual
+ %5 = arith.cmpi sge, %arg0, %arg1 : i32
+ // CHECK: spirv.ULessThan
+ %6 = arith.cmpi ult, %arg0, %arg1 : i32
+ // CHECK: spirv.ULessThanEqual
+ %7 = arith.cmpi ule, %arg0, %arg1 : i32
+ // CHECK: spirv.UGreaterThan
+ %8 = arith.cmpi ugt, %arg0, %arg1 : i32
+ // CHECK: spirv.UGreaterThanEqual
+ %9 = arith.cmpi uge, %arg0, %arg1 : i32
+ return
+}
+
+// CHECK-LABEL: @indexcmpi
+func.func @indexcmpi(%arg0 : index, %arg1 : index) {
+ // CHECK: spirv.IEqual
+ %0 = arith.cmpi eq, %arg0, %arg1 : index
+ // CHECK: spirv.INotEqual
+ %1 = arith.cmpi ne, %arg0, %arg1 : index
+ // CHECK: spirv.SLessThan
+ %2 = arith.cmpi slt, %arg0, %arg1 : index
+ // CHECK: spirv.SLessThanEqual
+ %3 = arith.cmpi sle, %arg0, %arg1 : index
+ // CHECK: spirv.SGreaterThan
+ %4 = arith.cmpi sgt, %arg0, %arg1 : index
+ // CHECK: spirv.SGreaterThanEqual
+ %5 = arith.cmpi sge, %arg0, %arg1 : index
+ // CHECK: spirv.ULessThan
+ %6 = arith.cmpi ult, %arg0, %arg1 : index
+ // CHECK: spirv.ULessThanEqual
+ %7 = arith.cmpi ule, %arg0, %arg1 : index
+ // CHECK: spirv.UGreaterThan
+ %8 = arith.cmpi ugt, %arg0, %arg1 : index
+ // CHECK: spirv.UGreaterThanEqual
+ %9 = arith.cmpi uge, %arg0, %arg1 : index
+ return
+}
+
+// CHECK-LABEL: @vec1cmpi
+func.func @vec1cmpi(%arg0 : vector<1xi32>, %arg1 : vector<1xi32>) {
+ // CHECK: spirv.ULessThan
+ %0 = arith.cmpi ult, %arg0, %arg1 : vector<1xi32>
+ // CHECK: spirv.SGreaterThan
+ %1 = arith.cmpi sgt, %arg0, %arg1 : vector<1xi32>
+ return
+}
+
+// CHECK-LABEL: @boolcmpi_equality
+func.func @boolcmpi_equality(%arg0 : i1, %arg1 : i1) {
+ // CHECK: spirv.LogicalEqual
+ %0 = arith.cmpi eq, %arg0, %arg1 : i1
+ // CHECK: spirv.LogicalNotEqual
+ %1 = arith.cmpi ne, %arg0, %arg1 : i1
+ return
+}
+
+// CHECK-LABEL: @boolcmpi_unsigned
+func.func @boolcmpi_unsigned(%arg0 : i1, %arg1 : i1) {
+ // CHECK-COUNT-2: spirv.Select
+ // CHECK: spirv.UGreaterThanEqual
+ %0 = arith.cmpi uge, %arg0, %arg1 : i1
+ // CHECK-COUNT-2: spirv.Select
+ // CHECK: spirv.ULessThan
+ %1 = arith.cmpi ult, %arg0, %arg1 : i1
+ return
+}
+
+// CHECK-LABEL: @vec1boolcmpi_equality
+func.func @vec1boolcmpi_equality(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
+ // CHECK: spirv.LogicalEqual
+ %0 = arith.cmpi eq, %arg0, %arg1 : vector<1xi1>
+ // CHECK: spirv.LogicalNotEqual
+ %1 = arith.cmpi ne, %arg0, %arg1 : vector<1xi1>
+ return
+}
+
+// CHECK-LABEL: @vec1boolcmpi_unsigned
+func.func @vec1boolcmpi_unsigned(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
+ // CHECK-COUNT-2: spirv.Select
+ // CHECK: spirv.UGreaterThanEqual
+ %0 = arith.cmpi uge, %arg0, %arg1 : vector<1xi1>
+ // CHECK-COUNT-2: spirv.Select
+ // CHECK: spirv.ULessThan
+ %1 = arith.cmpi ult, %arg0, %arg1 : vector<1xi1>
+ return
+}
+
+// CHECK-LABEL: @vecboolcmpi_equality
+func.func @vecboolcmpi_equality(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
+ // CHECK: spirv.LogicalEqual
+ %0 = arith.cmpi eq, %arg0, %arg1 : vector<4xi1>
+ // CHECK: spirv.LogicalNotEqual
+ %1 = arith.cmpi ne, %arg0, %arg1 : vector<4xi1>
+ return
+}
+
+// CHECK-LABEL: @vecboolcmpi_unsigned
+func.func @vecboolcmpi_unsigned(%arg0 : vector<3xi1>, %arg1 : vector<3xi1>) {
+ // CHECK-COUNT-2: spirv.Select
+ // CHECK: spirv.UGreaterThanEqual
+ %0 = arith.cmpi uge, %arg0, %arg1 : vector<3xi1>
+ // CHECK-COUNT-2: spirv.Select
+ // CHECK: spirv.ULessThan
+ %1 = arith.cmpi ult, %arg0, %arg1 : vector<3xi1>
+ return
+}
diff --git a/mlir/test/Conversion/ConvertToSPIRV/combined.mlir b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir
new file mode 100644
index 0000000000000..9e908465cb142
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s
+
+// CHECK-LABEL: @combined
+// CHECK: %[[C0_F32:.*]] = spirv.Constant 0.000000e+00 : f32
+// CHECK: %[[C1_F32:.*]] = spirv.Constant 1.000000e+00 : f32
+// CHECK: %[[C0_I32:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[C4_I32:.*]] = spirv.Constant 4 : i32
+// CHECK: %[[C0_I32_0:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[C4_I32_0:.*]] = spirv.Constant 4 : i32
+// CHECK: %[[C1_I32:.*]] = spirv.Constant 1 : i32
+// CHECK: %[[VEC:.*]] = spirv.Constant dense<1.000000e+00> : vector<4xf32>
+// CHECK: %[[VARIABLE:.*]] = spirv.Variable : !spirv.ptr<f32, Function>
+// CHECK: spirv.mlir.loop {
+// CHECK: spirv.Branch ^[[HEADER:.*]](%[[C0_I32_0]], %[[C0_F32]] : i32, f32)
+// CHECK: ^[[HEADER]](%[[INDVAR_0:.*]]: i32, %[[INDVAR_1:.*]]: f32):
+// CHECK: %[[SLESSTHAN:.*]] = spirv.SLessThan %[[INDVAR_0]], %[[C4_I32_0]] : i32
+// CHECK: spirv.BranchConditional %[[SLESSTHAN]], ^[[BODY:.*]], ^[[MERGE:.*]]
+// CHECK: ^[[BODY]]:
+// CHECK: %[[FADD:.*]] = spirv.FAdd %[[INDVAR_1]], %[[C1_F32]] : f32
+// CHECK: %[[INSERT:.*]] = spirv.CompositeInsert %[[FADD]], %[[VEC]][0 : i32] : f32 into vector<4xf32>
+// CHECK: spirv.Store "Function" %[[VARIABLE]], %[[FADD]] : f32
+// CHECK: %[[IADD:.*]] = spirv.IAdd %[[INDVAR_0]], %[[C1_I32]] : i32
+// CHECK: spirv.Branch ^[[HEADER]](%[[IADD]], %[[FADD]] : i32, f32)
+// CHECK: ^[[MERGE]]:
+// CHECK: spirv.mlir.merge
+// CHECK: }
+// CHECK: %[[LOAD:.*]] = spirv.Load "Function" %[[VARIABLE]] : f32
+// CHECK: %[[UNDEF:.*]] = spirv.Undef : f32
+// CHECK: spirv.ReturnValue %[[UNDEF]] : f32
+func.func @combined() -> f32 {
+ %c0_f32 = arith.constant 0.0 : f32
+ %c1_f32 = arith.constant 1.0 : f32
+ %c0_i32 = arith.constant 0 : i32
+ %c4_i32 = arith.constant 4 : i32
+ %lb = index.casts %c0_i32 : i32 to index
+ %ub = index.casts %c4_i32 : i32 to index
+ %step = arith.constant 1 : index
+ %buf = vector.broadcast %c1_f32 : f32 to vector<4xf32>
+ scf.for %iv = %lb to %ub step %step iter_args(%sum_iter = %c0_f32) -> f32 {
+ %t = vector.extract %buf[0] : f32 from vector<4xf32>
+ %sum_next = arith.addf %sum_iter, %t : f32
+ vector.insert %sum_next, %buf[0] : f32 into vector<4xf32>
+ scf.yield %sum_next : f32
+ }
+ %ret = ub.poison : f32
+ return %ret : f32
+}
diff --git a/mlir/test/Conversion/ConvertToSPIRV/index.mlir b/mlir/test/Conversion/ConvertToSPIRV/index.mlir
new file mode 100644
index 0000000000000..5ad2217add0d4
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/index.mlir
@@ -0,0 +1,104 @@
+// RUN: mlir-opt %s -convert-to-spirv | FileCheck %s
+
+// CHECK-LABEL: @basic
+func.func @basic(%a: index, %b: index) {
+ // CHECK: spirv.IAdd
+ %0 = index.add %a, %b
+ // CHECK: spirv.ISub
+ %1 = index.sub %a, %b
+ // CHECK: spirv.IMul
+ %2 = index.mul %a, %b
+ // CHECK: spirv.SDiv
+ %3 = index.divs %a, %b
+ // CHECK: spirv.UDiv
+ %4 = index.divu %a, %b
+ // CHECK: spirv.SRem
+ %5 = index.rems %a, %b
+ // CHECK: spirv.UMod
+ %6 = index.remu %a, %b
+ // CHECK: spirv.GL.SMax
+ %7 = index.maxs %a, %b
+ // CHECK: spirv.GL.UMax
+ %8 = index.maxu...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/95942
More information about the Mlir-commits
mailing list