[Mlir-commits] [mlir] 6e2c0e6 - [mlir][spirv] Add conversions from arith.bitcast, std.br, std.cond_br to spirv.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Oct 30 09:43:33 PDT 2021


Author: xndcn
Date: 2021-10-31T00:40:35+08:00
New Revision: 6e2c0e6931aff04265fc9e411b2871930cd626e8

URL: https://github.com/llvm/llvm-project/commit/6e2c0e6931aff04265fc9e411b2871930cd626e8
DIFF: https://github.com/llvm/llvm-project/commit/6e2c0e6931aff04265fc9e411b2871930cd626e8.diff

LOG: [mlir][spirv] Add conversions from arith.bitcast, std.br, std.cond_br to spirv.

Differential Revision: https://reviews.llvm.org/D112819

Added: 
    

Modified: 
    mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
    mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
    mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
    mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index d1eed3231d6f..6fd69637df1d 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -784,6 +784,7 @@ void mlir::arith::populateArithmeticToSPIRVPatterns(
     TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
     TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
     TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
+    TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
     CmpIOpBooleanPattern, CmpIOpPattern,
     CmpFOpNanNonePattern, CmpFOpPattern
   >(typeConverter, patterns.getContext());

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index 48b6e99257bb..87d57080ed80 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -65,6 +65,24 @@ class SplatPattern final : public OpConversionPattern<SplatOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Converts std.br to spv.Branch.
+struct BranchOpPattern final : public OpConversionPattern<BranchOp> {
+  using OpConversionPattern<BranchOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(BranchOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Converts std.cond_br to spv.BranchConditional.
+struct CondBranchOpPattern final : public OpConversionPattern<CondBranchOp> {
+  using OpConversionPattern<CondBranchOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CondBranchOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 /// Converts tensor.extract into loading using access chains from SPIR-V local
 /// variables.
 class TensorExtractPattern final
@@ -176,6 +194,31 @@ SplatPattern::matchAndRewrite(SplatOp op, OpAdaptor adaptor,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// BranchOpPattern
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+BranchOpPattern::matchAndRewrite(BranchOp op, OpAdaptor adaptor,
+                                 ConversionPatternRewriter &rewriter) const {
+  rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
+                                               adaptor.getDestOperands());
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// CondBranchOpPattern
+//===----------------------------------------------------------------------===//
+
+LogicalResult CondBranchOpPattern::matchAndRewrite(
+    CondBranchOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
+      op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(),
+      op.getFalseDest(), adaptor.getFalseDestOperands());
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Pattern population
 //===----------------------------------------------------------------------===//
@@ -194,7 +237,8 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       spirv::UnaryAndBinaryOpPattern<MinSIOp, spirv::GLSLSMinOp>,
       spirv::UnaryAndBinaryOpPattern<MinUIOp, spirv::GLSLUMinOp>,
 
-      ReturnOpPattern, SelectOpPattern, SplatPattern>(typeConverter, context);
+      ReturnOpPattern, SelectOpPattern, SplatPattern, BranchOpPattern,
+      CondBranchOpPattern>(typeConverter, context);
 }
 
 void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,

diff  --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
index 0c2aee27ca50..8a41a90a2fc0 100644
--- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
@@ -572,6 +572,15 @@ func @index_cast4(%arg0: index) {
   return
 }
 
+// CHECK-LABEL: @bit_cast
+func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) {
+  // CHECK: spv.Bitcast %{{.+}} : vector<2xf32> to vector<2xi32>
+  %0 = arith.bitcast %arg0 : vector<2xf32> to vector<2xi32>
+  // CHECK: spv.Bitcast %{{.+}} : i64 to f64
+  %1 = arith.bitcast %arg1 : i64 to f64
+  return
+}
+
 // CHECK-LABEL: @fpext1
 func @fpext1(%arg0: f16) -> f64 {
   // CHECK: spv.FConvert %{{.*}} : f16 to f64

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index 36a6d793e721..b8d9966c9a5b 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -933,3 +933,45 @@ func @splat(%f : f32) -> vector<4xf32> {
   %splat = splat %f : vector<4xf32>
   return %splat : vector<4xf32>
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// std.br, std.cond_br
+//===----------------------------------------------------------------------===//
+
+module attributes {
+  spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
+} {
+
+// CHECK-LABEL: func @simple_loop
+func @simple_loop(index, index, index) {
+^bb0(%begin : index, %end : index, %step : index):
+// CHECK-NEXT:  spv.Branch ^bb1
+  br ^bb1
+
+// CHECK-NEXT: ^bb1:    // pred: ^bb0
+// CHECK-NEXT:  spv.Branch ^bb2({{.*}} : i32)
+^bb1:   // pred: ^bb0
+  br ^bb2(%begin : index)
+
+// CHECK:      ^bb2({{.*}}: i32):       // 2 preds: ^bb1, ^bb3
+// CHECK-NEXT:  {{.*}} = spv.SLessThan {{.*}}, {{.*}} : i32
+// CHECK-NEXT:  spv.BranchConditional {{.*}}, ^bb3, ^bb4
+^bb2(%0: index):        // 2 preds: ^bb1, ^bb3
+  %1 = arith.cmpi slt, %0, %end : index
+  cond_br %1, ^bb3, ^bb4
+
+// CHECK:      ^bb3:    // pred: ^bb2
+// CHECK-NEXT:  {{.*}} = spv.IAdd {{.*}}, {{.*}} : i32
+// CHECK-NEXT:  spv.Branch ^bb2({{.*}} : i32)
+^bb3:   // pred: ^bb2
+  %2 = arith.addi %0, %step : index
+  br ^bb2(%2 : index)
+
+// CHECK:      ^bb4:    // pred: ^bb2
+^bb4:   // pred: ^bb2
+  return
+}
+
+}


        


More information about the Mlir-commits mailing list