[Mlir-commits] [mlir] 0c7f3d6 - [mlir] Allow to specify target type in `convertType`

Jakub Kuderski llvmlistbot at llvm.org
Wed Apr 19 10:54:40 PDT 2023


Author: Jakub Kuderski
Date: 2023-04-19T13:51:13-04:00
New Revision: 0c7f3d6c39d7dd7d240854ac347a801d62dc938b

URL: https://github.com/llvm/llvm-project/commit/0c7f3d6c39d7dd7d240854ac347a801d62dc938b
DIFF: https://github.com/llvm/llvm-project/commit/0c7f3d6c39d7dd7d240854ac347a801d62dc938b.diff

LOG: [mlir] Allow to specify target type in `convertType`

Add a new helper function for the type converter that takes care of
casting to the target type.

This is to avoid bugs where an incorrect cast function is used after
type conversion, e.g., `dyn_cast` or `cast`. These are not guaranteed to
work when type conversion fails, or when type conversion succeeds but
the provided type converted returned a type that a conversion pattern
did not expect.

I saw this being an issue in some SPIR-V passes and in mlir-hlo.

Exercise the new function in a couple of passes. As a side-effect, this
also made the code more concise.

Reviewed By: rriddle, mehdi_amini

Differential Revision: https://reviews.llvm.org/D148725

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
    mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index b5127c099a366..2821686655c2f 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -223,6 +223,14 @@ class TypeConverter {
   /// the type to convert to on success, and a null type on failure.
   Type convertType(Type t);
 
+  /// Attempts a 1-1 type conversion, expecting the result type to be
+  /// `TargetType`. Returns the converted type cast to `TargetType` on success,
+  /// and a null type on conversion or cast failure.
+  template <typename TargetType>
+  TargetType convertType(Type t) {
+    return dyn_cast_or_null<TargetType>(convertType(t));
+  }
+
   /// Convert the given set of types, filling 'results' as necessary. This
   /// returns failure if the conversion of any of the types fails, success
   /// otherwise.

diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 35aa87e11a751..9c74feb597c67 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "llvm/Support/Debug.h"
@@ -410,9 +411,12 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   bool isBool = srcBits == 1;
   if (isBool)
     srcBits = typeConverter.getOptions().boolNumBits;
-  Type pointeeType = typeConverter.convertType(memrefType)
-                         .cast<spirv::PointerType>()
-                         .getPointeeType();
+
+  auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
+  if (!pointerType)
+    return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
+
+  Type pointeeType = pointerType.getPointeeType();
   Type dstType;
   if (typeConverter.allows(spirv::Capability::Kernel)) {
     if (auto arrayType = pointeeType.dyn_cast<spirv::ArrayType>())
@@ -541,9 +545,12 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
   if (isBool)
     srcBits = typeConverter.getOptions().boolNumBits;
 
-  Type pointeeType = typeConverter.convertType(memrefType)
-                         .cast<spirv::PointerType>()
-                         .getPointeeType();
+  auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
+  if (!pointerType)
+    return rewriter.notifyMatchFailure(storeOp,
+                                       "failed to convert memref type");
+
+  Type pointeeType = pointerType.getPointeeType();
   Type dstType;
   if (typeConverter.allows(spirv::Capability::Kernel)) {
     if (auto arrayType = pointeeType.dyn_cast<spirv::ArrayType>())

diff  --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 96a58459a37b9..15e9b7f430806 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -206,7 +207,11 @@ struct ConvertConstant final : OpConversionPattern<arith::ConstantOp> {
   matchAndRewrite(arith::ConstantOp op, OpAdaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type oldType = op.getType();
-    auto newType = getTypeConverter()->convertType(oldType).cast<VectorType>();
+    auto newType = getTypeConverter()->convertType<VectorType>(oldType);
+    if (!newType)
+      return rewriter.notifyMatchFailure(
+          op, llvm::formatv("unsupported type: {0}", op.getType()));
+
     unsigned newBitWidth = newType.getElementTypeBitWidth();
     Attribute oldValue = op.getValueAttr();
 
@@ -264,9 +269,7 @@ struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {
   matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
-    auto newTy = getTypeConverter()
-                     ->convertType(op.getType())
-                     .dyn_cast_or_null<VectorType>();
+    auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
     if (!newTy)
       return rewriter.notifyMatchFailure(
           loc, llvm::formatv("unsupported type: {0}", op.getType()));
@@ -307,9 +310,8 @@ struct ConvertBitwiseBinary final : OpConversionPattern<BinaryOp> {
   matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
-    auto newTy = this->getTypeConverter()
-                     ->convertType(op.getType())
-                     .template dyn_cast_or_null<VectorType>();
+    auto newTy = this->getTypeConverter()->template convertType<VectorType>(
+        op.getType());
     if (!newTy)
       return rewriter.notifyMatchFailure(
           loc, llvm::formatv("unsupported type: {0}", op.getType()));
@@ -357,9 +359,8 @@ struct ConvertCmpI final : OpConversionPattern<arith::CmpIOp> {
   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
-    auto inputTy = getTypeConverter()
-                       ->convertType(op.getLhs().getType())
-                       .dyn_cast_or_null<VectorType>();
+    auto inputTy =
+        getTypeConverter()->convertType<VectorType>(op.getLhs().getType());
     if (!inputTy)
       return rewriter.notifyMatchFailure(
           loc, llvm::formatv("unsupported type: {0}", op.getType()));
@@ -414,9 +415,7 @@ struct ConvertMulI final : OpConversionPattern<arith::MulIOp> {
   matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
-    auto newTy = getTypeConverter()
-                     ->convertType(op.getType())
-                     .dyn_cast_or_null<VectorType>();
+    auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
     if (!newTy)
       return rewriter.notifyMatchFailure(
           loc, llvm::formatv("unsupported type: {0}", op.getType()));
@@ -457,9 +456,7 @@ struct ConvertExtSI final : OpConversionPattern<arith::ExtSIOp> {
   matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
-    auto newTy = getTypeConverter()
-                     ->convertType(op.getType())
-                     .dyn_cast_or_null<VectorType>();
+    auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
     if (!newTy)
       return rewriter.notifyMatchFailure(
           loc, llvm::formatv("unsupported type: {0}", op.getType()));
@@ -497,9 +494,7 @@ struct ConvertExtUI final : OpConversionPattern<arith::ExtUIOp> {
   matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
-    auto newTy = getTypeConverter()
-                     ->convertType(op.getType())
-                     .dyn_cast_or_null<VectorType>();
+    auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
     if (!newTy)
       return rewriter.notifyMatchFailure(
           loc, llvm::formatv("unsupported type: {0}", op.getType()));
@@ -577,9 +572,8 @@ struct ConvertIndexCastIntToIndex final : OpConversionPattern<CastOp> {
 
     Location loc = op.getLoc();
     Type inType = op.getIn().getType();
-    auto newInTy = this->getTypeConverter()
-                       ->convertType(inType)
-                       .template dyn_cast_or_null<VectorType>();
+    auto newInTy =
+        this->getTypeConverter()->template convertType<VectorType>(inType);
     if (!newInTy)
       return rewriter.notifyMatchFailure(
           loc, llvm::formatv("unsupported type: {0}", inType));
@@ -608,8 +602,7 @@ struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> {
         this->template getTypeConverter<arith::WideIntEmulationConverter>();
 
     Type resultType = op.getType();
-    auto newTy = typeConverter->convertType(resultType)
-                     .template dyn_cast_or_null<VectorType>();
+    auto newTy = typeConverter->template convertType<VectorType>(resultType);
     if (!newTy)
       return rewriter.notifyMatchFailure(
           loc, llvm::formatv("unsupported type: {0}", resultType));
@@ -640,9 +633,7 @@ struct ConvertSelect final : OpConversionPattern<arith::SelectOp> {
   matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
-    auto newTy = getTypeConverter()
-                     ->convertType(op.getType())
-                     .dyn_cast_or_null<VectorType>();
+    auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
     if (!newTy)
       return rewriter.notifyMatchFailure(
           loc, llvm::formatv("unsupported type: {0}", op.getType()));
@@ -677,8 +668,7 @@ struct ConvertShLI final : OpConversionPattern<arith::ShLIOp> {
     Location loc = op->getLoc();
 
     Type oldTy = op.getType();
-    auto newTy =
-        getTypeConverter()->convertType(oldTy).dyn_cast_or_null<VectorType>();
+    auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
     if (!newTy)
       return rewriter.notifyMatchFailure(
           loc, llvm::formatv("unsupported type: {0}", op.getType()));
@@ -767,8 +757,7 @@ struct ConvertShRUI final : OpConversionPattern<arith::ShRUIOp> {
     Location loc = op->getLoc();
 
     Type oldTy = op.getType();
-    auto newTy =
-        getTypeConverter()->convertType(oldTy).dyn_cast_or_null<VectorType>();
+    auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
     if (!newTy)
       return rewriter.notifyMatchFailure(
           loc, llvm::formatv("unsupported type: {0}", op.getType()));
@@ -857,8 +846,7 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
     Location loc = op->getLoc();
 
     Type oldTy = op.getType();
-    auto newTy =
-        getTypeConverter()->convertType(oldTy).dyn_cast_or_null<VectorType>();
+    auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
     if (!newTy)
       return rewriter.notifyMatchFailure(
           loc, llvm::formatv("unsupported type: {0}", op.getType()));
@@ -922,8 +910,7 @@ struct ConvertSIToFP final : OpConversionPattern<arith::SIToFPOp> {
 
     Value in = op.getIn();
     Type oldTy = in.getType();
-    auto newTy =
-        dyn_cast_or_null<VectorType>(getTypeConverter()->convertType(oldTy));
+    auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
     if (!newTy)
       return rewriter.notifyMatchFailure(
           loc, llvm::formatv("unsupported type: {0}", oldTy));
@@ -967,8 +954,7 @@ struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> {
     Location loc = op.getLoc();
 
     Type oldTy = op.getIn().getType();
-    auto newTy =
-        dyn_cast_or_null<VectorType>(getTypeConverter()->convertType(oldTy));
+    auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
     if (!newTy)
       return rewriter.notifyMatchFailure(
           loc, llvm::formatv("unsupported type: {0}", oldTy));


        


More information about the Mlir-commits mailing list