[Mlir-commits] [mlir] [mlir][linalg] Add runtime verification for linalg ops (PR #89917)
Ryan Holt
llvmlistbot at llvm.org
Thu Apr 25 05:03:56 PDT 2024
https://github.com/ryan-holt-1 updated https://github.com/llvm/llvm-project/pull/89917
>From ade076e1f85f51621b193dedfdb65b161c5c9ecb Mon Sep 17 00:00:00 2001
From: Ryan Holt <ryanholt at mathworks.com>
Date: Tue, 23 Apr 2024 11:18:04 -0400
Subject: [PATCH] [mlir][linalg] Add runtime verification for linalg ops
(#89342)
This commit implements runtime verification for LinalgStructuredOps
using the existing `RuntimeVerifiableOpInterface`. The verification
checks that the runtime sizes of the operands match the runtime sizes
inferred by composing the loop ranges with the op's indexing maps.
---
.../Linalg/Transforms/RuntimeOpVerification.h | 21 ++
mlir/include/mlir/InitAllDialects.h | 2 +
.../RuntimeVerifiableOpInterface.td | 6 +
.../Dialect/Linalg/Transforms/CMakeLists.txt | 2 +
.../Transforms/RuntimeOpVerification.cpp | 135 ++++++++
.../Transforms/RuntimeOpVerification.cpp | 54 ++--
.../RuntimeVerifiableOpInterface.cpp | 21 ++
.../Dialect/Linalg/runtime-verification.mlir | 43 +++
.../Linalg/CPU/runtime-verification.mlir | 298 ++++++++++++++++++
9 files changed, 549 insertions(+), 33 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h
create mode 100644 mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
create mode 100644 mlir/test/Dialect/Linalg/runtime-verification.mlir
create mode 100644 mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h b/mlir/include/mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h
new file mode 100644
index 00000000000000..6c3643f7835cbe
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h
@@ -0,0 +1,21 @@
+//===- RuntimeOpVerification.h - Op Verification ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_RUNTIMEOPVERIFICATION_H
+#define MLIR_DIALECT_LINALG_RUNTIMEOPVERIFICATION_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace linalg {
+void registerRuntimeVerifiableOpInterfaceExternalModels(
+ DialectRegistry ®istry);
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_RUNTIMEOPVERIFICATION_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index c4d788cf8ed316..d9db21073e15c7 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -45,6 +45,7 @@
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
+#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
@@ -161,6 +162,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
linalg::registerAllDialectInterfaceImplementations(registry);
+ linalg::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
memref::registerAllocationOpInterfaceExternalModels(registry);
memref::registerBufferViewFlowOpInterfaceExternalModels(registry);
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
index d5f11d00cc3d2a..6fd0df59d9d2e0 100644
--- a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
+++ b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
@@ -35,6 +35,12 @@ def RuntimeVerifiableOpInterface : OpInterface<"RuntimeVerifiableOpInterface"> {
"::mlir::Location":$loc)
>,
];
+
+ let extraClassDeclaration = [{
+ /// Generate the error message that will be printed to the user when
+ /// verification fails.
+ static std::string generateErrorMessage(Operation *op, const std::string &msg);
+ }];
}
#endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index ee6e391d0cc682..3b5282a09569d7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -27,6 +27,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
NamedOpConversions.cpp
Padding.cpp
Promotion.cpp
+ RuntimeOpVerification.cpp
Specialize.cpp
Split.cpp
SplitReduction.cpp
@@ -60,6 +61,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRFuncDialect
MLIRFuncToLLVM
MLIRFuncTransforms
+ MLIRIndexDialect
MLIRInferTypeOpInterface
MLIRIR
MLIRMemRefDialect
diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
new file mode 100644
index 00000000000000..b30182dc84079f
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
@@ -0,0 +1,135 @@
+//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Index/IR/IndexAttrs.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
+
+namespace mlir {
+namespace linalg {
+namespace {
+/// Verify that the runtime sizes of the operands to linalg structured ops are
+/// compatible with the runtime sizes inferred by composing the loop ranges with
+/// the linalg op's indexing maps. This is similar to the verifier except that
+/// here we insert IR to perform the verification at runtime.
+template <typename T>
+struct StructuredOpInterface
+ : public RuntimeVerifiableOpInterface::ExternalModel<
+ StructuredOpInterface<T>, T> {
+ void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+ Location loc) const {
+ auto linalgOp = llvm::cast<LinalgOp>(op);
+
+ SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
+ auto [starts, ends, _] = getOffsetsSizesAndStrides(loopRanges);
+
+ auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
+
+ // Subtract one from the loop ends before composing with the indexing map
+ transform(ends, ends.begin(), [&](OpFoldResult end) {
+ auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
+ return builder.createOrFold<index::SubOp>(loc, endValue, one);
+ });
+
+ for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+ AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
+ auto startIndices = affine::makeComposedFoldedMultiResultAffineApply(
+ builder, loc, indexingMap, starts);
+ auto endIndices = affine::makeComposedFoldedMultiResultAffineApply(
+ builder, loc, indexingMap, ends);
+
+ for (auto dim : llvm::seq(linalgOp.getRank(&opOperand))) {
+ auto startIndex =
+ getValueOrCreateConstantIndexOp(builder, loc, startIndices[dim]);
+ auto endIndex =
+ getValueOrCreateConstantIndexOp(builder, loc, endIndices[dim]);
+
+ // Generate:
+ // minIndex = min(startIndex, endIndex)
+ // assert(minIndex >= 0)
+ // To ensure we do not generate a negative index. We take the minimum of
+ // the start and end indices in order to handle reverse loops such as
+ // `affine_map<(i) -> (3 - i)>`
+ auto min =
+ builder.createOrFold<index::MinSOp>(loc, startIndex, endIndex);
+ auto cmpOp = builder.createOrFold<index::CmpOp>(
+ loc, index::IndexCmpPredicate::SGE, min, zero);
+ auto msg = RuntimeVerifiableOpInterface::generateErrorMessage(
+ linalgOp, "unexpected negative result on dimension #" +
+ std::to_string(dim) + " of input/output operand #" +
+ std::to_string(opOperand.getOperandNumber()));
+ builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
+
+ // Generate:
+ // inferredDimSize = max(startIndex, endIndex) + 1
+ // actualDimSize = dim(operand)
+ // assert(inferredDimSize <= actualDimSize)
+ // To ensure that we do not index past the bounds of the operands.
+ auto max =
+ builder.createOrFold<index::MaxSOp>(loc, startIndex, endIndex);
+
+ auto inferredDimSize =
+ builder.createOrFold<index::AddOp>(loc, max, one);
+
+ auto actualDimSize =
+ createOrFoldDimOp(builder, loc, opOperand.get(), dim);
+
+ // Similar to the verifier, when the affine expression in the indexing
+ // map is complicated, we just check that the inferred dimension sizes
+ // are in the boundary of the operands' size. Being more precise than
+ // that is difficult.
+ auto predicate = isa<AffineDimExpr>(indexingMap.getResult(dim))
+ ? index::IndexCmpPredicate::EQ
+ : index::IndexCmpPredicate::SLE;
+
+ cmpOp = builder.createOrFold<index::CmpOp>(
+ loc, predicate, inferredDimSize, actualDimSize);
+ msg = RuntimeVerifiableOpInterface::generateErrorMessage(
+ linalgOp, "dimension #" + std::to_string(dim) +
+ " of input/output operand #" +
+ std::to_string(opOperand.getOperandNumber()) +
+ " is incompatible with inferred dimension size");
+ builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
+ }
+ }
+ }
+};
+
+template <typename... OpTs>
+void attachInterface(MLIRContext *ctx) {
+ (OpTs::template attachInterface<StructuredOpInterface<OpTs>>(*ctx), ...);
+}
+} // namespace
+} // namespace linalg
+} // namespace mlir
+
+void mlir::linalg::registerRuntimeVerifiableOpInterfaceExternalModels(
+ DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *) {
+ attachInterface<
+#define GET_OP_LIST
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+ >(ctx);
+
+ // Load additional dialects of which ops may get created.
+ ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
+ cf::ControlFlowDialect, index::IndexDialect,
+ tensor::TensorDialect>();
+ });
+}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 05b813a3b1e908..450bfa0cec0c7f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -20,25 +20,6 @@
using namespace mlir;
-/// Generate an error message string for the given op and the specified error.
-static std::string generateErrorMessage(Operation *op, const std::string &msg) {
- std::string buffer;
- llvm::raw_string_ostream stream(buffer);
- OpPrintingFlags flags;
- // We may generate a lot of error messages and so we need to ensure the
- // printing is fast.
- flags.elideLargeElementsAttrs();
- flags.printGenericOpForm();
- flags.skipRegions();
- flags.useLocalScope();
- stream << "ERROR: Runtime op verification failed\n";
- op->print(stream, flags);
- stream << "\n^ " << msg;
- stream << "\nLocation: ";
- op->getLoc().print(stream);
- return stream.str();
-}
-
namespace mlir {
namespace memref {
namespace {
@@ -62,8 +43,10 @@ struct CastOpInterface
builder.create<arith::ConstantIndexOp>(loc, resultType.getRank());
Value isSameRank = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, srcRank, resultRank);
- builder.create<cf::AssertOp>(loc, isSameRank,
- generateErrorMessage(op, "rank mismatch"));
+ builder.create<cf::AssertOp>(
+ loc, isSameRank,
+ RuntimeVerifiableOpInterface::generateErrorMessage(op,
+ "rank mismatch"));
}
// Get source offset and strides. We do not have an op to get offsets and
@@ -101,8 +84,8 @@ struct CastOpInterface
loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
builder.create<cf::AssertOp>(
loc, isSameSz,
- generateErrorMessage(op, "size mismatch of dim " +
- std::to_string(it.index())));
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "size mismatch of dim " + std::to_string(it.index())));
}
// Get result offset and strides.
@@ -119,8 +102,10 @@ struct CastOpInterface
builder.create<arith::ConstantIndexOp>(loc, resultOffset);
Value isSameOffset = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
- builder.create<cf::AssertOp>(loc, isSameOffset,
- generateErrorMessage(op, "offset mismatch"));
+ builder.create<cf::AssertOp>(
+ loc, isSameOffset,
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "offset mismatch"));
}
// Check strides.
@@ -137,8 +122,8 @@ struct CastOpInterface
loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
builder.create<cf::AssertOp>(
loc, isSameStride,
- generateErrorMessage(op, "stride mismatch of dim " +
- std::to_string(it.index())));
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "stride mismatch of dim " + std::to_string(it.index())));
}
}
};
@@ -178,7 +163,9 @@ struct LoadStoreOpInterface
: andOp;
}
builder.create<cf::AssertOp>(
- loc, assertCond, generateErrorMessage(op, "out-of-bounds access"));
+ loc, assertCond,
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "out-of-bounds access"));
}
};
@@ -248,7 +235,7 @@ struct ReinterpretCastOpInterface
builder.create<cf::AssertOp>(
loc, assertCond,
- generateErrorMessage(
+ RuntimeVerifiableOpInterface::generateErrorMessage(
op,
"result of reinterpret_cast is out-of-bounds of the base memref"));
}
@@ -293,8 +280,8 @@ struct SubViewOpInterface
builder.create<cf::AssertOp>(
loc, assertCond,
- generateErrorMessage(op,
- "subview is out-of-bounds of the base memref"));
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "subview is out-of-bounds of the base memref"));
}
};
@@ -334,8 +321,9 @@ struct ExpandShapeOpInterface
builder.create<arith::ConstantIndexOp>(loc, 0));
builder.create<cf::AssertOp>(
loc, isModZero,
- generateErrorMessage(op, "static result dims in reassoc group do not "
- "divide src dim evenly"));
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "static result dims in reassoc group do not "
+ "divide src dim evenly"));
}
}
};
diff --git a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
index 9205d8d8c34a29..561e8d33868748 100644
--- a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
+++ b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
@@ -11,6 +11,27 @@
namespace mlir {
class Location;
class OpBuilder;
+
+/// Generate an error message string for the given op and the specified error.
+std::string
+RuntimeVerifiableOpInterface::generateErrorMessage(Operation *op,
+ const std::string &msg) {
+ std::string buffer;
+ llvm::raw_string_ostream stream(buffer);
+ OpPrintingFlags flags;
+ // We may generate a lot of error messages and so we need to ensure the
+ // printing is fast.
+ flags.elideLargeElementsAttrs();
+ flags.printGenericOpForm();
+ flags.skipRegions();
+ flags.useLocalScope();
+ stream << "ERROR: Runtime op verification failed\n";
+ op->print(stream, flags);
+ stream << "\n^ " << msg;
+ stream << "\nLocation: ";
+ op->getLoc().print(stream);
+ return stream.str();
+}
} // namespace mlir
/// Include the definitions of the interface.
diff --git a/mlir/test/Dialect/Linalg/runtime-verification.mlir b/mlir/test/Dialect/Linalg/runtime-verification.mlir
new file mode 100644
index 00000000000000..a4f29d8457e589
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/runtime-verification.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt %s -generate-runtime-verification | FileCheck %s
+
+// Most of the tests for linalg runtime-verification are implemented as integration tests.
+
+#identity = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @static_dims
+func.func @static_dims(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>) {
+ // CHECK: %[[TRUE:.*]] = index.bool.constant true
+ // CHECK: cf.assert %[[TRUE]]
+ %result = tensor.empty() : tensor<5xf32>
+ %0 = linalg.generic {
+ indexing_maps = [#identity, #identity, #identity],
+ iterator_types = ["parallel"]
+ } ins(%arg0, %arg1 : tensor<5xf32>, tensor<5xf32>)
+ outs(%result : tensor<5xf32>) {
+ ^bb0(%gen_arg1: f32, %gen_arg2: f32, %out: f32) :
+ %tmp1 = arith.addf %gen_arg1, %gen_arg2 : f32
+ linalg.yield %tmp1 : f32
+ } -> tensor<5xf32>
+ return %0 : tensor<5xf32>
+}
+
+// -----
+
+#map = affine_map<() -> ()>
+
+// CHECK-LABEL: @scalars
+func.func @scalars(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
+ // No runtime checks are required if the operands are all scalars
+ // CHECK-NOT: cf.assert
+ %result = tensor.empty() : tensor<f32>
+ %0 = linalg.generic {
+ indexing_maps = [#map, #map, #map],
+ iterator_types = []
+ } ins(%arg0, %arg1 : tensor<f32>, tensor<f32>)
+ outs(%result : tensor<f32>) {
+ ^bb0(%gen_arg1: f32, %gen_arg2: f32, %out: f32) :
+ %tmp1 = arith.addf %gen_arg1, %gen_arg2 : f32
+ linalg.yield %tmp1 : f32
+ } -> tensor<f32>
+ return %0 : tensor<f32>
+}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
new file mode 100644
index 00000000000000..b05ef9422e5967
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
@@ -0,0 +1,298 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
+// RUN: -convert-linalg-to-loops \
+// RUN: -expand-strided-metadata \
+// RUN: -lower-affine \
+// RUN: -convert-scf-to-cf \
+// RUN: -test-cf-assert \
+// RUN: -convert-index-to-llvm \
+// RUN: -finalize-memref-to-llvm \
+// RUN: -convert-func-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils \
+// RUN: -shared-libs=%mlir_c_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+func.func @main() {
+ %c5x = arith.constant dense<0.0> : tensor<5xf32>
+ %c4x = arith.constant dense<0.0> : tensor<4xf32>
+ %d5x = tensor.cast %c5x : tensor<5xf32> to tensor<?xf32>
+ %d4x = tensor.cast %c4x : tensor<4xf32> to tensor<?xf32>
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @simple_add(%d5x, %d5x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
+
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK: linalg.generic
+ // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
+ func.call @simple_add(%d5x, %d4x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
+
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK: linalg.generic
+ // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
+ func.call @simple_add(%d4x, %d5x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
+
+ %c1x1 = arith.constant dense<0.0> : tensor<1x1xf32>
+ %c1x4 = arith.constant dense<0.0> : tensor<1x4xf32>
+ %c4x4 = arith.constant dense<0.0> : tensor<4x4xf32>
+ %c4x5 = arith.constant dense<0.0> : tensor<4x5xf32>
+ %c5x4 = arith.constant dense<0.0> : tensor<5x4xf32>
+ %d1x1 = tensor.cast %c1x1 : tensor<1x1xf32> to tensor<?x?xf32>
+ %d1x4 = tensor.cast %c1x4 : tensor<1x4xf32> to tensor<?x?xf32>
+ %d4x4 = tensor.cast %c4x4 : tensor<4x4xf32> to tensor<?x?xf32>
+ %d4x5 = tensor.cast %c4x5 : tensor<4x5xf32> to tensor<?x?xf32>
+ %d5x4 = tensor.cast %c5x4 : tensor<5x4xf32> to tensor<?x?xf32>
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @broadcast_add(%d1x1, %d1x1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @broadcast_add(%d1x1, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @broadcast_add(%d4x4, %d1x4) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
+
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK: linalg.generic
+ // CHECK: ^ dimension #1 of input/output operand #1 is incompatible with inferred dimension size
+ func.call @broadcast_add(%d1x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
+
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK: linalg.generic
+ // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK: linalg.generic
+ // CHECK: ^ dimension #1 of input/output operand #1 is incompatible with inferred dimension size
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK: linalg.generic
+ // CHECK: ^ dimension #1 of input/output operand #2 is incompatible with inferred dimension size
+ func.call @broadcast_add(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @matmul_generic(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
+
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK: linalg.generic
+ // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
+ func.call @matmul_generic(%d4x5, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @matmul_named(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
+
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK: linalg.matmul
+ // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
+ func.call @matmul_named(%d4x5, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
+
+ %c64x57 = arith.constant dense<0.0> : tensor<16x29xf32>
+ %c3x4 = arith.constant dense<0.0> : tensor<3x4xf32>
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @conv(%c64x57, %c3x4) : (tensor<16x29xf32>, tensor<3x4xf32>) -> (tensor<5x7xf32>)
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @reverse_from_3(%d4x) : (tensor<?xf32>) -> (tensor<?xf32>)
+
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK: linalg.generic
+ // CHECK: unexpected negative result on dimension #0 of input/output operand #0
+ func.call @reverse_from_3(%d5x) : (tensor<?xf32>) -> (tensor<?xf32>)
+
+ return
+}
+
+
+#identity1D = affine_map<(d0) -> (d0)>
+
+func.func @simple_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> (tensor<?xf32>) {
+ %c0 = arith.constant 0 : index
+ %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
+ %result = tensor.empty(%dim) : tensor<?xf32>
+ %0 = linalg.generic {
+ indexing_maps = [#identity1D, #identity1D, #identity1D],
+ iterator_types = ["parallel"]
+ } ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
+ outs(%result : tensor<?xf32>) {
+ ^bb0(%gen_arg1: f32, %gen_arg2: f32, %out: f32) :
+ %tmp1 = arith.addf %gen_arg1, %gen_arg2 : f32
+ linalg.yield %tmp1 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+#broadcastD0 = affine_map<(d0, d1) -> (0, d1)>
+#broadcastD1 = affine_map<(d0, d1) -> (d0, 0)>
+#identity2D = affine_map<(d0, d1) -> (d0, d1)>
+
+func.func @broadcast_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // Calculate maximum dimension 0
+ %c0 = arith.constant 0 : index
+ %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %dim_0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
+ %0 = arith.maxui %dim, %dim_0 : index
+
+ // Calculate maximum dimension 1
+ %c1 = arith.constant 1 : index
+ %dim_1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %dim_2 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
+ %1 = arith.maxui %dim_1, %dim_2 : index
+
+ // Broadcast dimension 0 of %arg0
+ %dim_3 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %2 = arith.cmpi eq, %dim_3, %c1 : index
+ %3 = scf.if %2 -> (tensor<?x?xf32>) {
+ %dim_7 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %12 = tensor.empty(%0, %dim_7) : tensor<?x?xf32>
+ %13 = linalg.generic {
+ indexing_maps = [#broadcastD0, #identity2D],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%arg0 : tensor<?x?xf32>) outs(%12 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<?x?xf32>
+ scf.yield %13 : tensor<?x?xf32>
+ } else {
+ scf.yield %arg0 : tensor<?x?xf32>
+ }
+
+ // Broadcast dimension 1 of %arg0
+ %dim_4 = tensor.dim %3, %c1 : tensor<?x?xf32>
+ %4 = arith.cmpi eq, %dim_4, %c1 : index
+ %5 = scf.if %4 -> (tensor<?x?xf32>) {
+ %dim_7 = tensor.dim %3, %c0 : tensor<?x?xf32>
+ %12 = tensor.empty(%dim_7, %1) : tensor<?x?xf32>
+ %13 = linalg.generic {
+ indexing_maps = [#broadcastD1, #identity2D],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%3 : tensor<?x?xf32>) outs(%12 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<?x?xf32>
+ scf.yield %13 : tensor<?x?xf32>
+ } else {
+ scf.yield %3 : tensor<?x?xf32>
+ }
+
+ // Broadcast dimension 0 of %arg1
+ %dim_5 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
+ %6 = arith.cmpi eq, %dim_5, %c1 : index
+ %7 = scf.if %6 -> (tensor<?x?xf32>) {
+ %dim_7 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
+ %12 = tensor.empty(%0, %dim_7) : tensor<?x?xf32>
+ %13 = linalg.generic {
+ indexing_maps = [#broadcastD0, #identity2D],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%arg1 : tensor<?x?xf32>) outs(%12 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<?x?xf32>
+ scf.yield %13 : tensor<?x?xf32>
+ } else {
+ scf.yield %arg1 : tensor<?x?xf32>
+ }
+
+ // Broadcast dimension 1 of %arg1
+ %dim_6 = tensor.dim %7, %c1 : tensor<?x?xf32>
+ %8 = arith.cmpi eq, %dim_6, %c1 : index
+ %9 = scf.if %8 -> (tensor<?x?xf32>) {
+ %dim_7 = tensor.dim %7, %c0 : tensor<?x?xf32>
+ %12 = tensor.empty(%dim_7, %1) : tensor<?x?xf32>
+ %13 = linalg.generic {
+ indexing_maps = [#broadcastD1, #identity2D],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%7 : tensor<?x?xf32>) outs(%12 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<?x?xf32>
+ scf.yield %13 : tensor<?x?xf32>
+ } else {
+ scf.yield %7 : tensor<?x?xf32>
+ }
+
+ // Perform element-wise computation
+ %10 = tensor.empty(%0, %1) : tensor<?x?xf32>
+ %11 = linalg.generic {
+ indexing_maps = [#identity2D, #identity2D, #identity2D],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%5, %9 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%10 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_7: f32, %out: f32):
+ %12 = arith.addf %in, %in_7 : f32
+ linalg.yield %12 : f32
+ } -> tensor<?x?xf32>
+ return %11 : tensor<?x?xf32>
+}
+
+#matmul_accesses = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>
+]
+#matmul_trait = {
+ iterator_types = ["parallel", "parallel", "reduction"],
+ indexing_maps = #matmul_accesses
+}
+
+func.func @matmul_generic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %cf0 = arith.constant 0.0 : f32
+ %ci0 = arith.constant 0 : index
+ %ci1 = arith.constant 1 : index
+ %d0 = tensor.dim %arg0, %ci0 : tensor<?x?xf32>
+ %d1 = tensor.dim %arg1, %ci1 : tensor<?x?xf32>
+ %splat = tensor.splat %cf0[%d0, %d1] : tensor<?x?xf32>
+ %0 = linalg.generic #matmul_trait ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%splat : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+func.func @matmul_named(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %cf0 = arith.constant 0.0 : f32
+ %ci0 = arith.constant 0 : index
+ %ci1 = arith.constant 1 : index
+ %d0 = tensor.dim %arg0, %ci0 : tensor<?x?xf32>
+ %d1 = tensor.dim %arg1, %ci1 : tensor<?x?xf32>
+ %splat = tensor.splat %cf0[%d0, %d1] : tensor<?x?xf32>
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%splat : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+#conv_trait = {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0 * 3 + d2, d1 * 4 + d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]
+}
+
+func.func @conv(%arg0: tensor<16x29xf32>, %arg1: tensor<3x4xf32>) -> (tensor<5x7xf32>) {
+ %c0 = arith.constant 0.0 : f32
+ %splat = tensor.splat %c0 : tensor<5x7xf32>
+ %result = linalg.generic #conv_trait ins(%arg0, %arg1 : tensor<16x29xf32>, tensor<3x4xf32>) outs(%splat : tensor<5x7xf32>) {
+ ^bb0(%in: f32, %in_64: f32, %out: f32):
+ %5 = arith.mulf %in, %in_64 : f32
+ %6 = arith.addf %out, %5 : f32
+ linalg.yield %6 : f32
+ } -> tensor<5x7xf32>
+ return %result : tensor<5x7xf32>
+}
+
+#reverse_trait = {
+ indexing_maps = [
+ affine_map<(i) -> (3 - i)>,
+ affine_map<(i) -> (i)>
+ ],
+ iterator_types = ["parallel"]
+}
+
+func.func @reverse_from_3(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
+ %cf0 = arith.constant 0.0 : f32
+ %ci0 = arith.constant 0 : index
+ %d0 = tensor.dim %arg0, %ci0 : tensor<?xf32>
+ %splat = tensor.splat %cf0[%d0] : tensor<?xf32>
+ %result = linalg.generic #reverse_trait ins(%arg0: tensor<?xf32>) outs(%splat: tensor<?xf32>) {
+ ^bb0(%a: f32, %b: f32):
+ linalg.yield %a : f32
+ } -> tensor<?xf32>
+ return %result : tensor<?xf32>
+}
More information about the Mlir-commits
mailing list