[Mlir-commits] [mlir] f426be1 - Revert "[mlir][linalg] Add runtime verification for linalg ops" (#89780)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 23 08:56:03 PDT 2024


Author: Ryan Holt
Date: 2024-04-23T11:55:59-04:00
New Revision: f426be195a08874686d01783bbc490295bf4afb2

URL: https://github.com/llvm/llvm-project/commit/f426be195a08874686d01783bbc490295bf4afb2
DIFF: https://github.com/llvm/llvm-project/commit/f426be195a08874686d01783bbc490295bf4afb2.diff

LOG: Revert "[mlir][linalg] Add runtime verification for linalg ops" (#89780)

Reverts llvm/llvm-project#89342 due to build failure

Added: 
    

Modified: 
    mlir/include/mlir/InitAllDialects.h
    mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
    mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp

Removed: 
    mlir/include/mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h
    mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
    mlir/test/Dialect/Linalg/runtime-verification.mlir
    mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h b/mlir/include/mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h
deleted file mode 100644
index 6c3643f7835cbe8..000000000000000
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h
+++ /dev/null
@@ -1,21 +0,0 @@
-//===- RuntimeOpVerification.h - Op Verification ----------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_LINALG_RUNTIMEOPVERIFICATION_H
-#define MLIR_DIALECT_LINALG_RUNTIMEOPVERIFICATION_H
-
-namespace mlir {
-class DialectRegistry;
-
-namespace linalg {
-void registerRuntimeVerifiableOpInterfaceExternalModels(
-    DialectRegistry &registry);
-} // namespace linalg
-} // namespace mlir
-
-#endif // MLIR_DIALECT_LINALG_RUNTIMEOPVERIFICATION_H

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index d9db21073e15c7a..c4d788cf8ed3166 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -45,7 +45,6 @@
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
-#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
 #include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
@@ -162,7 +161,6 @@ inline void registerAllDialects(DialectRegistry &registry) {
   cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
   gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
   linalg::registerAllDialectInterfaceImplementations(registry);
-  linalg::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
   memref::registerAllocationOpInterfaceExternalModels(registry);
   memref::registerBufferViewFlowOpInterfaceExternalModels(registry);
   memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);

diff  --git a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
index 6fd0df59d9d2e08..d5f11d00cc3d2ab 100644
--- a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
+++ b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
@@ -35,12 +35,6 @@ def RuntimeVerifiableOpInterface : OpInterface<"RuntimeVerifiableOpInterface"> {
                     "::mlir::Location":$loc)
     >,
   ];
-
-  let extraClassDeclaration = [{
-    /// Generate the error message that will be printed to the user when 
-    /// verification fails.
-    static std::string generateErrorMessage(Operation *op, const std::string &msg);
-  }];
 }
 
 #endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 44d95bbc02d4ed8..ee6e391d0cc6826 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -27,7 +27,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   NamedOpConversions.cpp
   Padding.cpp
   Promotion.cpp
-  RuntimeOpVerification.cpp
   Specialize.cpp
   Split.cpp
   SplitReduction.cpp

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
deleted file mode 100644
index b30182dc84079fe..000000000000000
--- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
+++ /dev/null
@@ -1,135 +0,0 @@
-//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
-
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
-#include "mlir/Dialect/Index/IR/IndexAttrs.h"
-#include "mlir/Dialect/Index/IR/IndexDialect.h"
-#include "mlir/Dialect/Index/IR/IndexOps.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
-
-namespace mlir {
-namespace linalg {
-namespace {
-/// Verify that the runtime sizes of the operands to linalg structured ops are
-/// compatible with the runtime sizes inferred by composing the loop ranges with
-/// the linalg op's indexing maps. This is similar to the verifier except that
-/// here we insert IR to perform the verification at runtime.
-template <typename T>
-struct StructuredOpInterface
-    : public RuntimeVerifiableOpInterface::ExternalModel<
-          StructuredOpInterface<T>, T> {
-  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
-                                   Location loc) const {
-    auto linalgOp = llvm::cast<LinalgOp>(op);
-
-    SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
-    auto [starts, ends, _] = getOffsetsSizesAndStrides(loopRanges);
-
-    auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
-    auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
-
-    // Subtract one from the loop ends before composing with the indexing map
-    transform(ends, ends.begin(), [&](OpFoldResult end) {
-      auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
-      return builder.createOrFold<index::SubOp>(loc, endValue, one);
-    });
-
-    for (OpOperand &opOperand : linalgOp->getOpOperands()) {
-      AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
-      auto startIndices = affine::makeComposedFoldedMultiResultAffineApply(
-          builder, loc, indexingMap, starts);
-      auto endIndices = affine::makeComposedFoldedMultiResultAffineApply(
-          builder, loc, indexingMap, ends);
-
-      for (auto dim : llvm::seq(linalgOp.getRank(&opOperand))) {
-        auto startIndex =
-            getValueOrCreateConstantIndexOp(builder, loc, startIndices[dim]);
-        auto endIndex =
-            getValueOrCreateConstantIndexOp(builder, loc, endIndices[dim]);
-
-        // Generate:
-        //   minIndex = min(startIndex, endIndex)
-        //   assert(minIndex >= 0)
-        // To ensure we do not generate a negative index. We take the minimum of
-        // the start and end indices in order to handle reverse loops such as
-        // `affine_map<(i) -> (3 - i)>`
-        auto min =
-            builder.createOrFold<index::MinSOp>(loc, startIndex, endIndex);
-        auto cmpOp = builder.createOrFold<index::CmpOp>(
-            loc, index::IndexCmpPredicate::SGE, min, zero);
-        auto msg = RuntimeVerifiableOpInterface::generateErrorMessage(
-            linalgOp, "unexpected negative result on dimension #" +
-                          std::to_string(dim) + " of input/output operand #" +
-                          std::to_string(opOperand.getOperandNumber()));
-        builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
-
-        // Generate:
-        //   inferredDimSize = max(startIndex, endIndex) + 1
-        //   actualDimSize = dim(operand)
-        //   assert(inferredDimSize <= actualDimSize)
-        // To ensure that we do not index past the bounds of the operands.
-        auto max =
-            builder.createOrFold<index::MaxSOp>(loc, startIndex, endIndex);
-
-        auto inferredDimSize =
-            builder.createOrFold<index::AddOp>(loc, max, one);
-
-        auto actualDimSize =
-            createOrFoldDimOp(builder, loc, opOperand.get(), dim);
-
-        // Similar to the verifier, when the affine expression in the indexing
-        // map is complicated, we just check that the inferred dimension sizes
-        // are in the boundary of the operands' size. Being more precise than
-        // that is 
diff icult.
-        auto predicate = isa<AffineDimExpr>(indexingMap.getResult(dim))
-                             ? index::IndexCmpPredicate::EQ
-                             : index::IndexCmpPredicate::SLE;
-
-        cmpOp = builder.createOrFold<index::CmpOp>(
-            loc, predicate, inferredDimSize, actualDimSize);
-        msg = RuntimeVerifiableOpInterface::generateErrorMessage(
-            linalgOp, "dimension #" + std::to_string(dim) +
-                          " of input/output operand #" +
-                          std::to_string(opOperand.getOperandNumber()) +
-                          " is incompatible with inferred dimension size");
-        builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
-      }
-    }
-  }
-};
-
-template <typename... OpTs>
-void attachInterface(MLIRContext *ctx) {
-  (OpTs::template attachInterface<StructuredOpInterface<OpTs>>(*ctx), ...);
-}
-} // namespace
-} // namespace linalg
-} // namespace mlir
-
-void mlir::linalg::registerRuntimeVerifiableOpInterfaceExternalModels(
-    DialectRegistry &registry) {
-  registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *) {
-    attachInterface<
-#define GET_OP_LIST
-#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
-        >(ctx);
-
-    // Load additional dialects of which ops may get created.
-    ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
-                     cf::ControlFlowDialect, index::IndexDialect,
-                     tensor::TensorDialect>();
-  });
-}

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 450bfa0cec0c7ff..05b813a3b1e9084 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -20,6 +20,25 @@
 
 using namespace mlir;
 
+/// Generate an error message string for the given op and the specified error.
+static std::string generateErrorMessage(Operation *op, const std::string &msg) {
+  std::string buffer;
+  llvm::raw_string_ostream stream(buffer);
+  OpPrintingFlags flags;
+  // We may generate a lot of error messages and so we need to ensure the
+  // printing is fast.
+  flags.elideLargeElementsAttrs();
+  flags.printGenericOpForm();
+  flags.skipRegions();
+  flags.useLocalScope();
+  stream << "ERROR: Runtime op verification failed\n";
+  op->print(stream, flags);
+  stream << "\n^ " << msg;
+  stream << "\nLocation: ";
+  op->getLoc().print(stream);
+  return stream.str();
+}
+
 namespace mlir {
 namespace memref {
 namespace {
@@ -43,10 +62,8 @@ struct CastOpInterface
           builder.create<arith::ConstantIndexOp>(loc, resultType.getRank());
       Value isSameRank = builder.create<arith::CmpIOp>(
           loc, arith::CmpIPredicate::eq, srcRank, resultRank);
-      builder.create<cf::AssertOp>(
-          loc, isSameRank,
-          RuntimeVerifiableOpInterface::generateErrorMessage(op,
-                                                             "rank mismatch"));
+      builder.create<cf::AssertOp>(loc, isSameRank,
+                                   generateErrorMessage(op, "rank mismatch"));
     }
 
     // Get source offset and strides. We do not have an op to get offsets and
@@ -84,8 +101,8 @@ struct CastOpInterface
           loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
       builder.create<cf::AssertOp>(
           loc, isSameSz,
-          RuntimeVerifiableOpInterface::generateErrorMessage(
-              op, "size mismatch of dim " + std::to_string(it.index())));
+          generateErrorMessage(op, "size mismatch of dim " +
+                                       std::to_string(it.index())));
     }
 
     // Get result offset and strides.
@@ -102,10 +119,8 @@ struct CastOpInterface
           builder.create<arith::ConstantIndexOp>(loc, resultOffset);
       Value isSameOffset = builder.create<arith::CmpIOp>(
           loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
-      builder.create<cf::AssertOp>(
-          loc, isSameOffset,
-          RuntimeVerifiableOpInterface::generateErrorMessage(
-              op, "offset mismatch"));
+      builder.create<cf::AssertOp>(loc, isSameOffset,
+                                   generateErrorMessage(op, "offset mismatch"));
     }
 
     // Check strides.
@@ -122,8 +137,8 @@ struct CastOpInterface
           loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
       builder.create<cf::AssertOp>(
           loc, isSameStride,
-          RuntimeVerifiableOpInterface::generateErrorMessage(
-              op, "stride mismatch of dim " + std::to_string(it.index())));
+          generateErrorMessage(op, "stride mismatch of dim " +
+                                       std::to_string(it.index())));
     }
   }
 };
@@ -163,9 +178,7 @@ struct LoadStoreOpInterface
                 : andOp;
     }
     builder.create<cf::AssertOp>(
-        loc, assertCond,
-        RuntimeVerifiableOpInterface::generateErrorMessage(
-            op, "out-of-bounds access"));
+        loc, assertCond, generateErrorMessage(op, "out-of-bounds access"));
   }
 };
 
@@ -235,7 +248,7 @@ struct ReinterpretCastOpInterface
 
     builder.create<cf::AssertOp>(
         loc, assertCond,
-        RuntimeVerifiableOpInterface::generateErrorMessage(
+        generateErrorMessage(
             op,
             "result of reinterpret_cast is out-of-bounds of the base memref"));
   }
@@ -280,8 +293,8 @@ struct SubViewOpInterface
 
     builder.create<cf::AssertOp>(
         loc, assertCond,
-        RuntimeVerifiableOpInterface::generateErrorMessage(
-            op, "subview is out-of-bounds of the base memref"));
+        generateErrorMessage(op,
+                             "subview is out-of-bounds of the base memref"));
   }
 };
 
@@ -321,9 +334,8 @@ struct ExpandShapeOpInterface
           builder.create<arith::ConstantIndexOp>(loc, 0));
       builder.create<cf::AssertOp>(
           loc, isModZero,
-          RuntimeVerifiableOpInterface::generateErrorMessage(
-              op, "static result dims in reassoc group do not "
-                  "divide src dim evenly"));
+          generateErrorMessage(op, "static result dims in reassoc group do not "
+                                   "divide src dim evenly"));
     }
   }
 };

diff  --git a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
index e823b5df179c50d..9205d8d8c34a291 100644
--- a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
+++ b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
@@ -11,28 +11,6 @@
 namespace mlir {
 class Location;
 class OpBuilder;
-
-/// Generate an error message string for the given op and the specified error.
-std::string
-RuntimeVerifiableOpInterface::generateErrorMessage(Operation *op,
-                                                   const std::string &msg) {
-  std::string buffer;
-  llvm::raw_string_ostream stream(buffer);
-  OpPrintingFlags flags;
-  // We may generate a lot of error messages and so we need to ensure the
-  // printing is fast.
-  flags.elideLargeElementsAttrs();
-  flags.printGenericOpForm();
-  flags.skipRegions();
-  flags.useLocalScope();
-  stream << "ERROR: Runtime op verification failed\n";
-  op->print(stream, flags);
-  stream << "\n^ " << msg;
-  stream << "\nLocation: ";
-  op->getLoc().print(stream);
-  return stream.str();
-}
-
 } // namespace mlir
 
 /// Include the definitions of the interface.

diff  --git a/mlir/test/Dialect/Linalg/runtime-verification.mlir b/mlir/test/Dialect/Linalg/runtime-verification.mlir
deleted file mode 100644
index a4f29d8457e5895..000000000000000
--- a/mlir/test/Dialect/Linalg/runtime-verification.mlir
+++ /dev/null
@@ -1,43 +0,0 @@
-// RUN: mlir-opt %s -generate-runtime-verification | FileCheck %s
-
-// Most of the tests for linalg runtime-verification are implemented as integration tests.
-
-#identity = affine_map<(d0) -> (d0)>
-
-// CHECK-LABEL: @static_dims
-func.func @static_dims(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>) {
-    // CHECK: %[[TRUE:.*]] = index.bool.constant true
-    // CHECK: cf.assert %[[TRUE]]
-    %result = tensor.empty() : tensor<5xf32> 
-    %0 = linalg.generic {
-      indexing_maps = [#identity, #identity, #identity],
-      iterator_types = ["parallel"]
-    } ins(%arg0, %arg1 : tensor<5xf32>, tensor<5xf32>)
-      outs(%result : tensor<5xf32>) {
-      ^bb0(%gen_arg1: f32, %gen_arg2: f32, %out: f32) :
-        %tmp1 = arith.addf %gen_arg1, %gen_arg2 : f32
-        linalg.yield %tmp1 : f32
-    } -> tensor<5xf32>
-    return %0 : tensor<5xf32>
-}
-
-// -----
-
-#map = affine_map<() -> ()>
-
-// CHECK-LABEL: @scalars
-func.func @scalars(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
-    // No runtime checks are required if the operands are all scalars
-    // CHECK-NOT: cf.assert
-    %result = tensor.empty() : tensor<f32> 
-    %0 = linalg.generic {
-      indexing_maps = [#map, #map, #map],
-      iterator_types = []
-    } ins(%arg0, %arg1 : tensor<f32>, tensor<f32>)
-      outs(%result : tensor<f32>) {
-      ^bb0(%gen_arg1: f32, %gen_arg2: f32, %out: f32) :
-        %tmp1 = arith.addf %gen_arg1, %gen_arg2 : f32
-        linalg.yield %tmp1 : f32
-    } -> tensor<f32>
-    return %0 : tensor<f32>
-}

diff  --git a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
deleted file mode 100644
index b05ef9422e59674..000000000000000
--- a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
+++ /dev/null
@@ -1,298 +0,0 @@
-// RUN: mlir-opt %s -generate-runtime-verification \
-// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
-// RUN: -convert-linalg-to-loops \
-// RUN: -expand-strided-metadata \
-// RUN: -lower-affine \
-// RUN: -convert-scf-to-cf \
-// RUN: -test-cf-assert \
-// RUN: -convert-index-to-llvm \
-// RUN: -finalize-memref-to-llvm \
-// RUN: -convert-func-to-llvm \
-// RUN: -reconcile-unrealized-casts | \
-// RUN: mlir-cpu-runner -e main -entry-point-result=void \
-// RUN:     -shared-libs=%mlir_runner_utils \
-// RUN:     -shared-libs=%mlir_c_runner_utils 2>&1 | \
-// RUN: FileCheck %s
-
-func.func @main() {
-  %c5x = arith.constant dense<0.0> : tensor<5xf32>
-  %c4x = arith.constant dense<0.0> : tensor<4xf32>
-  %d5x = tensor.cast %c5x : tensor<5xf32> to tensor<?xf32>
-  %d4x = tensor.cast %c4x : tensor<4xf32> to tensor<?xf32>
-
-  // CHECK-NOT: ERROR: Runtime op verification failed
-  func.call @simple_add(%d5x, %d5x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
-
-  // CHECK: ERROR: Runtime op verification failed
-  // CHECK: linalg.generic
-  // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
-  func.call @simple_add(%d5x, %d4x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
-
-  // CHECK: ERROR: Runtime op verification failed
-  // CHECK: linalg.generic
-  // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
-  func.call @simple_add(%d4x, %d5x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
-
-  %c1x1 = arith.constant dense<0.0> : tensor<1x1xf32>
-  %c1x4 = arith.constant dense<0.0> : tensor<1x4xf32>
-  %c4x4 = arith.constant dense<0.0> : tensor<4x4xf32>
-  %c4x5 = arith.constant dense<0.0> : tensor<4x5xf32>
-  %c5x4 = arith.constant dense<0.0> : tensor<5x4xf32>
-  %d1x1 = tensor.cast %c1x1 : tensor<1x1xf32> to tensor<?x?xf32>
-  %d1x4 = tensor.cast %c1x4 : tensor<1x4xf32> to tensor<?x?xf32>
-  %d4x4 = tensor.cast %c4x4 : tensor<4x4xf32> to tensor<?x?xf32>
-  %d4x5 = tensor.cast %c4x5 : tensor<4x5xf32> to tensor<?x?xf32>
-  %d5x4 = tensor.cast %c5x4 : tensor<5x4xf32> to tensor<?x?xf32>
-
-  // CHECK-NOT: ERROR: Runtime op verification failed
-  func.call @broadcast_add(%d1x1, %d1x1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
-
-  // CHECK-NOT: ERROR: Runtime op verification failed
-  func.call @broadcast_add(%d1x1, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
-
-  // CHECK-NOT: ERROR: Runtime op verification failed
-  func.call @broadcast_add(%d4x4, %d1x4) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
-
-  // CHECK: ERROR: Runtime op verification failed
-  // CHECK: linalg.generic
-  // CHECK: ^ dimension #1 of input/output operand #1 is incompatible with inferred dimension size
-  func.call @broadcast_add(%d1x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
-
-  // CHECK: ERROR: Runtime op verification failed
-  // CHECK: linalg.generic
-  // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
-  // CHECK: ERROR: Runtime op verification failed
-  // CHECK: linalg.generic
-  // CHECK: ^ dimension #1 of input/output operand #1 is incompatible with inferred dimension size
-  // CHECK: ERROR: Runtime op verification failed
-  // CHECK: linalg.generic
-  // CHECK: ^ dimension #1 of input/output operand #2 is incompatible with inferred dimension size
-  func.call @broadcast_add(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
-
-  // CHECK-NOT: ERROR: Runtime op verification failed
-  func.call @matmul_generic(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
-
-  // CHECK: ERROR: Runtime op verification failed
-  // CHECK: linalg.generic
-  // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
-  func.call @matmul_generic(%d4x5, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
-
-  // CHECK-NOT: ERROR: Runtime op verification failed
-  func.call @matmul_named(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
-
-  // CHECK: ERROR: Runtime op verification failed
-  // CHECK: linalg.matmul
-  // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
-  func.call @matmul_named(%d4x5, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
-
-  %c64x57 = arith.constant dense<0.0> : tensor<16x29xf32>
-  %c3x4 = arith.constant dense<0.0> : tensor<3x4xf32>
-
-  // CHECK-NOT: ERROR: Runtime op verification failed
-  func.call @conv(%c64x57, %c3x4) : (tensor<16x29xf32>, tensor<3x4xf32>) -> (tensor<5x7xf32>)
-
-  // CHECK-NOT: ERROR: Runtime op verification failed
-  func.call @reverse_from_3(%d4x) : (tensor<?xf32>) -> (tensor<?xf32>)
-
-  // CHECK: ERROR: Runtime op verification failed
-  // CHECK: linalg.generic
-  // CHECK: unexpected negative result on dimension #0 of input/output operand #0
-  func.call @reverse_from_3(%d5x) : (tensor<?xf32>) -> (tensor<?xf32>)
-
-  return
-}
-
-
-#identity1D = affine_map<(d0) -> (d0)>
-
-func.func @simple_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> (tensor<?xf32>) {
-    %c0 = arith.constant 0 : index
-    %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
-    %result = tensor.empty(%dim) : tensor<?xf32> 
-    %0 = linalg.generic {
-      indexing_maps = [#identity1D, #identity1D, #identity1D],
-      iterator_types = ["parallel"]
-    } ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
-      outs(%result : tensor<?xf32>) {
-      ^bb0(%gen_arg1: f32, %gen_arg2: f32, %out: f32) :
-        %tmp1 = arith.addf %gen_arg1, %gen_arg2 : f32
-        linalg.yield %tmp1 : f32
-    } -> tensor<?xf32>
-    return %0 : tensor<?xf32>
-}
-
-#broadcastD0 = affine_map<(d0, d1) -> (0, d1)>
-#broadcastD1 = affine_map<(d0, d1) -> (d0, 0)>
-#identity2D = affine_map<(d0, d1) -> (d0, d1)>
-
-func.func @broadcast_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
- // Calculate maximum dimension 0
-  %c0 = arith.constant 0 : index
-  %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
-  %dim_0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
-  %0 = arith.maxui %dim, %dim_0 : index
-
-  // Calculate maximum dimension 1
-  %c1 = arith.constant 1 : index
-  %dim_1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
-  %dim_2 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
-  %1 = arith.maxui %dim_1, %dim_2 : index
-
-  // Broadcast dimension 0 of %arg0
-  %dim_3 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
-  %2 = arith.cmpi eq, %dim_3, %c1 : index
-  %3 = scf.if %2 -> (tensor<?x?xf32>) {
-    %dim_7 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
-    %12 = tensor.empty(%0, %dim_7) : tensor<?x?xf32>
-    %13 = linalg.generic {
-      indexing_maps = [#broadcastD0, #identity2D],
-      iterator_types = ["parallel", "parallel"]
-    } ins(%arg0 : tensor<?x?xf32>) outs(%12 : tensor<?x?xf32>) {
-    ^bb0(%in: f32, %out: f32):
-      linalg.yield %in : f32
-    } -> tensor<?x?xf32>
-    scf.yield %13 : tensor<?x?xf32>
-  } else {
-    scf.yield %arg0 : tensor<?x?xf32>
-  }
-
-  // Broadcast dimension 1 of %arg0
-  %dim_4 = tensor.dim %3, %c1 : tensor<?x?xf32>
-  %4 = arith.cmpi eq, %dim_4, %c1 : index
-  %5 = scf.if %4 -> (tensor<?x?xf32>) {
-    %dim_7 = tensor.dim %3, %c0 : tensor<?x?xf32>
-    %12 = tensor.empty(%dim_7, %1) : tensor<?x?xf32>
-    %13 = linalg.generic {
-      indexing_maps = [#broadcastD1, #identity2D],
-      iterator_types = ["parallel", "parallel"]
-    } ins(%3 : tensor<?x?xf32>) outs(%12 : tensor<?x?xf32>) {
-    ^bb0(%in: f32, %out: f32):
-      linalg.yield %in : f32
-    } -> tensor<?x?xf32>
-    scf.yield %13 : tensor<?x?xf32>
-  } else {
-    scf.yield %3 : tensor<?x?xf32>
-  }
-
-  // Broadcast dimension 0 of %arg1
-  %dim_5 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
-  %6 = arith.cmpi eq, %dim_5, %c1 : index
-  %7 = scf.if %6 -> (tensor<?x?xf32>) {
-    %dim_7 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
-    %12 = tensor.empty(%0, %dim_7) : tensor<?x?xf32>
-    %13 = linalg.generic {
-      indexing_maps = [#broadcastD0, #identity2D],
-      iterator_types = ["parallel", "parallel"]
-    } ins(%arg1 : tensor<?x?xf32>) outs(%12 : tensor<?x?xf32>) {
-    ^bb0(%in: f32, %out: f32):
-      linalg.yield %in : f32
-    } -> tensor<?x?xf32>
-    scf.yield %13 : tensor<?x?xf32>
-  } else {
-    scf.yield %arg1 : tensor<?x?xf32>
-  }
-
-  // Broadcast dimension 1 of %arg1
-  %dim_6 = tensor.dim %7, %c1 : tensor<?x?xf32>
-  %8 = arith.cmpi eq, %dim_6, %c1 : index
-  %9 = scf.if %8 -> (tensor<?x?xf32>) {
-    %dim_7 = tensor.dim %7, %c0 : tensor<?x?xf32>
-    %12 = tensor.empty(%dim_7, %1) : tensor<?x?xf32>
-    %13 = linalg.generic {
-      indexing_maps = [#broadcastD1, #identity2D],
-      iterator_types = ["parallel", "parallel"]
-    } ins(%7 : tensor<?x?xf32>) outs(%12 : tensor<?x?xf32>) {
-    ^bb0(%in: f32, %out: f32):
-      linalg.yield %in : f32
-    } -> tensor<?x?xf32>
-    scf.yield %13 : tensor<?x?xf32>
-  } else {
-    scf.yield %7 : tensor<?x?xf32>
-  }
-
-  // Perform element-wise computation
-  %10 = tensor.empty(%0, %1) : tensor<?x?xf32>
-  %11 = linalg.generic {
-    indexing_maps = [#identity2D, #identity2D, #identity2D],
-    iterator_types = ["parallel", "parallel"]
-  } ins(%5, %9 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%10 : tensor<?x?xf32>) {
-  ^bb0(%in: f32, %in_7: f32, %out: f32):
-    %12 = arith.addf %in, %in_7 : f32
-    linalg.yield %12 : f32
-  } -> tensor<?x?xf32>
-  return %11 : tensor<?x?xf32>
-}
-
-#matmul_accesses = [
-  affine_map<(m, n, k) -> (m, k)>,
-  affine_map<(m, n, k) -> (k, n)>,
-  affine_map<(m, n, k) -> (m, n)>
-]
-#matmul_trait = {
-  iterator_types = ["parallel", "parallel", "reduction"],
-  indexing_maps = #matmul_accesses
-}
-
-func.func @matmul_generic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
-  %cf0 = arith.constant 0.0 : f32 
-  %ci0 = arith.constant 0 : index 
-  %ci1 = arith.constant 1 : index 
-  %d0 = tensor.dim %arg0, %ci0 : tensor<?x?xf32>
-  %d1 = tensor.dim %arg1, %ci1 : tensor<?x?xf32>
-  %splat = tensor.splat %cf0[%d0, %d1] : tensor<?x?xf32>
-  %0 = linalg.generic #matmul_trait ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%splat : tensor<?x?xf32>) {
-  ^bb0(%in: f32, %in_0: f32, %out: f32):
-    %1 = arith.mulf %in, %in_0 : f32
-    %2 = arith.addf %out, %1 : f32
-    linalg.yield %2 : f32
-  } -> tensor<?x?xf32>
-  return %0 : tensor<?x?xf32>
-}
-
-func.func @matmul_named(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
-  %cf0 = arith.constant 0.0 : f32 
-  %ci0 = arith.constant 0 : index 
-  %ci1 = arith.constant 1 : index 
-  %d0 = tensor.dim %arg0, %ci0 : tensor<?x?xf32>
-  %d1 = tensor.dim %arg1, %ci1 : tensor<?x?xf32>
-  %splat = tensor.splat %cf0[%d0, %d1] : tensor<?x?xf32>
-  %0 = linalg.matmul  ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%splat : tensor<?x?xf32>) -> tensor<?x?xf32>
-  return %0 : tensor<?x?xf32>
-}
-
-#conv_trait = {
-  indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0 * 3 + d2, d1 * 4 + d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
-  iterator_types = ["parallel", "parallel", "reduction", "reduction"]
-}
-
-func.func @conv(%arg0: tensor<16x29xf32>, %arg1: tensor<3x4xf32>) -> (tensor<5x7xf32>) {
-  %c0 = arith.constant 0.0 : f32 
-  %splat = tensor.splat %c0 : tensor<5x7xf32>
-  %result = linalg.generic #conv_trait ins(%arg0, %arg1 : tensor<16x29xf32>, tensor<3x4xf32>) outs(%splat : tensor<5x7xf32>) {
-  ^bb0(%in: f32, %in_64: f32, %out: f32):
-    %5 = arith.mulf %in, %in_64 : f32
-    %6 = arith.addf %out, %5 : f32
-    linalg.yield %6 : f32
-  } -> tensor<5x7xf32>
-  return %result : tensor<5x7xf32>
-}
-
-#reverse_trait = {
-  indexing_maps = [
-          affine_map<(i) -> (3 - i)>,
-          affine_map<(i) -> (i)>
-  ],
-  iterator_types = ["parallel"]
-}
-
-func.func @reverse_from_3(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
-  %cf0 = arith.constant 0.0 : f32 
-  %ci0 = arith.constant 0 : index 
-  %d0 = tensor.dim %arg0, %ci0 : tensor<?xf32>
-  %splat = tensor.splat %cf0[%d0] : tensor<?xf32>
-  %result = linalg.generic #reverse_trait ins(%arg0: tensor<?xf32>) outs(%splat: tensor<?xf32>) {
-    ^bb0(%a: f32, %b: f32):
-    linalg.yield %a : f32
-  } -> tensor<?xf32>
-  return %result : tensor<?xf32>
-}


        


More information about the Mlir-commits mailing list