[Mlir-commits] [mlir] afc159b - [mlir][arith][spirv] Handle i1 sign extension in arith-to-spirv
Jakub Kuderski
llvmlistbot at llvm.org
Mon Nov 14 12:08:27 PST 2022
Author: Jakub Kuderski
Date: 2022-11-14T15:07:27-05:00
New Revision: afc159bbf12ac96298070f916a35321e7953a7b4
URL: https://github.com/llvm/llvm-project/commit/afc159bbf12ac96298070f916a35321e7953a7b4
DIFF: https://github.com/llvm/llvm-project/commit/afc159bbf12ac96298070f916a35321e7953a7b4.diff
LOG: [mlir][arith][spirv] Handle i1 sign extension in arith-to-spirv
Also fix some surrounding nits.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D137974
Added:
Modified:
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index cf65beb924fb..a284be8ce939 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -17,8 +17,10 @@
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/Debug.h"
+#include <cassert>
namespace mlir {
#define GEN_PASS_DEF_CONVERTARITHTOSPIRV
@@ -118,6 +120,16 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
ConversionPatternRewriter &rewriter) const override;
};
+/// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
+/// of i1.
+struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
+ using OpConversionPattern<arith::ExtSIOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
/// Converts arith.extui to spirv.Select if the type of source is i1 or vector
/// of i1.
struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
@@ -615,6 +627,42 @@ UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
return success();
}
+//===----------------------------------------------------------------------===//
+// ExtSII1Pattern
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ExtSII1Pattern::matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Value operand = adaptor.getIn();
+ if (!isBoolScalarOrVector(operand.getType()))
+ return failure();
+
+ Location loc = op.getLoc();
+ Type dstType = getTypeConverter()->convertType(op.getResult().getType());
+
+ Value allOnes;
+ if (auto intTy = dstType.dyn_cast<IntegerType>()) {
+ unsigned componentBitwidth = intTy.getWidth();
+ allOnes = rewriter.create<spirv::ConstantOp>(
+ loc, intTy,
+ rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
+ } else if (auto vectorTy = dstType.dyn_cast<VectorType>()) {
+ unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
+ allOnes = rewriter.create<spirv::ConstantOp>(
+ loc, vectorTy,
+ SplatElementsAttr::get(vectorTy, APInt::getAllOnes(componentBitwidth)));
+ } else {
+ return rewriter.notifyMatchFailure(
+ loc, llvm::formatv("unhandled type: {0}", dstType));
+ }
+
+ Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+ rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
+ zero);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// ExtUII1Pattern
//===----------------------------------------------------------------------===//
@@ -982,7 +1030,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>,
spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>,
TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern,
- TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>,
+ TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>, ExtSII1Pattern,
TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern,
TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 958f0fe4f2dc..c16d3d0b2d8e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2032,7 +2032,7 @@ spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
return builder.create<spirv::ConstantOp>(
loc, type,
DenseElementsAttr::get(vectorType,
- IntegerAttr::get(elemType, 0.0).getValue()));
+ IntegerAttr::get(elemType, 0).getValue()));
}
if (elemType.isa<FloatType>()) {
return builder.create<spirv::ConstantOp>(
@@ -2065,7 +2065,7 @@ spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
return builder.create<spirv::ConstantOp>(
loc, type,
DenseElementsAttr::get(vectorType,
- IntegerAttr::get(elemType, 1.0).getValue()));
+ IntegerAttr::get(elemType, 1).getValue()));
}
if (elemType.isa<FloatType>()) {
return builder.create<spirv::ConstantOp>(
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 2f7fb592c896..beb52c5f3402 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -756,6 +756,28 @@ func.func @sexti2(%arg0 : i32) -> i64 {
return %0 : i64
}
+// CHECK-LABEL: @sext_bool_scalar
+// CHECK-SAME: ([[ARG:%.+]]: i1) -> i32
+func.func @sext_bool_scalar(%arg0 : i1) -> i32 {
+ // CHECK-DAG: [[ONES:%.+]] = spirv.Constant -1 : i32
+ // CHECK-DAG: [[ZERO:%.+]] = spirv.Constant 0 : i32
+ // CHECK: [[SEL:%.+]] = spirv.Select [[ARG]], [[ONES]], [[ZERO]] : i1, i32
+ // CHECK-NEXT: return [[SEL]] : i32
+ %0 = arith.extsi %arg0 : i1 to i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: @sext_bool_vector
+// CHECK-SAME: ([[ARG:%.+]]: vector<3xi1>) -> vector<3xi32>
+func.func @sext_bool_vector(%arg0 : vector<3xi1>) -> vector<3xi32> {
+ // CHECK-DAG: [[ONES:%.+]] = spirv.Constant dense<-1> : vector<3xi32>
+ // CHECK-DAG: [[ZERO:%.+]] = spirv.Constant dense<0> : vector<3xi32>
+ // CHECK: [[SEL:%.+]] = spirv.Select [[ARG]], [[ONES]], [[ZERO]] : vector<3xi1>, vector<3xi32>
+ // CHECK-NEXT: return [[SEL]] : vector<3xi32>
+ %0 = arith.extsi %arg0 : vector<3xi1> to vector<3xi32>
+ return %0 : vector<3xi32>
+}
+
// CHECK-LABEL: @zexti1
func.func @zexti1(%arg0: i16) -> i64 {
// CHECK: spirv.UConvert %{{.*}} : i16 to i64
More information about the Mlir-commits
mailing list