[Mlir-commits] [mlir] f5aee1f - [mlir][memref] Fix type conversion in emulate-wide-int and emulate-narrow-type (#112214)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 16 18:08:27 PDT 2024


Author: Longsheng Mou
Date: 2024-10-17T09:08:24+08:00
New Revision: f5aee1f18bdbc5694330a5e86eb46cf60e653d0c

URL: https://github.com/llvm/llvm-project/commit/f5aee1f18bdbc5694330a5e86eb46cf60e653d0c
DIFF: https://github.com/llvm/llvm-project/commit/f5aee1f18bdbc5694330a5e86eb46cf60e653d0c.diff

LOG: [mlir][memref] Fix type conversion in emulate-wide-int and emulate-narrow-type (#112214)

This PR follows with #112104, using `nullptr` to indicate that type
conversion failed and no fallback conversion should be attempted.

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp
    mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
    mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
    mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
    mlir/test/Dialect/MemRef/emulate-wide-int.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp
index 4be0e06fe2a5e5..fddd7c51bfbc87 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp
@@ -40,11 +40,11 @@ arith::NarrowTypeEmulationConverter::NarrowTypeEmulationConverter(
   addConversion([this](FunctionType ty) -> std::optional<Type> {
     SmallVector<Type> inputs;
     if (failed(convertTypes(ty.getInputs(), inputs)))
-      return std::nullopt;
+      return nullptr;
 
     SmallVector<Type> results;
     if (failed(convertTypes(ty.getResults(), results)))
-      return std::nullopt;
+      return nullptr;
 
     return FunctionType::get(ty.getContext(), inputs, results);
   });

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9efea066a03c85..28f9061d9873b7 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -169,8 +169,9 @@ struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
                       std::is_same<OpTy, memref::AllocaOp>(),
                   "expected only memref::AllocOp or memref::AllocaOp");
     auto currentType = cast<MemRefType>(op.getMemref().getType());
-    auto newResultType = dyn_cast<MemRefType>(
-        this->getTypeConverter()->convertType(op.getType()));
+    auto newResultType =
+        this->getTypeConverter()->template convertType<MemRefType>(
+            op.getType());
     if (!newResultType) {
       return rewriter.notifyMatchFailure(
           op->getLoc(),
@@ -378,7 +379,7 @@ struct ConvertMemRefReinterpretCast final
   matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     MemRefType newTy =
-        dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+        getTypeConverter()->convertType<MemRefType>(op.getType());
     if (!newTy) {
       return rewriter.notifyMatchFailure(
           op->getLoc(),
@@ -466,8 +467,8 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
   LogicalResult
   matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    MemRefType newTy = dyn_cast<MemRefType>(
-        getTypeConverter()->convertType(subViewOp.getType()));
+    MemRefType newTy =
+        getTypeConverter()->convertType<MemRefType>(subViewOp.getType());
     if (!newTy) {
       return rewriter.notifyMatchFailure(
           subViewOp->getLoc(),
@@ -632,14 +633,14 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
         SmallVector<int64_t> strides;
         int64_t offset;
         if (failed(getStridesAndOffset(ty, strides, offset)))
-          return std::nullopt;
+          return nullptr;
         if (!strides.empty() && strides.back() != 1)
-          return std::nullopt;
+          return nullptr;
 
         auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
                                           intTy.getSignedness());
         if (!newElemTy)
-          return std::nullopt;
+          return nullptr;
 
         StridedLayoutAttr layoutAttr;
         // If the offset is 0, we do not need a strided layout as the stride is

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
index bc4535f97acf04..49b71625291db9 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
@@ -159,7 +159,7 @@ void memref::populateMemRefWideIntEmulationConversions(
 
         Type newElemTy = typeConverter.convertType(intTy);
         if (!newElemTy)
-          return std::nullopt;
+          return nullptr;
 
         return ty.cloneWith(std::nullopt, newElemTy);
       });

diff  --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 540da239fced08..1d6cbfa343ba5d 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -203,7 +203,6 @@ func.func @memref_subview_dynamic_offset_i4(%idx : index) -> i4 {
 
 // -----
 
-
 func.func @negative_memref_subview_non_contiguous(%idx : index) -> i4 {
   %c0 = arith.constant 0 : index
   %arr = memref.alloc() : memref<40x40xi4>
@@ -543,13 +542,15 @@ func.func @memref_copy_i4(%arg0: memref<32x128xi4, 1>, %arg1: memref<32x128xi4>)
 
 // -----
 
-!colMajor = memref<8x8xi4, strided<[1, 8]>>
-func.func @copy_distinct_layouts(%idx : index) -> i4 {
-  %c0 = arith.constant 0 : index
-  %arr = memref.alloc() : memref<8x8xi4>
-  %arr2 = memref.alloc() : !colMajor
-  // expected-error @+1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}}
-  memref.copy %arr, %arr2 : memref<8x8xi4> to !colMajor
-  %ld = memref.load %arr2[%c0, %c0] : !colMajor
-  return %ld : i4
+func.func @alloc_non_contiguous() {
+  // expected-error @+1 {{failed to legalize operation 'memref.alloc' that was explicitly marked illegal}}
+  %arr = memref.alloc() : memref<8x8xi4, strided<[1, 8]>>
+  return
+}
+
+// -----
+
+// expected-error @+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
+func.func @argument_non_contiguous(%arg0 : memref<8x8xi4, strided<[1, 8]>>) {
+  return
 }

diff  --git a/mlir/test/Dialect/MemRef/emulate-wide-int.mlir b/mlir/test/Dialect/MemRef/emulate-wide-int.mlir
index 65ac5beed0a1de..994e400bd73c1b 100644
--- a/mlir/test/Dialect/MemRef/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-wide-int.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt --memref-emulate-wide-int="widest-int-supported=32" %s | FileCheck %s
+// RUN: mlir-opt --memref-emulate-wide-int="widest-int-supported=32" %s \
+// RUN:   --split-input-file --verify-diagnostics | FileCheck %s
 
 // Expect no conversions, i32 is supported.
 // CHECK-LABEL: func @memref_i32
@@ -15,6 +16,8 @@ func.func @memref_i32() {
     return
 }
 
+// -----
+
 // Expect no conversions, f64 is not an integer type.
 // CHECK-LABEL: func @memref_f32
 // CHECK:         [[M:%.+]] = memref.alloc() : memref<4xf32, 1>
@@ -30,6 +33,8 @@ func.func @memref_f32() {
     return
 }
 
+// -----
+
 // CHECK-LABEL: func @alloc_load_store_i64
 // CHECK:         [[C1:%.+]] = arith.constant dense<[1, 0]> : vector<2xi32>
 // CHECK-NEXT:    [[M:%.+]]  = memref.alloc() : memref<4xvector<2xi32>, 1>
@@ -45,6 +50,7 @@ func.func @alloc_load_store_i64() {
     return
 }
 
+// -----
 
 // CHECK-LABEL: func @alloc_load_store_i64_nontemporal
 // CHECK:         [[C1:%.+]] = arith.constant dense<[1, 0]> : vector<2xi32>
@@ -60,3 +66,30 @@ func.func @alloc_load_store_i64_nontemporal() {
     memref.store %c1, %m[%c0] {nontemporal = true} : memref<4xi64, 1>
     return
 }
+
+// -----
+
+// Make sure we do not crash on unsupported types.
+func.func @alloc_i128() {
+  // expected-error at +1 {{failed to legalize operation 'memref.alloc' that was explicitly marked illegal}}
+  %m = memref.alloc() : memref<4xi128, 1>
+  return
+}
+
+// -----
+
+func.func @load_i128(%m: memref<4xi128, 1>) {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{failed to legalize operation 'memref.load' that was explicitly marked illegal}}
+  %v = memref.load %m[%c0] : memref<4xi128, 1>
+  return
+}
+
+// -----
+
+func.func @store_i128(%c1: i128, %m: memref<4xi128, 1>) {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{failed to legalize operation 'memref.store' that was explicitly marked illegal}}
+  memref.store %c1, %m[%c0] : memref<4xi128, 1>
+  return
+}


        


More information about the Mlir-commits mailing list