[Mlir-commits] [mlir] 6e2c0e6 - [mlir][spirv] Add conversions from arith.bitcast, std.br, std.cond_br to spirv.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Oct 30 09:43:33 PDT 2021
Author: xndcn
Date: 2021-10-31T00:40:35+08:00
New Revision: 6e2c0e6931aff04265fc9e411b2871930cd626e8
URL: https://github.com/llvm/llvm-project/commit/6e2c0e6931aff04265fc9e411b2871930cd626e8
DIFF: https://github.com/llvm/llvm-project/commit/6e2c0e6931aff04265fc9e411b2871930cd626e8.diff
LOG: [mlir][spirv] Add conversions from arith.bitcast, std.br, std.cond_br to spirv.
Differential Revision: https://reviews.llvm.org/D112819
Added:
Modified:
mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index d1eed3231d6f..6fd69637df1d 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -784,6 +784,7 @@ void mlir::arith::populateArithmeticToSPIRVPatterns(
TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
+ TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
CmpIOpBooleanPattern, CmpIOpPattern,
CmpFOpNanNonePattern, CmpFOpPattern
>(typeConverter, patterns.getContext());
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index 48b6e99257bb..87d57080ed80 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -65,6 +65,24 @@ class SplatPattern final : public OpConversionPattern<SplatOp> {
ConversionPatternRewriter &rewriter) const override;
};
+/// Converts std.br to spv.Branch.
+struct BranchOpPattern final : public OpConversionPattern<BranchOp> {
+ using OpConversionPattern<BranchOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(BranchOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Converts std.cond_br to spv.BranchConditional.
+struct CondBranchOpPattern final : public OpConversionPattern<CondBranchOp> {
+ using OpConversionPattern<CondBranchOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(CondBranchOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
/// Converts tensor.extract into loading using access chains from SPIR-V local
/// variables.
class TensorExtractPattern final
@@ -176,6 +194,31 @@ SplatPattern::matchAndRewrite(SplatOp op, OpAdaptor adaptor,
return success();
}
+//===----------------------------------------------------------------------===//
+// BranchOpPattern
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+BranchOpPattern::matchAndRewrite(BranchOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
+ adaptor.getDestOperands());
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// CondBranchOpPattern
+//===----------------------------------------------------------------------===//
+
+LogicalResult CondBranchOpPattern::matchAndRewrite(
+ CondBranchOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
+ op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(),
+ op.getFalseDest(), adaptor.getFalseDestOperands());
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
@@ -194,7 +237,8 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
spirv::UnaryAndBinaryOpPattern<MinSIOp, spirv::GLSLSMinOp>,
spirv::UnaryAndBinaryOpPattern<MinUIOp, spirv::GLSLUMinOp>,
- ReturnOpPattern, SelectOpPattern, SplatPattern>(typeConverter, context);
+ ReturnOpPattern, SelectOpPattern, SplatPattern, BranchOpPattern,
+ CondBranchOpPattern>(typeConverter, context);
}
void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
index 0c2aee27ca50..8a41a90a2fc0 100644
--- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
@@ -572,6 +572,15 @@ func @index_cast4(%arg0: index) {
return
}
+// CHECK-LABEL: @bit_cast
+func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) {
+ // CHECK: spv.Bitcast %{{.+}} : vector<2xf32> to vector<2xi32>
+ %0 = arith.bitcast %arg0 : vector<2xf32> to vector<2xi32>
+ // CHECK: spv.Bitcast %{{.+}} : i64 to f64
+ %1 = arith.bitcast %arg1 : i64 to f64
+ return
+}
+
// CHECK-LABEL: @fpext1
func @fpext1(%arg0: f16) -> f64 {
// CHECK: spv.FConvert %{{.*}} : f16 to f64
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index 36a6d793e721..b8d9966c9a5b 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -933,3 +933,45 @@ func @splat(%f : f32) -> vector<4xf32> {
%splat = splat %f : vector<4xf32>
return %splat : vector<4xf32>
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// std.br, std.cond_br
+//===----------------------------------------------------------------------===//
+
+module attributes {
+ spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
+} {
+
+// CHECK-LABEL: func @simple_loop
+func @simple_loop(index, index, index) {
+^bb0(%begin : index, %end : index, %step : index):
+// CHECK-NEXT: spv.Branch ^bb1
+ br ^bb1
+
+// CHECK-NEXT: ^bb1: // pred: ^bb0
+// CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32)
+^bb1: // pred: ^bb0
+ br ^bb2(%begin : index)
+
+// CHECK: ^bb2({{.*}}: i32): // 2 preds: ^bb1, ^bb3
+// CHECK-NEXT: {{.*}} = spv.SLessThan {{.*}}, {{.*}} : i32
+// CHECK-NEXT: spv.BranchConditional {{.*}}, ^bb3, ^bb4
+^bb2(%0: index): // 2 preds: ^bb1, ^bb3
+ %1 = arith.cmpi slt, %0, %end : index
+ cond_br %1, ^bb3, ^bb4
+
+// CHECK: ^bb3: // pred: ^bb2
+// CHECK-NEXT: {{.*}} = spv.IAdd {{.*}}, {{.*}} : i32
+// CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32)
+^bb3: // pred: ^bb2
+ %2 = arith.addi %0, %step : index
+ br ^bb2(%2 : index)
+
+// CHECK: ^bb4: // pred: ^bb2
+^bb4: // pred: ^bb2
+ return
+}
+
+}
More information about the Mlir-commits
mailing list