[Mlir-commits] [mlir] 2b06650 - [mlir][spirv] Fix integer type emulation with extension/truncation

Lei Zhang llvmlistbot at llvm.org
Sat Aug 12 18:42:41 PDT 2023


Author: Lei Zhang
Date: 2023-08-12T18:41:34-07:00
New Revision: 2b066501b1bcb21c408310e6cfca31ba02068736

URL: https://github.com/llvm/llvm-project/commit/2b066501b1bcb21c408310e6cfca31ba02068736
DIFF: https://github.com/llvm/llvm-project/commit/2b066501b1bcb21c408310e6cfca31ba02068736.diff

LOG: [mlir][spirv] Fix integer type emulation with extension/truncation

For integer extension or truncation with type emulation, we need
to make sure we perform masking or shifting to discard unwanted
bits to avoid pollute consumer ops.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D157788

Added: 
    

Modified: 
    mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
    mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index a8692a281366ba..9a1b7ade788e68 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -20,6 +20,7 @@
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
 #include <cassert>
 #include <memory>
 
@@ -108,6 +109,22 @@ static bool isBoolScalarOrVector(Type type) {
   return false;
 }
 
+/// Creates a scalar/vector integer constant.
+static Value getScalarOrVectorConstInt(Type type, uint64_t value,
+                                       OpBuilder &builder, Location loc) {
+  if (auto vectorType = dyn_cast<VectorType>(type)) {
+    Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
+    auto attr = SplatElementsAttr::get(vectorType, element);
+    return builder.create<spirv::ConstantOp>(loc, vectorType, attr);
+  }
+
+  if (auto intType = dyn_cast<IntegerType>(type))
+    return builder.create<spirv::ConstantOp>(
+        loc, type, builder.getIntegerAttr(type, value));
+
+  return nullptr;
+}
+
 /// Returns true if scalar/vector type `a` and `b` have the same number of
 /// bitwidth.
 static bool hasSameBitwidth(Type a, Type b) {
@@ -525,6 +542,53 @@ struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
   }
 };
 
+/// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor
+/// vector of i1.
+struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type srcType = adaptor.getIn().getType();
+    if (isBoolScalarOrVector(srcType))
+      return failure();
+
+    Type dstType = getTypeConverter()->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
+
+    if (dstType == srcType) {
+      // We can have the same source and destination type due to type emulation.
+      // Perform bit shifting to make sure we have the proper leading set bits.
+
+      unsigned srcBW =
+          getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
+      unsigned dstBW =
+          getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
+      assert(srcBW < dstBW);
+      Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
+                                                  rewriter, op.getLoc());
+
+      // First shift left to sequeeze out all leading bits beyond the original
+      // bitwidth. Here we need to use the original source and result type's
+      // bitwidth.
+      auto shiftLOp = rewriter.create<spirv::ShiftLeftLogicalOp>(
+          op.getLoc(), dstType, adaptor.getIn(), shiftSize);
+
+      // Then we perform arithmetic right shift to make sure we have the right
+      // sign bits for negative values.
+      rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
+          op, dstType, shiftLOp, shiftSize);
+    } else {
+      rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
+                                                     adaptor.getOperands());
+    }
+
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ExtUIOp
 //===----------------------------------------------------------------------===//
@@ -554,6 +618,42 @@ struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
   }
 };
 
+/// Converts arith.extui for cases where the type of source is neither i1 nor
+/// vector of i1.
+struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type srcType = adaptor.getIn().getType();
+    if (isBoolScalarOrVector(srcType))
+      return failure();
+
+    Type dstType = getTypeConverter()->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
+
+    if (dstType == srcType) {
+      // We can have the same source and destination type due to type emulation.
+      // Perform bit masking to make sure we don't pollute downstream consumers
+      // with unwanted bits. Here we need to use the original source type's
+      // bitwidth.
+      unsigned bitwidth =
+          getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
+      Value mask = getScalarOrVectorConstInt(
+          dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
+          op.getLoc());
+      rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
+                                                       adaptor.getIn(), mask);
+    } else {
+      rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
+                                                     adaptor.getOperands());
+    }
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // TruncIOp
 //===----------------------------------------------------------------------===//
@@ -588,6 +688,41 @@ struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
   }
 };
 
+/// Converts arith.trunci for cases where the type of result is neither i1
+/// nor vector of i1.
+struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type srcType = adaptor.getIn().getType();
+    Type dstType = getTypeConverter()->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
+
+    if (isBoolScalarOrVector(dstType))
+      return failure();
+
+    if (dstType == srcType) {
+      // We can have the same source and destination type due to type emulation.
+      // Perform bit masking to make sure we don't pollute downstream consumers
+      // with unwanted bits. Here we need to use the original result type's
+      // bitwidth.
+      unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
+      Value mask = getScalarOrVectorConstInt(
+          dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
+      rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
+                                                       adaptor.getIn(), mask);
+    } else {
+      // Given this is truncation, either SConvertOp or UConvertOp works.
+      rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
+                                                     adaptor.getOperands());
+    }
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // TypeCastingOp
 //===----------------------------------------------------------------------===//
@@ -981,10 +1116,10 @@ void mlir::arith::populateArithToSPIRVPatterns(
     spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>,
     spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>,
     spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>,
-    TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern,
-    TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>, ExtSII1Pattern,
+    ExtUIPattern, ExtUII1Pattern,
+    ExtSIPattern, ExtSII1Pattern,
     TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
-    TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern,
+    TruncIPattern, TruncII1Pattern,
     TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
     TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
     TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,

diff  --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index d70df982c366ad..604f85757537de 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -990,6 +990,38 @@ func.func @fpext2(%arg0 : f32) -> f64 {
   return %0: f64
 }
 
+// CHECK-LABEL: @trunci4
+//  CHECK-SAME: %[[ARG:.*]]: i32
+func.func @trunci4(%arg0 : i32) -> i4 {
+  // CHECK: %[[MASK:.+]] = spirv.Constant 15 : i32
+  // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[ARG]], %[[MASK]] : i32
+  %0 = arith.trunci %arg0 : i32 to i4
+  // CHECK: %[[RET:.+]] = builtin.unrealized_conversion_cast %[[AND]] : i32 to i4
+  // CHECK: return %[[RET]] : i4
+  return %0 : i4
+}
+
+// CHECK-LABEL: @zexti4
+func.func @zexti4(%arg0: i4) -> i32 {
+  // CHECK: %[[INPUT:.+]] = builtin.unrealized_conversion_cast %{{.+}} : i4 to i32
+  // CHECK: %[[MASK:.+]] = spirv.Constant 15 : i32
+  // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[INPUT]], %[[MASK]] : i32
+  %0 = arith.extui %arg0 : i4 to i32
+  // CHECK: return %[[AND]] : i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: @sexti4
+func.func @sexti4(%arg0: i4) -> i32 {
+  // CHECK: %[[INPUT:.+]] = builtin.unrealized_conversion_cast %arg0 : i4 to i32
+  // CHECK: %[[SIZE:.+]] = spirv.Constant 28 : i32
+  // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[INPUT]], %[[SIZE]] : i32, i32
+  // CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[SL]], %[[SIZE]] : i32, i32
+  %0 = arith.extsi %arg0 : i4 to i32
+  // CHECK: return %[[SR]] : i32
+  return %0 : i32
+}
+
 } // end module
 
 // -----


        


More information about the Mlir-commits mailing list