[Mlir-commits] [mlir] 571831a - [mlir] Add sub-byte type emulation support for `memref.collapse_shape` (#89962)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 26 08:29:10 PDT 2024
Author: Diego Caballero
Date: 2024-04-26T17:29:06+02:00
New Revision: 571831a680faa9615183f855fddf43fd1a9ba192
URL: https://github.com/llvm/llvm-project/commit/571831a680faa9615183f855fddf43fd1a9ba192
DIFF: https://github.com/llvm/llvm-project/commit/571831a680faa9615183f855fddf43fd1a9ba192.diff
LOG: [mlir] Add sub-byte type emulation support for `memref.collapse_shape` (#89962)
This PR adds support for `memref.collapse_shape` to sub-byte type emulation. The `memref.collapse_shape` becomes a no-opt given that we are flattening the memref as part of the emulation (i.e., we are collapsing all the dimensions).
Added:
Modified:
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 4449733f0daf06..77c108aab48070 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -13,7 +13,6 @@
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -24,7 +23,6 @@
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/MathExtras.h"
#include <cassert>
#include <type_traits>
@@ -430,6 +428,33 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertMemRefCollapseShape
+//===----------------------------------------------------------------------===//
+
+/// Emulating a `memref.collapse_shape` becomes a no-op after emulation given
+/// that we flatten memrefs to a single dimension as part of the emulation and
+/// there is no dimension to collapse any further.
+struct ConvertMemRefCollapseShape final
+ : OpConversionPattern<memref::CollapseShapeOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Value srcVal = adaptor.getSrc();
+ auto newTy = dyn_cast<MemRefType>(srcVal.getType());
+ if (!newTy)
+ return failure();
+
+ if (newTy.getRank() != 1)
+ return failure();
+
+ rewriter.replaceOp(collapseShapeOp, srcVal);
+ return success();
+ }
+};
+
} // end anonymous namespace
//===----------------------------------------------------------------------===//
@@ -442,7 +467,8 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
// Populate `memref.*` conversion patterns.
patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
- ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
+ ConvertMemRefAllocation<memref::AllocaOp>,
+ ConvertMemRefCollapseShape, ConvertMemRefLoad,
ConvertMemrefStore, ConvertMemRefAssumeAlignment,
ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
typeConverter, patterns.getContext());
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index fd37b7ff0a2713..435dcc944778db 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -430,3 +430,23 @@ func.func @rank_zero_memref_store(%arg0: i4) -> () {
// CHECK32: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i32
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
// CHECK32: return
+
+// -----
+
+func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 {
+ %arr = memref.alloc() : memref<32x8x128xi4>
+ %collapse = memref.collapse_shape %arr[[0, 1], [2]] : memref<32x8x128xi4> into memref<256x128xi4>
+ %1 = memref.load %collapse[%idx0, %idx1] : memref<256x128xi4>
+ return %1 : i4
+}
+
+// CHECK-LABEL: func.func @memref_collapse_shape_i4(
+// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<16384xi8>
+// CHECK-NOT: memref.collapse_shape
+// CHECK: memref.load %[[ALLOC]][%{{.*}}] : memref<16384xi8>
+
+// CHECK32-LABEL: func.func @memref_collapse_shape_i4(
+// CHECK32: %[[ALLOC:.*]] = memref.alloc() : memref<4096xi32>
+// CHECK32-NOT: memref.collapse_shape
+// CHECK32: memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32>
+
More information about the Mlir-commits
mailing list