[Mlir-commits] [mlir] 2b06650 - [mlir][spirv] Fix integer type emulation with extension/truncation
Lei Zhang
llvmlistbot at llvm.org
Sat Aug 12 18:42:41 PDT 2023
Author: Lei Zhang
Date: 2023-08-12T18:41:34-07:00
New Revision: 2b066501b1bcb21c408310e6cfca31ba02068736
URL: https://github.com/llvm/llvm-project/commit/2b066501b1bcb21c408310e6cfca31ba02068736
DIFF: https://github.com/llvm/llvm-project/commit/2b066501b1bcb21c408310e6cfca31ba02068736.diff
LOG: [mlir][spirv] Fix integer type emulation with extension/truncation
For integer extension or truncation with type emulation, we need
to make sure we perform masking or shifting to discard unwanted
bits to avoid pollute consumer ops.
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D157788
Added:
Modified:
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index a8692a281366ba..9a1b7ade788e68 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -20,6 +20,7 @@
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
#include <cassert>
#include <memory>
@@ -108,6 +109,22 @@ static bool isBoolScalarOrVector(Type type) {
return false;
}
+/// Creates a scalar/vector integer constant.
+static Value getScalarOrVectorConstInt(Type type, uint64_t value,
+ OpBuilder &builder, Location loc) {
+ if (auto vectorType = dyn_cast<VectorType>(type)) {
+ Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
+ auto attr = SplatElementsAttr::get(vectorType, element);
+ return builder.create<spirv::ConstantOp>(loc, vectorType, attr);
+ }
+
+ if (auto intType = dyn_cast<IntegerType>(type))
+ return builder.create<spirv::ConstantOp>(
+ loc, type, builder.getIntegerAttr(type, value));
+
+ return nullptr;
+}
+
/// Returns true if scalar/vector type `a` and `b` have the same number of
/// bitwidth.
static bool hasSameBitwidth(Type a, Type b) {
@@ -525,6 +542,53 @@ struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
}
};
+/// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor
+/// vector of i1.
+struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type srcType = adaptor.getIn().getType();
+ if (isBoolScalarOrVector(srcType))
+ return failure();
+
+ Type dstType = getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
+ if (dstType == srcType) {
+ // We can have the same source and destination type due to type emulation.
+ // Perform bit shifting to make sure we have the proper leading set bits.
+
+ unsigned srcBW =
+ getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
+ unsigned dstBW =
+ getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
+ assert(srcBW < dstBW);
+ Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
+ rewriter, op.getLoc());
+
+ // First shift left to sequeeze out all leading bits beyond the original
+ // bitwidth. Here we need to use the original source and result type's
+ // bitwidth.
+ auto shiftLOp = rewriter.create<spirv::ShiftLeftLogicalOp>(
+ op.getLoc(), dstType, adaptor.getIn(), shiftSize);
+
+ // Then we perform arithmetic right shift to make sure we have the right
+ // sign bits for negative values.
+ rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
+ op, dstType, shiftLOp, shiftSize);
+ } else {
+ rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
+ adaptor.getOperands());
+ }
+
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ExtUIOp
//===----------------------------------------------------------------------===//
@@ -554,6 +618,42 @@ struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
}
};
+/// Converts arith.extui for cases where the type of source is neither i1 nor
+/// vector of i1.
+struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type srcType = adaptor.getIn().getType();
+ if (isBoolScalarOrVector(srcType))
+ return failure();
+
+ Type dstType = getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
+ if (dstType == srcType) {
+ // We can have the same source and destination type due to type emulation.
+ // Perform bit masking to make sure we don't pollute downstream consumers
+ // with unwanted bits. Here we need to use the original source type's
+ // bitwidth.
+ unsigned bitwidth =
+ getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
+ Value mask = getScalarOrVectorConstInt(
+ dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
+ op.getLoc());
+ rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
+ adaptor.getIn(), mask);
+ } else {
+ rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
+ adaptor.getOperands());
+ }
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// TruncIOp
//===----------------------------------------------------------------------===//
@@ -588,6 +688,41 @@ struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
}
};
+/// Converts arith.trunci for cases where the type of result is neither i1
+/// nor vector of i1.
+struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type srcType = adaptor.getIn().getType();
+ Type dstType = getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
+ if (isBoolScalarOrVector(dstType))
+ return failure();
+
+ if (dstType == srcType) {
+ // We can have the same source and destination type due to type emulation.
+ // Perform bit masking to make sure we don't pollute downstream consumers
+ // with unwanted bits. Here we need to use the original result type's
+ // bitwidth.
+ unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
+ Value mask = getScalarOrVectorConstInt(
+ dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
+ rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
+ adaptor.getIn(), mask);
+ } else {
+ // Given this is truncation, either SConvertOp or UConvertOp works.
+ rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
+ adaptor.getOperands());
+ }
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// TypeCastingOp
//===----------------------------------------------------------------------===//
@@ -981,10 +1116,10 @@ void mlir::arith::populateArithToSPIRVPatterns(
spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>,
spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>,
spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>,
- TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern,
- TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>, ExtSII1Pattern,
+ ExtUIPattern, ExtUII1Pattern,
+ ExtSIPattern, ExtSII1Pattern,
TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
- TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern,
+ TruncIPattern, TruncII1Pattern,
TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index d70df982c366ad..604f85757537de 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -990,6 +990,38 @@ func.func @fpext2(%arg0 : f32) -> f64 {
return %0: f64
}
+// CHECK-LABEL: @trunci4
+// CHECK-SAME: %[[ARG:.*]]: i32
+func.func @trunci4(%arg0 : i32) -> i4 {
+ // CHECK: %[[MASK:.+]] = spirv.Constant 15 : i32
+ // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[ARG]], %[[MASK]] : i32
+ %0 = arith.trunci %arg0 : i32 to i4
+ // CHECK: %[[RET:.+]] = builtin.unrealized_conversion_cast %[[AND]] : i32 to i4
+ // CHECK: return %[[RET]] : i4
+ return %0 : i4
+}
+
+// CHECK-LABEL: @zexti4
+func.func @zexti4(%arg0: i4) -> i32 {
+ // CHECK: %[[INPUT:.+]] = builtin.unrealized_conversion_cast %{{.+}} : i4 to i32
+ // CHECK: %[[MASK:.+]] = spirv.Constant 15 : i32
+ // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[INPUT]], %[[MASK]] : i32
+ %0 = arith.extui %arg0 : i4 to i32
+ // CHECK: return %[[AND]] : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: @sexti4
+func.func @sexti4(%arg0: i4) -> i32 {
+ // CHECK: %[[INPUT:.+]] = builtin.unrealized_conversion_cast %arg0 : i4 to i32
+ // CHECK: %[[SIZE:.+]] = spirv.Constant 28 : i32
+ // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[INPUT]], %[[SIZE]] : i32, i32
+ // CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[SL]], %[[SIZE]] : i32, i32
+ %0 = arith.extsi %arg0 : i4 to i32
+ // CHECK: return %[[SR]] : i32
+ return %0 : i32
+}
+
} // end module
// -----
More information about the Mlir-commits
mailing list