[Mlir-commits] [mlir] [mlir] Add sub-byte type emulation support for `memref.collapse_shape` (PR #89962)

Diego Caballero llvmlistbot at llvm.org
Wed Apr 24 10:53:33 PDT 2024


https://github.com/dcaballe created https://github.com/llvm/llvm-project/pull/89962

This PR adds support for `memref.collapse_shape` to sub-byte type emulation. The `memref.collapse_shape` becomes a no-opt given that we are flattening the memref as part of the emulation (i.e., we are collapsing all the dimensions).

>From c15f9d4ad6cd7626e8cba071c6cad1c935c5c875 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Wed, 24 Apr 2024 16:59:52 +0000
Subject: [PATCH] [mlir] Add sub-byte type emulation support for
 `memref.collapse_shape`

This PR add support for `memref.collapse_shape` to sub-byte type
emulation. The `memref.collapse_shape` becomes a no-opt given that we
are flattening the memref as part of the emulation (i.e., we are
collapsing all the dimensions).
---
 .../MemRef/Transforms/EmulateNarrowType.cpp   | 32 +++++++++++++++++--
 .../Dialect/MemRef/emulate-narrow-type.mlir   | 20 ++++++++++++
 2 files changed, 49 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 4449733f0daf06..77c108aab48070 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -13,7 +13,6 @@
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -24,7 +23,6 @@
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/MathExtras.h"
 #include <cassert>
 #include <type_traits>
 
@@ -430,6 +428,33 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertMemRefCollapseShape
+//===----------------------------------------------------------------------===//
+
+/// Emulating a `memref.collapse_shape` becomes a no-op after emulation given
+/// that we flatten memrefs to a single dimension as part of the emulation and
+/// there is no dimension to collapse any further.
+struct ConvertMemRefCollapseShape final
+    : OpConversionPattern<memref::CollapseShapeOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Value srcVal = adaptor.getSrc();
+    auto newTy = dyn_cast<MemRefType>(srcVal.getType());
+    if (!newTy)
+      return failure();
+
+    if (newTy.getRank() != 1)
+      return failure();
+
+    rewriter.replaceOp(collapseShapeOp, srcVal);
+    return success();
+  }
+};
+
 } // end anonymous namespace
 
 //===----------------------------------------------------------------------===//
@@ -442,7 +467,8 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
 
   // Populate `memref.*` conversion patterns.
   patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
-               ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
+               ConvertMemRefAllocation<memref::AllocaOp>,
+               ConvertMemRefCollapseShape, ConvertMemRefLoad,
                ConvertMemrefStore, ConvertMemRefAssumeAlignment,
                ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
       typeConverter, patterns.getContext());
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index fd37b7ff0a2713..435dcc944778db 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -430,3 +430,23 @@ func.func @rank_zero_memref_store(%arg0: i4) -> () {
 //       CHECK32:   %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i32
 //       CHECK32:   %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
 //       CHECK32:   return
+
+// -----
+
+func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 {
+  %arr = memref.alloc() : memref<32x8x128xi4>
+  %collapse = memref.collapse_shape %arr[[0, 1], [2]] : memref<32x8x128xi4> into memref<256x128xi4>
+  %1 = memref.load %collapse[%idx0, %idx1] : memref<256x128xi4>
+  return %1 : i4
+}
+
+// CHECK-LABEL:   func.func @memref_collapse_shape_i4(
+//       CHECK:     %[[ALLOC:.*]] = memref.alloc() : memref<16384xi8>
+//   CHECK-NOT:     memref.collapse_shape
+//       CHECK:     memref.load %[[ALLOC]][%{{.*}}] : memref<16384xi8>
+
+// CHECK32-LABEL:   func.func @memref_collapse_shape_i4(
+//       CHECK32:     %[[ALLOC:.*]] = memref.alloc() : memref<4096xi32>
+//   CHECK32-NOT:     memref.collapse_shape
+//       CHECK32:     memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32>
+



More information about the Mlir-commits mailing list