[Mlir-commits] [mlir] afc159b - [mlir][arith][spirv] Handle i1 sign extension in arith-to-spirv

Jakub Kuderski llvmlistbot at llvm.org
Mon Nov 14 12:08:27 PST 2022


Author: Jakub Kuderski
Date: 2022-11-14T15:07:27-05:00
New Revision: afc159bbf12ac96298070f916a35321e7953a7b4

URL: https://github.com/llvm/llvm-project/commit/afc159bbf12ac96298070f916a35321e7953a7b4
DIFF: https://github.com/llvm/llvm-project/commit/afc159bbf12ac96298070f916a35321e7953a7b4.diff

LOG: [mlir][arith][spirv] Handle i1 sign extension in arith-to-spirv

Also fix some surrounding nits.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D137974

Added: 
    

Modified: 
    mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index cf65beb924fb..a284be8ce939 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -17,8 +17,10 @@
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "llvm/ADT/APInt.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/Debug.h"
+#include <cassert>
 
 namespace mlir {
 #define GEN_PASS_DEF_CONVERTARITHTOSPIRV
@@ -118,6 +120,16 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
+/// of i1.
+struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
+  using OpConversionPattern<arith::ExtSIOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 /// Converts arith.extui to spirv.Select if the type of source is i1 or vector
 /// of i1.
 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
@@ -615,6 +627,42 @@ UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ExtSII1Pattern
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ExtSII1Pattern::matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
+                                ConversionPatternRewriter &rewriter) const {
+  Value operand = adaptor.getIn();
+  if (!isBoolScalarOrVector(operand.getType()))
+    return failure();
+
+  Location loc = op.getLoc();
+  Type dstType = getTypeConverter()->convertType(op.getResult().getType());
+
+  Value allOnes;
+  if (auto intTy = dstType.dyn_cast<IntegerType>()) {
+    unsigned componentBitwidth = intTy.getWidth();
+    allOnes = rewriter.create<spirv::ConstantOp>(
+        loc, intTy,
+        rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
+  } else if (auto vectorTy = dstType.dyn_cast<VectorType>()) {
+    unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
+    allOnes = rewriter.create<spirv::ConstantOp>(
+        loc, vectorTy,
+        SplatElementsAttr::get(vectorTy, APInt::getAllOnes(componentBitwidth)));
+  } else {
+    return rewriter.notifyMatchFailure(
+        loc, llvm::formatv("unhandled type: {0}", dstType));
+  }
+
+  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
+                                               zero);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ExtUII1Pattern
 //===----------------------------------------------------------------------===//
@@ -982,7 +1030,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
     spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>,
     spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>,
     TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern,
-    TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>,
+    TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>, ExtSII1Pattern,
     TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
     TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern,
     TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 958f0fe4f2dc..c16d3d0b2d8e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2032,7 +2032,7 @@ spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
       return builder.create<spirv::ConstantOp>(
           loc, type,
           DenseElementsAttr::get(vectorType,
-                                 IntegerAttr::get(elemType, 0.0).getValue()));
+                                 IntegerAttr::get(elemType, 0).getValue()));
     }
     if (elemType.isa<FloatType>()) {
       return builder.create<spirv::ConstantOp>(
@@ -2065,7 +2065,7 @@ spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
       return builder.create<spirv::ConstantOp>(
           loc, type,
           DenseElementsAttr::get(vectorType,
-                                 IntegerAttr::get(elemType, 1.0).getValue()));
+                                 IntegerAttr::get(elemType, 1).getValue()));
     }
     if (elemType.isa<FloatType>()) {
       return builder.create<spirv::ConstantOp>(

diff  --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 2f7fb592c896..beb52c5f3402 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -756,6 +756,28 @@ func.func @sexti2(%arg0 : i32) -> i64 {
   return %0 : i64
 }
 
+// CHECK-LABEL: @sext_bool_scalar
+// CHECK-SAME:  ([[ARG:%.+]]: i1) -> i32
+func.func @sext_bool_scalar(%arg0 : i1) -> i32 {
+  // CHECK-DAG:  [[ONES:%.+]] = spirv.Constant -1 : i32
+  // CHECK-DAG:  [[ZERO:%.+]] = spirv.Constant 0 : i32
+  // CHECK:      [[SEL:%.+]]  = spirv.Select [[ARG]], [[ONES]], [[ZERO]] : i1, i32
+  // CHECK-NEXT: return [[SEL]] : i32
+  %0 = arith.extsi %arg0 : i1 to i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: @sext_bool_vector
+// CHECK-SAME:  ([[ARG:%.+]]: vector<3xi1>) -> vector<3xi32>
+func.func @sext_bool_vector(%arg0 : vector<3xi1>) -> vector<3xi32> {
+  // CHECK-DAG:  [[ONES:%.+]] = spirv.Constant dense<-1> : vector<3xi32>
+  // CHECK-DAG:  [[ZERO:%.+]] = spirv.Constant dense<0> : vector<3xi32>
+  // CHECK:      [[SEL:%.+]]  = spirv.Select [[ARG]], [[ONES]], [[ZERO]] : vector<3xi1>, vector<3xi32>
+  // CHECK-NEXT: return [[SEL]] : vector<3xi32>
+  %0 = arith.extsi %arg0 : vector<3xi1> to vector<3xi32>
+  return %0 : vector<3xi32>
+}
+
 // CHECK-LABEL: @zexti1
 func.func @zexti1(%arg0: i16) -> i64 {
   // CHECK: spirv.UConvert %{{.*}} : i16 to i64


        


More information about the Mlir-commits mailing list