[llvm] [mlir] Revert "[mlir][spirv] Add a generic convert-to-spirv pass" (PR #96359)
Angel Zhang via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 21 14:16:29 PDT 2024
https://github.com/angelz913 created https://github.com/llvm/llvm-project/pull/96359
This PR relands #95492, which was reverted in #96332 due to link failures. It fixes the issue by updating CMake dependencies. The bazel support, originally introduced in #96334, is also included in this PR.
>From dc0bdbffdcc2cb428969f7cb2a59fc1e956fb8a4 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Fri, 21 Jun 2024 12:31:16 -0400
Subject: [PATCH 1/2] [mlir][spirv] Add a generic `convert-to-spirv` pass
(#95942)
This PR implements a MVP version of an MLIR lowering pipeline to SPIR-V.
The goal of adding this pipeline is to have a better test coverage of
SPIR-V compilation upstream, and enable writing simple kernels by hand.
The dialects supported in this version include `arith`, `vector` (only
1-D vectors with size 2,3,4,8 or 16), `scf`, `ub`, `index`, `func` and
`math`. New test cases for the pass are also included in this PR.
**Relevant links**
- [Open MLIR Meeting - YouTube
Video](https://www.youtube.com/watch?v=csWPOQfgLMo)
- [Discussion on LLVM
Forum](https://discourse.llvm.org/t/open-mlir-meeting-12-14-2023-discussion-on-improving-handling-of-unit-dimensions-in-the-vector-dialect/75683)
**Future plans**
- Add conversion patterns for other dialects, e.g. `gpu`, `tensor`, etc.
- Include vector transformation to unroll vectors to 1-D, and handle
those with unsupported sizes.
- Implement multiple-return. SPIR-V does not support multiple return
values since a `spirv.func` can only return zero or one values. It might
be possible to wrap the return values in a `spirv.struct`.
- Add a conversion for `scf.parallel`.
---
.../ConvertToSPIRV/ConvertToSPIRVPass.h | 22 +
mlir/include/mlir/Conversion/Passes.h | 1 +
mlir/include/mlir/Conversion/Passes.td | 12 +
mlir/lib/Conversion/CMakeLists.txt | 1 +
.../Conversion/ConvertToSPIRV/CMakeLists.txt | 32 ++
.../ConvertToSPIRV/ConvertToSPIRVPass.cpp | 71 +++
.../test/Conversion/ConvertToSPIRV/arith.mlir | 218 +++++++++
.../Conversion/ConvertToSPIRV/combined.mlir | 47 ++
.../test/Conversion/ConvertToSPIRV/index.mlir | 63 +++
mlir/test/Conversion/ConvertToSPIRV/scf.mlir | 47 ++
.../Conversion/ConvertToSPIRV/simple.mlir | 15 +
mlir/test/Conversion/ConvertToSPIRV/ub.mlir | 9 +
.../Conversion/ConvertToSPIRV/vector.mlir | 439 ++++++++++++++++++
13 files changed, 977 insertions(+)
create mode 100644 mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h
create mode 100644 mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
create mode 100644 mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
create mode 100644 mlir/test/Conversion/ConvertToSPIRV/arith.mlir
create mode 100644 mlir/test/Conversion/ConvertToSPIRV/combined.mlir
create mode 100644 mlir/test/Conversion/ConvertToSPIRV/index.mlir
create mode 100644 mlir/test/Conversion/ConvertToSPIRV/scf.mlir
create mode 100644 mlir/test/Conversion/ConvertToSPIRV/simple.mlir
create mode 100644 mlir/test/Conversion/ConvertToSPIRV/ub.mlir
create mode 100644 mlir/test/Conversion/ConvertToSPIRV/vector.mlir
diff --git a/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h b/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h
new file mode 100644
index 0000000000000..3852782247527
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h
@@ -0,0 +1,22 @@
+//===- ConvertToSPIRVPass.h - Conversion to SPIR-V pass ---*- C++ -*-=========//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H
+#define MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTTOSPIRVPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 7700299b3a4f3..8c6f85d461aea 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -30,6 +30,7 @@
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
+#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
#include "mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 2315686839c20..560b088dbe5cd 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -31,6 +31,18 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
];
}
+//===----------------------------------------------------------------------===//
+// ToSPIRV
+//===----------------------------------------------------------------------===//
+
+def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
+ let summary = "Convert to SPIR-V";
+ let description = [{
+ This is a generic pass to convert to SPIR-V.
+ }];
+ let dependentDialects = ["spirv::SPIRVDialect"];
+}
+
//===----------------------------------------------------------------------===//
// AffineToStandard
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 0a03a2e133db1..e107738a4c50c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -19,6 +19,7 @@ add_subdirectory(ControlFlowToLLVM)
add_subdirectory(ControlFlowToSCF)
add_subdirectory(ControlFlowToSPIRV)
add_subdirectory(ConvertToLLVM)
+add_subdirectory(ConvertToSPIRV)
add_subdirectory(FuncToEmitC)
add_subdirectory(FuncToLLVM)
add_subdirectory(FuncToSPIRV)
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
new file mode 100644
index 0000000000000..f7b090acf33af
--- /dev/null
+++ b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
@@ -0,0 +1,32 @@
+set(LLVM_OPTIONAL_SOURCES
+ ConvertToSPIRVPass.cpp
+)
+
+add_mlir_conversion_library(MLIRConvertToSPIRVPass
+ ConvertToSPIRVPass.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ConvertToSPIRV
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRArithToSPIRV
+ MLIRArithTransforms
+ MLIRFuncToSPIRV
+ MLIRIndexToSPIRV
+ MLIRIR
+ MLIRPass
+ MLIRRewrite
+ MLIRSCFToSPIRV
+ MLIRSPIRVConversion
+ MLIRSPIRVDialect
+ MLIRSPIRVTransforms
+ MLIRSupport
+ MLIRTransforms
+ MLIRTransformUtils
+ MLIRUBToSPIRV
+ MLIRVectorToSPIRV
+ MLIRVectorTransforms
+ )
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
new file mode 100644
index 0000000000000..b5be4654bcb25
--- /dev/null
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -0,0 +1,71 @@
+//===- ConvertToSPIRVPass.cpp - MLIR SPIR-V Conversion --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
+#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
+#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
+#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
+#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
+#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
+#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <memory>
+
+#define DEBUG_TYPE "convert-to-spirv"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTTOSPIRVPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// A pass to perform the SPIR-V conversion.
+struct ConvertToSPIRVPass final
+ : impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ Operation *op = getOperation();
+
+ spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
+ SPIRVTypeConverter typeConverter(targetAttr);
+
+ RewritePatternSet patterns(context);
+ ScfToSPIRVContext scfToSPIRVContext;
+
+ // Populate patterns.
+ arith::populateCeilFloorDivExpandOpsPatterns(patterns);
+ arith::populateArithToSPIRVPatterns(typeConverter, patterns);
+ populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
+ populateFuncToSPIRVPatterns(typeConverter, patterns);
+ index::populateIndexToSPIRVPatterns(typeConverter, patterns);
+ populateVectorToSPIRVPatterns(typeConverter, patterns);
+ populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
+ ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
+
+ std::unique_ptr<ConversionTarget> target =
+ SPIRVConversionTarget::get(targetAttr);
+
+ if (failed(applyPartialConversion(op, *target, std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+
+} // namespace
diff --git a/mlir/test/Conversion/ConvertToSPIRV/arith.mlir b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir
new file mode 100644
index 0000000000000..a2adc0ad9c7a5
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir
@@ -0,0 +1,218 @@
+// RUN: mlir-opt -convert-to-spirv -split-input-file %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// arithmetic ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @int32_scalar
+func.func @int32_scalar(%lhs: i32, %rhs: i32) {
+ // CHECK: spirv.IAdd %{{.*}}, %{{.*}}: i32
+ %0 = arith.addi %lhs, %rhs: i32
+ // CHECK: spirv.ISub %{{.*}}, %{{.*}}: i32
+ %1 = arith.subi %lhs, %rhs: i32
+ // CHECK: spirv.IMul %{{.*}}, %{{.*}}: i32
+ %2 = arith.muli %lhs, %rhs: i32
+ // CHECK: spirv.SDiv %{{.*}}, %{{.*}}: i32
+ %3 = arith.divsi %lhs, %rhs: i32
+ // CHECK: spirv.UDiv %{{.*}}, %{{.*}}: i32
+ %4 = arith.divui %lhs, %rhs: i32
+ // CHECK: spirv.UMod %{{.*}}, %{{.*}}: i32
+ %5 = arith.remui %lhs, %rhs: i32
+ return
+}
+
+// CHECK-LABEL: @int32_scalar_srem
+// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
+func.func @int32_scalar_srem(%lhs: i32, %rhs: i32) {
+ // CHECK: %[[LABS:.+]] = spirv.GL.SAbs %[[LHS]] : i32
+ // CHECK: %[[RABS:.+]] = spirv.GL.SAbs %[[RHS]] : i32
+ // CHECK: %[[ABS:.+]] = spirv.UMod %[[LABS]], %[[RABS]] : i32
+ // CHECK: %[[POS:.+]] = spirv.IEqual %[[LHS]], %[[LABS]] : i32
+ // CHECK: %[[NEG:.+]] = spirv.SNegate %[[ABS]] : i32
+ // CHECK: %{{.+}} = spirv.Select %[[POS]], %[[ABS]], %[[NEG]] : i1, i32
+ %0 = arith.remsi %lhs, %rhs: i32
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// arith bit ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @bitwise_scalar
+func.func @bitwise_scalar(%arg0 : i32, %arg1 : i32) {
+ // CHECK: spirv.BitwiseAnd
+ %0 = arith.andi %arg0, %arg1 : i32
+ // CHECK: spirv.BitwiseOr
+ %1 = arith.ori %arg0, %arg1 : i32
+ // CHECK: spirv.BitwiseXor
+ %2 = arith.xori %arg0, %arg1 : i32
+ return
+}
+
+// CHECK-LABEL: @bitwise_vector
+func.func @bitwise_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
+ // CHECK: spirv.BitwiseAnd
+ %0 = arith.andi %arg0, %arg1 : vector<4xi32>
+ // CHECK: spirv.BitwiseOr
+ %1 = arith.ori %arg0, %arg1 : vector<4xi32>
+ // CHECK: spirv.BitwiseXor
+ %2 = arith.xori %arg0, %arg1 : vector<4xi32>
+ return
+}
+
+// CHECK-LABEL: @logical_scalar
+func.func @logical_scalar(%arg0 : i1, %arg1 : i1) {
+ // CHECK: spirv.LogicalAnd
+ %0 = arith.andi %arg0, %arg1 : i1
+ // CHECK: spirv.LogicalOr
+ %1 = arith.ori %arg0, %arg1 : i1
+ // CHECK: spirv.LogicalNotEqual
+ %2 = arith.xori %arg0, %arg1 : i1
+ return
+}
+
+// CHECK-LABEL: @logical_vector
+func.func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
+ // CHECK: spirv.LogicalAnd
+ %0 = arith.andi %arg0, %arg1 : vector<4xi1>
+ // CHECK: spirv.LogicalOr
+ %1 = arith.ori %arg0, %arg1 : vector<4xi1>
+ // CHECK: spirv.LogicalNotEqual
+ %2 = arith.xori %arg0, %arg1 : vector<4xi1>
+ return
+}
+
+// CHECK-LABEL: @shift_scalar
+func.func @shift_scalar(%arg0 : i32, %arg1 : i32) {
+ // CHECK: spirv.ShiftLeftLogical
+ %0 = arith.shli %arg0, %arg1 : i32
+ // CHECK: spirv.ShiftRightArithmetic
+ %1 = arith.shrsi %arg0, %arg1 : i32
+ // CHECK: spirv.ShiftRightLogical
+ %2 = arith.shrui %arg0, %arg1 : i32
+ return
+}
+
+// CHECK-LABEL: @shift_vector
+func.func @shift_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
+ // CHECK: spirv.ShiftLeftLogical
+ %0 = arith.shli %arg0, %arg1 : vector<4xi32>
+ // CHECK: spirv.ShiftRightArithmetic
+ %1 = arith.shrsi %arg0, %arg1 : vector<4xi32>
+ // CHECK: spirv.ShiftRightLogical
+ %2 = arith.shrui %arg0, %arg1 : vector<4xi32>
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// arith.cmpf
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @cmpf
+func.func @cmpf(%arg0 : f32, %arg1 : f32) {
+ // CHECK: spirv.FOrdEqual
+ %1 = arith.cmpf oeq, %arg0, %arg1 : f32
+ return
+}
+
+// CHECK-LABEL: @vec1cmpf
+func.func @vec1cmpf(%arg0 : vector<1xf32>, %arg1 : vector<1xf32>) {
+ // CHECK: spirv.FOrdGreaterThan
+ %0 = arith.cmpf ogt, %arg0, %arg1 : vector<1xf32>
+ // CHECK: spirv.FUnordLessThan
+ %1 = arith.cmpf ult, %arg0, %arg1 : vector<1xf32>
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// arith.cmpi
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @cmpi
+func.func @cmpi(%arg0 : i32, %arg1 : i32) {
+ // CHECK: spirv.IEqual
+ %0 = arith.cmpi eq, %arg0, %arg1 : i32
+ return
+}
+
+// CHECK-LABEL: @indexcmpi
+func.func @indexcmpi(%arg0 : index, %arg1 : index) {
+ // CHECK: spirv.IEqual
+ %0 = arith.cmpi eq, %arg0, %arg1 : index
+ return
+}
+
+// CHECK-LABEL: @vec1cmpi
+func.func @vec1cmpi(%arg0 : vector<1xi32>, %arg1 : vector<1xi32>) {
+ // CHECK: spirv.ULessThan
+ %0 = arith.cmpi ult, %arg0, %arg1 : vector<1xi32>
+ // CHECK: spirv.SGreaterThan
+ %1 = arith.cmpi sgt, %arg0, %arg1 : vector<1xi32>
+ return
+}
+
+// CHECK-LABEL: @boolcmpi_equality
+func.func @boolcmpi_equality(%arg0 : i1, %arg1 : i1) {
+ // CHECK: spirv.LogicalEqual
+ %0 = arith.cmpi eq, %arg0, %arg1 : i1
+ // CHECK: spirv.LogicalNotEqual
+ %1 = arith.cmpi ne, %arg0, %arg1 : i1
+ return
+}
+
+// CHECK-LABEL: @boolcmpi_unsigned
+func.func @boolcmpi_unsigned(%arg0 : i1, %arg1 : i1) {
+ // CHECK-COUNT-2: spirv.Select
+ // CHECK: spirv.UGreaterThanEqual
+ %0 = arith.cmpi uge, %arg0, %arg1 : i1
+ // CHECK-COUNT-2: spirv.Select
+ // CHECK: spirv.ULessThan
+ %1 = arith.cmpi ult, %arg0, %arg1 : i1
+ return
+}
+
+// CHECK-LABEL: @vec1boolcmpi_equality
+func.func @vec1boolcmpi_equality(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
+ // CHECK: spirv.LogicalEqual
+ %0 = arith.cmpi eq, %arg0, %arg1 : vector<1xi1>
+ // CHECK: spirv.LogicalNotEqual
+ %1 = arith.cmpi ne, %arg0, %arg1 : vector<1xi1>
+ return
+}
+
+// CHECK-LABEL: @vec1boolcmpi_unsigned
+func.func @vec1boolcmpi_unsigned(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
+ // CHECK-COUNT-2: spirv.Select
+ // CHECK: spirv.UGreaterThanEqual
+ %0 = arith.cmpi uge, %arg0, %arg1 : vector<1xi1>
+ // CHECK-COUNT-2: spirv.Select
+ // CHECK: spirv.ULessThan
+ %1 = arith.cmpi ult, %arg0, %arg1 : vector<1xi1>
+ return
+}
+
+// CHECK-LABEL: @vecboolcmpi_equality
+func.func @vecboolcmpi_equality(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
+ // CHECK: spirv.LogicalEqual
+ %0 = arith.cmpi eq, %arg0, %arg1 : vector<4xi1>
+ // CHECK: spirv.LogicalNotEqual
+ %1 = arith.cmpi ne, %arg0, %arg1 : vector<4xi1>
+ return
+}
+
+// CHECK-LABEL: @vecboolcmpi_unsigned
+func.func @vecboolcmpi_unsigned(%arg0 : vector<3xi1>, %arg1 : vector<3xi1>) {
+ // CHECK-COUNT-2: spirv.Select
+ // CHECK: spirv.UGreaterThanEqual
+ %0 = arith.cmpi uge, %arg0, %arg1 : vector<3xi1>
+ // CHECK-COUNT-2: spirv.Select
+ // CHECK: spirv.ULessThan
+ %1 = arith.cmpi ult, %arg0, %arg1 : vector<3xi1>
+ return
+}
diff --git a/mlir/test/Conversion/ConvertToSPIRV/combined.mlir b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir
new file mode 100644
index 0000000000000..9e908465cb142
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s
+
+// CHECK-LABEL: @combined
+// CHECK: %[[C0_F32:.*]] = spirv.Constant 0.000000e+00 : f32
+// CHECK: %[[C1_F32:.*]] = spirv.Constant 1.000000e+00 : f32
+// CHECK: %[[C0_I32:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[C4_I32:.*]] = spirv.Constant 4 : i32
+// CHECK: %[[C0_I32_0:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[C4_I32_0:.*]] = spirv.Constant 4 : i32
+// CHECK: %[[C1_I32:.*]] = spirv.Constant 1 : i32
+// CHECK: %[[VEC:.*]] = spirv.Constant dense<1.000000e+00> : vector<4xf32>
+// CHECK: %[[VARIABLE:.*]] = spirv.Variable : !spirv.ptr<f32, Function>
+// CHECK: spirv.mlir.loop {
+// CHECK: spirv.Branch ^[[HEADER:.*]](%[[C0_I32_0]], %[[C0_F32]] : i32, f32)
+// CHECK: ^[[HEADER]](%[[INDVAR_0:.*]]: i32, %[[INDVAR_1:.*]]: f32):
+// CHECK: %[[SLESSTHAN:.*]] = spirv.SLessThan %[[INDVAR_0]], %[[C4_I32_0]] : i32
+// CHECK: spirv.BranchConditional %[[SLESSTHAN]], ^[[BODY:.*]], ^[[MERGE:.*]]
+// CHECK: ^[[BODY]]:
+// CHECK: %[[FADD:.*]] = spirv.FAdd %[[INDVAR_1]], %[[C1_F32]] : f32
+// CHECK: %[[INSERT:.*]] = spirv.CompositeInsert %[[FADD]], %[[VEC]][0 : i32] : f32 into vector<4xf32>
+// CHECK: spirv.Store "Function" %[[VARIABLE]], %[[FADD]] : f32
+// CHECK: %[[IADD:.*]] = spirv.IAdd %[[INDVAR_0]], %[[C1_I32]] : i32
+// CHECK: spirv.Branch ^[[HEADER]](%[[IADD]], %[[FADD]] : i32, f32)
+// CHECK: ^[[MERGE]]:
+// CHECK: spirv.mlir.merge
+// CHECK: }
+// CHECK: %[[LOAD:.*]] = spirv.Load "Function" %[[VARIABLE]] : f32
+// CHECK: %[[UNDEF:.*]] = spirv.Undef : f32
+// CHECK: spirv.ReturnValue %[[UNDEF]] : f32
+func.func @combined() -> f32 {
+ %c0_f32 = arith.constant 0.0 : f32
+ %c1_f32 = arith.constant 1.0 : f32
+ %c0_i32 = arith.constant 0 : i32
+ %c4_i32 = arith.constant 4 : i32
+ %lb = index.casts %c0_i32 : i32 to index
+ %ub = index.casts %c4_i32 : i32 to index
+ %step = arith.constant 1 : index
+ %buf = vector.broadcast %c1_f32 : f32 to vector<4xf32>
+ scf.for %iv = %lb to %ub step %step iter_args(%sum_iter = %c0_f32) -> f32 {
+ %t = vector.extract %buf[0] : f32 from vector<4xf32>
+ %sum_next = arith.addf %sum_iter, %t : f32
+ vector.insert %sum_next, %buf[0] : f32 into vector<4xf32>
+ scf.yield %sum_next : f32
+ }
+ %ret = ub.poison : f32
+ return %ret : f32
+}
diff --git a/mlir/test/Conversion/ConvertToSPIRV/index.mlir b/mlir/test/Conversion/ConvertToSPIRV/index.mlir
new file mode 100644
index 0000000000000..db747625bc7b3
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/index.mlir
@@ -0,0 +1,63 @@
+// RUN: mlir-opt %s -convert-to-spirv | FileCheck %s
+
+// CHECK-LABEL: @basic
+func.func @basic(%a: index, %b: index) {
+ // CHECK: spirv.IAdd
+ %0 = index.add %a, %b
+ // CHECK: spirv.ISub
+ %1 = index.sub %a, %b
+ // CHECK: spirv.IMul
+ %2 = index.mul %a, %b
+ // CHECK: spirv.SDiv
+ %3 = index.divs %a, %b
+ // CHECK: spirv.UDiv
+ %4 = index.divu %a, %b
+ // CHECK: spirv.SRem
+ %5 = index.rems %a, %b
+ // CHECK: spirv.UMod
+ %6 = index.remu %a, %b
+ // CHECK: spirv.GL.SMax
+ %7 = index.maxs %a, %b
+ // CHECK: spirv.GL.UMax
+ %8 = index.maxu %a, %b
+ // CHECK: spirv.GL.SMin
+ %9 = index.mins %a, %b
+ // CHECK: spirv.GL.UMin
+ %10 = index.minu %a, %b
+ // CHECK: spirv.ShiftLeftLogical
+ %11 = index.shl %a, %b
+ // CHECK: spirv.ShiftRightArithmetic
+ %12 = index.shrs %a, %b
+ // CHECK: spirv.ShiftRightLogical
+ %13 = index.shru %a, %b
+ // CHECK: spirv.BitwiseAnd
+ %14 = index.and %a, %b
+ // CHECK: spirv.BitwiseOr
+ %15 = index.or %a, %b
+ // CHECK: spirv.BitwiseXor
+ %16 = index.xor %a, %b
+ return
+}
+
+// CHECK-LABEL: @cmp
+func.func @cmp(%a : index, %b : index) {
+ // CHECK: spirv.IEqual
+ %0 = index.cmp eq(%a, %b)
+ return
+}
+
+// CHECK-LABEL: @ceildivs
+// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
+// CHECK: spirv.ReturnValue %{{.*}} : i32
+func.func @ceildivs(%n: index, %m: index) -> index {
+ %result = index.ceildivs %n, %m
+ return %result : index
+}
+
+// CHECK-LABEL: @ceildivu
+// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
+// CHECK: spirv.ReturnValue %{{.*}} : i32
+func.func @ceildivu(%n: index, %m: index) -> index {
+ %result = index.ceildivu %n, %m
+ return %result : index
+}
diff --git a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
new file mode 100644
index 0000000000000..f619ca5771824
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s
+
+// CHECK-LABEL: @if_yield
+// CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr<f32, Function>
+// CHECK: spirv.mlir.selection {
+// CHECK-NEXT: spirv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]]
+// CHECK-NEXT: [[TRUE]]:
+// CHECK: %[[C0TRUE:.*]] = spirv.Constant 0.000000e+00 : f32
+// CHECK: %[[RETTRUE:.*]] = spirv.Constant 0.000000e+00 : f32
+// CHECK-DAG: spirv.Store "Function" %[[VAR]], %[[RETTRUE]] : f32
+// CHECK: spirv.Branch ^[[MERGE:.*]]
+// CHECK-NEXT: [[FALSE]]:
+// CHECK: %[[C0FALSE:.*]] = spirv.Constant 1.000000e+00 : f32
+// CHECK: %[[RETFALSE:.*]] = spirv.Constant 2.71828175 : f32
+// CHECK-DAG: spirv.Store "Function" %[[VAR]], %[[RETFALSE]] : f32
+// CHECK: spirv.Branch ^[[MERGE]]
+// CHECK-NEXT: ^[[MERGE]]:
+// CHECK: spirv.mlir.merge
+// CHECK-NEXT: }
+// CHECK-DAG: %[[OUT:.*]] = spirv.Load "Function" %[[VAR]] : f32
+// CHECK: spirv.ReturnValue %[[OUT]] : f32
+func.func @if_yield(%arg0: i1) -> f32 {
+ %0 = scf.if %arg0 -> f32 {
+ %c0 = arith.constant 0.0 : f32
+ %res = math.sqrt %c0 : f32
+ scf.yield %res : f32
+ } else {
+ %c0 = arith.constant 1.0 : f32
+ %res = math.exp %c0 : f32
+ scf.yield %res : f32
+ }
+ return %0 : f32
+}
+
+// CHECK-LABEL: @while
+func.func @while(%arg0: i32, %arg1: i32) -> i32 {
+ %c2_i32 = arith.constant 2 : i32
+ %0 = scf.while (%arg3 = %arg0) : (i32) -> (i32) {
+ %1 = arith.cmpi slt, %arg3, %arg1 : i32
+ scf.condition(%1) %arg3 : i32
+ } do {
+ ^bb0(%arg5: i32):
+ %1 = arith.muli %arg5, %c2_i32 : i32
+ scf.yield %1 : i32
+ }
+ return %0 : i32
+}
diff --git a/mlir/test/Conversion/ConvertToSPIRV/simple.mlir b/mlir/test/Conversion/ConvertToSPIRV/simple.mlir
new file mode 100644
index 0000000000000..20b2a42bc3975
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/simple.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s
+
+// CHECK-LABEL: @return_scalar
+// CHECK-SAME: %[[ARG0:.*]]: i32
+// CHECK: spirv.ReturnValue %[[ARG0]]
+func.func @return_scalar(%arg0 : i32) -> i32 {
+ return %arg0 : i32
+}
+
+// CHECK-LABEL: @return_vector
+// CHECK-SAME: %[[ARG0:.*]]: vector<4xi32>
+// CHECK: spirv.ReturnValue %[[ARG0]]
+func.func @return_vector(%arg0 : vector<4xi32>) -> vector<4xi32> {
+ return %arg0 : vector<4xi32>
+}
diff --git a/mlir/test/Conversion/ConvertToSPIRV/ub.mlir b/mlir/test/Conversion/ConvertToSPIRV/ub.mlir
new file mode 100644
index 0000000000000..66528b68f58cf
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/ub.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s
+
+// CHECK-LABEL: @ub
+// CHECK: %[[UNDEF:.*]] = spirv.Undef : i32
+// CHECK: spirv.ReturnValue %[[UNDEF]] : i32
+func.func @ub() -> index {
+ %0 = ub.poison : index
+ return %0 : index
+}
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
new file mode 100644
index 0000000000000..336f0fe10c27e
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
@@ -0,0 +1,439 @@
+// RUN: mlir-opt -split-input-file -convert-to-spirv %s | FileCheck %s
+
+// CHECK-LABEL: @extract
+// CHECK-SAME: %[[ARG:.+]]: vector<2xf32>
+// CHECK: spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
+// CHECK: spirv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
+func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
+ %0 = "vector.extract"(%arg0) <{static_position = array<i64: 0>}> : (vector<2xf32>) -> vector<1xf32>
+ %1 = "vector.extract"(%arg0) <{static_position = array<i64: 1>}> : (vector<2xf32>) -> f32
+ return %0, %1: vector<1xf32>, f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_size1_vector
+// CHECK-SAME: %[[ARG0:.+]]: f32
+// CHECK: spirv.ReturnValue %[[ARG0]] : f32
+func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 {
+ %0 = vector.extract %arg0[0] : f32 from vector<1xf32>
+ return %0: f32
+}
+
+// -----
+
+// CHECK-LABEL: @insert
+// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32
+// CHECK: spirv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32>
+func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
+ %1 = vector.insert %arg1, %arg0[2] : f32 into vector<4xf32>
+ return %1: vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_index_vector
+// CHECK: spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i32 into vector<4xi32>
+func.func @insert_index_vector(%arg0 : vector<4xindex>, %arg1: index) -> vector<4xindex> {
+ %1 = vector.insert %arg1, %arg0[2] : index into vector<4xindex>
+ return %1: vector<4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_size1_vector
+// CHECK-SAME: %[[V:.*]]: f32, %[[S:.*]]: f32
+// CHECK: spirv.ReturnValue %[[S]] : f32
+func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf32> {
+ %1 = vector.insert %arg1, %arg0[0] : f32 into vector<1xf32>
+ return %1 : vector<1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @extract_element
+// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
+// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
+func.func @extract_element(%arg0 : vector<4xf32>, %id : i32) -> f32 {
+ %0 = vector.extractelement %arg0[%id : i32] : vector<4xf32>
+ return %0: f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_element_cst
+// CHECK-SAME: %[[V:.*]]: vector<4xf32>
+// CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
+func.func @extract_element_cst(%arg0 : vector<4xf32>) -> f32 {
+ %idx = arith.constant 1 : i32
+ %0 = vector.extractelement %arg0[%idx : i32] : vector<4xf32>
+ return %0: f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_element_index
+func.func @extract_element_index(%arg0 : vector<4xf32>, %id : index) -> f32 {
+ // CHECK: spirv.VectorExtractDynamic
+ %0 = vector.extractelement %arg0[%id : index] : vector<4xf32>
+ return %0: f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_element_size1_vector
+// CHECK-SAME:(%[[S:.+]]: f32,
+func.func @extract_element_size1_vector(%arg0 : f32, %i: index) -> f32 {
+ %bcast = vector.broadcast %arg0 : f32 to vector<1xf32>
+ %0 = vector.extractelement %bcast[%i : index] : vector<1xf32>
+ // CHECK: spirv.ReturnValue %[[S]]
+ return %0: f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_element_0d_vector
+// CHECK-SAME: (%[[S:.+]]: f32)
+func.func @extract_element_0d_vector(%arg0 : f32) -> f32 {
+ %bcast = vector.broadcast %arg0 : f32 to vector<f32>
+ %0 = vector.extractelement %bcast[] : vector<f32>
+ // CHECK: spirv.ReturnValue %[[S]]
+ return %0: f32
+}
+
+// -----
+
+// CHECK-LABEL: @insert_element
+// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
+// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
+func.func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) -> vector<4xf32> {
+ %0 = vector.insertelement %val, %arg0[%id : i32] : vector<4xf32>
+ return %0: vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_element_cst
+// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>
+// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32>
+func.func @insert_element_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> {
+ %idx = arith.constant 2 : i32
+ %0 = vector.insertelement %val, %arg0[%idx : i32] : vector<4xf32>
+ return %0: vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_element_index
+func.func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
+ // CHECK: spirv.VectorInsertDynamic
+ %0 = vector.insertelement %val, %arg0[%id : index] : vector<4xf32>
+ return %0: vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_element_size1_vector
+// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32
+func.func @insert_element_size1_vector(%scalar: f32, %vector : vector<1xf32>, %i: index) -> vector<1xf32> {
+ %0 = vector.insertelement %scalar, %vector[%i : index] : vector<1xf32>
+ // CHECK: spirv.ReturnValue %[[S]]
+ return %0: vector<1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_element_0d_vector
+// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32
+func.func @insert_element_0d_vector(%scalar: f32, %vector : vector<f32>) -> vector<f32> {
+ %0 = vector.insertelement %scalar, %vector[] : vector<f32>
+ // CHECK: spirv.ReturnValue %[[S]]
+ return %0: vector<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_size1_vector
+// CHECK-SAME: %[[SUB:.*]]: f32, %[[FULL:.*]]: vector<3xf32>
+// CHECK: %[[RET:.*]] = spirv.CompositeInsert %[[SUB]], %[[FULL]][2 : i32] : f32 into vector<3xf32>
+// CHECK: spirv.ReturnValue %[[RET]] : vector<3xf32>
+func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: vector<3xf32>) -> vector<3xf32> {
+ %1 = vector.insert_strided_slice %arg0, %arg1 {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
+ return %1 : vector<3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @fma
+// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
+// CHECK: spirv.GL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32>
+func.func @fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vector<4xf32> {
+ %0 = vector.fma %a, %b, %c: vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @fma_size1_vector
+// CHECK: spirv.GL.Fma %{{.+}} : f32
+func.func @fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<1xf32>) -> vector<1xf32> {
+ %0 = vector.fma %a, %b, %c: vector<1xf32>
+ return %0 : vector<1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @splat
+// CHECK-SAME: (%[[A:.+]]: f32)
+// CHECK: %[[VAL:.+]] = spirv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]]
+// CHECK: spirv.ReturnValue %[[VAL]] : vector<4xf32>
+func.func @splat(%f : f32) -> vector<4xf32> {
+ %splat = vector.splat %f : vector<4xf32>
+ return %splat : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @splat_size1_vector
+// CHECK-SAME: (%[[A:.+]]: f32)
+// CHECK: spirv.ReturnValue %[[A]] : f32
+func.func @splat_size1_vector(%f : f32) -> vector<1xf32> {
+ %splat = vector.splat %f : vector<1xf32>
+ return %splat : vector<1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @shuffle
+// CHECK-SAME: %[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32
+// CHECK: spirv.CompositeConstruct %[[ARG0]], %[[ARG1]], %[[ARG1]], %[[ARG0]] : (f32, f32, f32, f32) -> vector<4xf32>
+func.func @shuffle(%v0 : vector<1xf32>, %v1: vector<1xf32>) -> vector<4xf32> {
+ %shuffle = vector.shuffle %v0, %v1 [0, 1, 1, 0] : vector<1xf32>, vector<1xf32>
+ return %shuffle : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @shuffle
+// CHECK-SAME: %[[V0:.+]]: vector<3xf32>, %[[V1:.+]]: vector<3xf32>
+// CHECK: spirv.VectorShuffle [3 : i32, 2 : i32, 5 : i32, 1 : i32] %[[V0]], %[[V1]] : vector<3xf32>, vector<3xf32> -> vector<4xf32>
+func.func @shuffle(%v0 : vector<3xf32>, %v1: vector<3xf32>) -> vector<4xf32> {
+ %shuffle = vector.shuffle %v0, %v1 [3, 2, 5, 1] : vector<3xf32>, vector<3xf32>
+ return %shuffle : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @shuffle
+// CHECK-SAME: %[[ARG0:.+]]: i32, %[[ARG1:.+]]: vector<3xi32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[ARG1]][1 : i32] : vector<3xi32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[ARG1]][2 : i32] : vector<3xi32>
+// CHECK: %[[RES:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[S1]], %[[S2]] : (i32, i32, i32) -> vector<3xi32>
+// CHECK: spirv.ReturnValue %[[RES]]
+func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<3xi32>) -> vector<3xi32> {
+ %shuffle = vector.shuffle %v0, %v1 [0, 2, 3] : vector<1xi32>, vector<3xi32>
+ return %shuffle : vector<3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @shuffle
+// CHECK-SAME: %[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32
+// CHECK: %[[RES:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ARG1]] : (i32, i32) -> vector<2xi32>
+// CHECK: spirv.ReturnValue %[[RES]]
+func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<1xi32>) -> vector<2xi32> {
+ %shuffle = vector.shuffle %v0, %v1 [0, 1] : vector<1xi32>, vector<1xi32>
+ return %shuffle : vector<2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @interleave
+// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf32>)
+// CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>
+// CHECK: spirv.ReturnValue %[[SHUFFLE]]
+func.func @interleave(%a: vector<2xf32>, %b: vector<2xf32>) -> vector<4xf32> {
+ %0 = vector.interleave %a, %b : vector<2xf32> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @interleave_size1
+// CHECK-SAME: (%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32)
+// CHECK: %[[RES:.*]] = spirv.CompositeConstruct %[[ARG0]], %[[ARG1]] : (f32, f32) -> vector<2xf32>
+// CHECK: spirv.ReturnValue %[[RES]]
+func.func @interleave_size1(%a: vector<1xf32>, %b: vector<1xf32>) -> vector<2xf32> {
+ %0 = vector.interleave %a, %b : vector<1xf32> -> vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_add
+// CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<4xi32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<4xi32>
+// CHECK: %[[S3:.+]] = spirv.CompositeExtract %[[V]][3 : i32] : vector<4xi32>
+// CHECK: %[[ADD0:.+]] = spirv.IAdd %[[S0]], %[[S1]]
+// CHECK: %[[ADD1:.+]] = spirv.IAdd %[[ADD0]], %[[S2]]
+// CHECK: %[[ADD2:.+]] = spirv.IAdd %[[ADD1]], %[[S3]]
+// CHECK: spirv.ReturnValue %[[ADD2]]
+func.func @reduction_add(%v : vector<4xi32>) -> i32 {
+ %reduce = vector.reduction <add>, %v : vector<4xi32> into i32
+ return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_addf_one_elem
+// CHECK-SAME: (%[[ARG0:.+]]: f32)
+// CHECK: spirv.ReturnValue %[[ARG0]] : f32
+func.func @reduction_addf_one_elem(%arg0: vector<1xf32>) -> f32 {
+ %red = vector.reduction <add>, %arg0 : vector<1xf32> into f32
+ return %red : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_addf_one_elem_acc
+// CHECK-SAME: (%[[ARG0:.+]]: f32, %[[ACC:.+]]: f32)
+// CHECK: %[[RES:.+]] = spirv.FAdd %[[ACC]], %[[ARG0]] : f32
+// CHECK: spirv.ReturnValue %[[RES]] : f32
+func.func @reduction_addf_one_elem_acc(%arg0: vector<1xf32>, %acc: f32) -> f32 {
+ %red = vector.reduction <add>, %arg0, %acc : vector<1xf32> into f32
+ return %red : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_mul
+// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+// CHECK: %[[MUL0:.+]] = spirv.FMul %[[S0]], %[[S1]]
+// CHECK: %[[MUL1:.+]] = spirv.FMul %[[MUL0]], %[[S2]]
+// CHECK: %[[MUL2:.+]] = spirv.FMul %[[MUL1]], %[[S]]
+// CHECK: spirv.ReturnValue %[[MUL2]]
+func.func @reduction_mul(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <mul>, %v, %s : vector<3xf32> into f32
+ return %reduce : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_maximumf
+// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+// CHECK: %[[MAX0:.+]] = spirv.GL.FMax %[[S0]], %[[S1]]
+// CHECK: %[[MAX1:.+]] = spirv.GL.FMax %[[MAX0]], %[[S2]]
+// CHECK: %[[MAX2:.+]] = spirv.GL.FMax %[[MAX1]], %[[S]]
+// CHECK: spirv.ReturnValue %[[MAX2]]
+func.func @reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <maximumf>, %v, %s : vector<3xf32> into f32
+ return %reduce : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_minimumf
+// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+// CHECK: %[[MIN0:.+]] = spirv.GL.FMin %[[S0]], %[[S1]]
+// CHECK: %[[MIN1:.+]] = spirv.GL.FMin %[[MIN0]], %[[S2]]
+// CHECK: %[[MIN2:.+]] = spirv.GL.FMin %[[MIN1]], %[[S]]
+// CHECK: spirv.ReturnValue %[[MIN2]]
+func.func @reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <minimumf>, %v, %s : vector<3xf32> into f32
+ return %reduce : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_maxsi
+// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+// CHECK: %[[MAX0:.+]] = spirv.GL.SMax %[[S0]], %[[S1]]
+// CHECK: %[[MAX1:.+]] = spirv.GL.SMax %[[MAX0]], %[[S2]]
+// CHECK: %[[MAX2:.+]] = spirv.GL.SMax %[[MAX1]], %[[S]]
+// CHECK: spirv.ReturnValue %[[MAX2]]
+func.func @reduction_maxsi(%v : vector<3xi32>, %s: i32) -> i32 {
+ %reduce = vector.reduction <maxsi>, %v, %s : vector<3xi32> into i32
+ return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_minsi
+// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+// CHECK: %[[MIN0:.+]] = spirv.GL.SMin %[[S0]], %[[S1]]
+// CHECK: %[[MIN1:.+]] = spirv.GL.SMin %[[MIN0]], %[[S2]]
+// CHECK: %[[MIN2:.+]] = spirv.GL.SMin %[[MIN1]], %[[S]]
+// CHECK: spirv.ReturnValue %[[MIN2]]
+func.func @reduction_minsi(%v : vector<3xi32>, %s: i32) -> i32 {
+ %reduce = vector.reduction <minsi>, %v, %s : vector<3xi32> into i32
+ return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_maxui
+// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+// CHECK: %[[MAX0:.+]] = spirv.GL.UMax %[[S0]], %[[S1]]
+// CHECK: %[[MAX1:.+]] = spirv.GL.UMax %[[MAX0]], %[[S2]]
+// CHECK: %[[MAX2:.+]] = spirv.GL.UMax %[[MAX1]], %[[S]]
+// CHECK: spirv.ReturnValue %[[MAX2]]
+func.func @reduction_maxui(%v : vector<3xi32>, %s: i32) -> i32 {
+ %reduce = vector.reduction <maxui>, %v, %s : vector<3xi32> into i32
+ return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_minui
+// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+// CHECK: %[[MIN0:.+]] = spirv.GL.UMin %[[S0]], %[[S1]]
+// CHECK: %[[MIN1:.+]] = spirv.GL.UMin %[[MIN0]], %[[S2]]
+// CHECK: %[[MIN2:.+]] = spirv.GL.UMin %[[MIN1]], %[[S]]
+// CHECK: spirv.ReturnValue %[[MIN2]]
+func.func @reduction_minui(%v : vector<3xi32>, %s: i32) -> i32 {
+ %reduce = vector.reduction <minui>, %v, %s : vector<3xi32> into i32
+ return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: @shape_cast_same_type
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2xf32>)
+// CHECK: spirv.ReturnValue %[[ARG0]]
+func.func @shape_cast_same_type(%arg0 : vector<2xf32>) -> vector<2xf32> {
+ %1 = vector.shape_cast %arg0 : vector<2xf32> to vector<2xf32>
+ return %arg0 : vector<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @shape_cast_size1_vector
+// CHECK-SAME: (%[[ARG0:.*]]: f32)
+// CHECK: spirv.ReturnValue %[[ARG0]]
+func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
+ %1 = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
+ return %1 : vector<1xf32>
+}
>From 6eaab4a29cac49f6ef86b61ecd78405236df50f1 Mon Sep 17 00:00:00 2001
From: Keith Smiley <keithbsmiley at gmail.com>
Date: Fri, 21 Jun 2024 09:57:46 -0700
Subject: [PATCH 2/2] [bazel] Port #95942 (#96334)
---
.../llvm-project-overlay/mlir/BUILD.bazel | 28 +++++++++++++++++++
1 file changed, 28 insertions(+)
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index d8369853e22f8..14203eb9e7060 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -824,6 +824,7 @@ mlir_c_api_cc_library(
includes = ["include"],
deps = [
":ConversionPasses",
+ ":ConvertToSPIRV",
":Pass",
],
)
@@ -4200,6 +4201,7 @@ cc_library(
":ControlFlowToSPIRV",
":ConversionPassIncGen",
":ConvertToLLVM",
+ ":ConvertToSPIRV",
":FuncToEmitC",
":FuncToLLVM",
":FuncToSPIRV",
@@ -8209,6 +8211,32 @@ cc_library(
],
)
+cc_library(
+ name = "ConvertToSPIRV",
+ srcs = ["lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp"],
+ hdrs = ["include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"],
+ includes = ["include"],
+ deps = [
+ ":ArithToSPIRV",
+ ":ArithTransforms",
+ ":ConversionPassIncGen",
+ ":FuncToSPIRV",
+ ":IR",
+ ":IndexToSPIRV",
+ ":Pass",
+ ":Rewrite",
+ ":SCFToSPIRV",
+ ":SPIRVConversion",
+ ":SPIRVDialect",
+ ":SPIRVTransforms",
+ ":TransformUtils",
+ ":Transforms",
+ ":UBToSPIRV",
+ ":VectorToSPIRV",
+ ":VectorTransforms",
+ ],
+)
+
cc_library(
name = "ControlFlowToSCF",
srcs = [
More information about the llvm-commits
mailing list