[Mlir-commits] [mlir] fae258e - [mlir][memref] Add initial Wide Int Emulation pass and patterns
Jakub Kuderski
llvmlistbot at llvm.org
Fri Oct 14 08:38:20 PDT 2022
Author: Jakub Kuderski
Date: 2022-10-14T11:37:52-04:00
New Revision: fae258e6c6194ddd88894b79053b0fd31df5990d
URL: https://github.com/llvm/llvm-project/commit/fae258e6c6194ddd88894b79053b0fd31df5990d
DIFF: https://github.com/llvm/llvm-project/commit/fae258e6c6194ddd88894b79053b0fd31df5990d.diff
LOG: [mlir][memref] Add initial Wide Int Emulation pass and patterns
Add a new pass and conversions to emulate wide integer operations over memrefs.
The emulation is implemented on top of the existing pass to emulate wide integer arith ops.
Improve naming in the arith pass to avoid potential name clashes.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D135722
Added:
mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
mlir/test/Dialect/MemRef/emulate-wide-int.mlir
Modified:
mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 5e441a7020708..d087ac69828a9 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -28,8 +28,8 @@ std::unique_ptr<Pass> createConstantBufferizePass(uint64_t alignment = 0);
/// Adds patterns to emulate wide Arith and Function ops over integer
/// types into supported ones. This is done by splitting original power-of-two
/// i2N integer types into two iN halves.
-void populateWideIntEmulationPatterns(WideIntEmulationConverter &typeConverter,
- RewritePatternSet &patterns);
+void populateArithWideIntEmulationPatterns(
+ WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns);
/// Add patterns to expand Arith ceil/floor division ops.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index e642dc572a0af..16ef294a90d28 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -52,9 +52,9 @@ def ArithUnsignedWhenEquivalent : Pass<"arith-unsigned-when-equivalent"> {
def ArithEmulateWideInt : Pass<"arith-emulate-wide-int"> {
let summary = "Emulate 2*N-bit integer operations using N-bit operations";
let description = [{
- Emulate integer operations that use too wide integer types with equivalent
- operations on supported narrow integer types. This is done by splitting
- original integer values into two halves.
+ Emulate arith integer operations that use too wide integer types with
+ equivalent operations on supported narrow integer types. This is done by
+ splitting original integer values into two halves.
This pass is intended preserve semantics but not necessarily provide the
most efficient implementation.
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index 2a7b5d82a4cdb..ee30e6e252dff 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -20,6 +20,10 @@ namespace mlir {
class AffineDialect;
class ModuleOp;
+namespace arith {
+class WideIntEmulationConverter;
+} // namespace arith
+
namespace func {
class FuncDialect;
} // namespace func
@@ -60,6 +64,17 @@ void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
void populateSimplifyExtractStridedMetadataOpPatterns(
RewritePatternSet &patterns);
+/// Appends patterns for emulating wide integer memref operations with ops over
+/// narrower integer types.
+void populateMemRefWideIntEmulationPatterns(
+ arith::WideIntEmulationConverter &typeConverter,
+ RewritePatternSet &patterns);
+
+/// Appends type converions for emulating wide integer memref operations with
+/// ops over narrowe integer types.
+void populateMemRefWideIntEmulationConversions(
+ arith::WideIntEmulationConverter &typeConverter);
+
/// Transformation to do multi-buffering/array expansion to remove dependencies
/// on the temporary allocation between consecutive loop iterations.
/// It returns the new allocation if the original allocation was multi-buffered
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index 64045033cabef..b41676482a889 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -28,6 +28,22 @@ def FoldMemRefAliasOps : Pass<"fold-memref-alias-ops"> {
];
}
+def MemRefEmulateWideInt : Pass<"memref-emulate-wide-int"> {
+ let summary = "Emulate 2*N-bit integer operations using N-bit operations";
+ let description = [{
+ Emulate memref integer operations that use too wide integer types with
+ equivalent operations on supported narrow integer types. This is done by
+ splitting original integer values into two halves.
+
+ Currently, only power-of-two integer bitwidths are supported.
+ }];
+ let options = [
+ Option<"widestIntSupported", "widest-int-supported", "unsigned",
+ /*default=*/"32", "Widest integer type supported by the target">,
+ ];
+ let dependentDialects = ["vector::VectorDialect"];
+}
+
def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> {
let summary = "Normalize memrefs";
let description = [{
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 9784f0d4f92e2..826c8ee96d419 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -745,7 +745,7 @@ struct EmulateWideIntPass final
opLegalCallback);
RewritePatternSet patterns(ctx);
- arith::populateWideIntEmulationPatterns(typeConverter, patterns);
+ arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
@@ -817,7 +817,7 @@ arith::WideIntEmulationConverter::WideIntEmulationConverter(
});
}
-void arith::populateWideIntEmulationPatterns(
+void arith::populateArithWideIntEmulationPatterns(
WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns) {
// Populate `func.*` conversion patterns.
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index 967a85d053bfe..2e2ffb491fb75 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRMemRefTransforms
ComposeSubView.cpp
ExpandOps.cpp
+ EmulateWideInt.cpp
FoldMemRefAliasOps.cpp
MultiBuffer.cpp
NormalizeMemRefs.cpp
@@ -17,6 +18,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
MLIRAffineDialect
MLIRAffineUtils
MLIRArithDialect
+ MLIRArithTransforms
MLIRFuncDialect
MLIRInferTypeOpInterface
MLIRLoopLikeInterface
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
new file mode 100644
index 0000000000000..02c6e5873bea6
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
@@ -0,0 +1,163 @@
+//===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+namespace mlir::memref {
+#define GEN_PASS_DEF_MEMREFEMULATEWIDEINT
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+} // namespace mlir::memref
+
+using namespace mlir;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefAlloc
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type newTy = getTypeConverter()->convertType(op.getType());
+ if (!newTy)
+ return rewriter.notifyMatchFailure(
+ op->getLoc(),
+ llvm::formatv("failed to convert memref type: {0}", op.getType()));
+
+ rewriter.replaceOpWithNewOp<memref::AllocOp>(
+ op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(),
+ adaptor.getAlignmentAttr());
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefLoad
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type newResTy = getTypeConverter()->convertType(op.getType());
+ if (!newResTy)
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
+ op.getMemRefType()));
+
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(
+ op, newResTy, adaptor.getMemref(), adaptor.getIndices());
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefStore final : OpConversionPattern<memref::StoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type newTy = getTypeConverter()->convertType(op.getMemRefType());
+ if (!newTy)
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
+ op.getMemRefType()));
+
+ rewriter.replaceOpWithNewOp<memref::StoreOp>(
+ op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices());
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
+struct EmulateWideIntPass final
+ : memref::impl::MemRefEmulateWideIntBase<EmulateWideIntPass> {
+ using MemRefEmulateWideIntBase::MemRefEmulateWideIntBase;
+
+ void runOnOperation() override {
+ if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
+ signalPassFailure();
+ return;
+ }
+
+ Operation *op = getOperation();
+ MLIRContext *ctx = op->getContext();
+
+ arith::WideIntEmulationConverter typeConverter(widestIntSupported);
+ memref::populateMemRefWideIntEmulationConversions(typeConverter);
+ ConversionTarget target(*ctx);
+ target.addDynamicallyLegalDialect<
+ arith::ArithDialect, memref::MemRefDialect, vector::VectorDialect>(
+ [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
+
+ RewritePatternSet patterns(ctx);
+ // Add common pattenrs to support contants, functions, etc.
+ arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
+
+ memref::populateMemRefWideIntEmulationPatterns(typeConverter, patterns);
+
+ if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Public Interface Definition
+//===----------------------------------------------------------------------===//
+
+void memref::populateMemRefWideIntEmulationPatterns(
+ arith::WideIntEmulationConverter &typeConverter,
+ RewritePatternSet &patterns) {
+ // Populate `memref.*` conversion patterns.
+ patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefStore>(
+ typeConverter, patterns.getContext());
+}
+
+void memref::populateMemRefWideIntEmulationConversions(
+ arith::WideIntEmulationConverter &typeConverter) {
+ typeConverter.addConversion(
+ [&typeConverter](MemRefType ty) -> Optional<Type> {
+ auto intTy = ty.getElementType().dyn_cast<IntegerType>();
+ if (!intTy)
+ return ty;
+
+ if (intTy.getIntOrFloatBitWidth() <=
+ typeConverter.getMaxTargetIntBitWidth())
+ return ty;
+
+ Type newElemTy = typeConverter.convertType(intTy);
+ if (!newElemTy)
+ return None;
+
+ return ty.cloneWith(None, newElemTy);
+ });
+}
diff --git a/mlir/test/Dialect/MemRef/emulate-wide-int.mlir b/mlir/test/Dialect/MemRef/emulate-wide-int.mlir
new file mode 100644
index 0000000000000..de1cba5c0477f
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/emulate-wide-int.mlir
@@ -0,0 +1,46 @@
+// RUN: mlir-opt --memref-emulate-wide-int="widest-int-supported=32" %s | FileCheck %s
+
+// Expect no conversions, i32 is supported.
+// CHECK-LABEL: func @memref_i32
+// CHECK: [[M:%.+]] = memref.alloc() : memref<4xi32, 1>
+// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi32, 1>
+// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi32, 1>
+// CHECK-NEXT: return
+func.func @memref_i32() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : i32
+ %m = memref.alloc() : memref<4xi32, 1>
+ %v = memref.load %m[%c0] : memref<4xi32, 1>
+ memref.store %c1, %m[%c0] : memref<4xi32, 1>
+ return
+}
+
+// Expect no conversions, f64 is not an integer type.
+// CHECK-LABEL: func @memref_f32
+// CHECK: [[M:%.+]] = memref.alloc() : memref<4xf32, 1>
+// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xf32, 1>
+// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xf32, 1>
+// CHECK-NEXT: return
+func.func @memref_f32() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1.0 : f32
+ %m = memref.alloc() : memref<4xf32, 1>
+ %v = memref.load %m[%c0] : memref<4xf32, 1>
+ memref.store %c1, %m[%c0] : memref<4xf32, 1>
+ return
+}
+
+// CHECK-LABEL: func @alloc_load_store_i64
+// CHECK: [[C1:%.+]] = arith.constant dense<[1, 0]> : vector<2xi32>
+// CHECK-NEXT: [[M:%.+]] = memref.alloc() : memref<4xvector<2xi32>, 1>
+// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xvector<2xi32>, 1>
+// CHECK-NEXT: memref.store [[C1]], [[M]][{{%.+}}] : memref<4xvector<2xi32>, 1>
+// CHECK-NEXT: return
+func.func @alloc_load_store_i64() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : i64
+ %m = memref.alloc() : memref<4xi64, 1>
+ %v = memref.load %m[%c0] : memref<4xi64, 1>
+ memref.store %c1, %m[%c0] : memref<4xi64, 1>
+ return
+}
diff --git a/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp b/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp
index ee84eadcbcdf4..c1ae321711fcd 100644
--- a/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp
+++ b/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp
@@ -74,7 +74,7 @@ struct TestEmulateWideIntPass
});
RewritePatternSet patterns(ctx);
- arith::populateWideIntEmulationPatterns(typeConverter, patterns);
+ arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
More information about the Mlir-commits
mailing list