diff --git a/include/circt/Dialect/OM/OMPasses.td b/include/circt/Dialect/OM/OMPasses.td index c454aa0f7e36..abc4af04ba1a 100644 --- a/include/circt/Dialect/OM/OMPasses.td +++ b/include/circt/Dialect/OM/OMPasses.td @@ -31,15 +31,21 @@ def ElaborateObject: Pass<"om-elaborate-object", "mlir::ModuleOp"> { Performs evaluation of a specified OM class by inlining all object instantiations and folding field accesses. - The pass requires the `target-class` option to specify which class to - elaborate. + The pass requires either the `target-class` option to specify which class to + elaborate, or the `all-public-classes` option to elaborate all public + classes. By default, selected classes must elaborate fully. The + `allow-unevaluated` option leaves external objects and other unevaluated + operations in place. }]; let options = [ Option<"targetClass", "target-class", "std::string", /*default=*/"", "The class to elaborate">, - Option<"test", "test", "bool", /*default=*/"false", - "Internal testing mode: elaborate all zero-argument classes", - "llvm::cl::Hidden"> + Option<"allPublicClasses", "all-public-classes", "bool", + /*default=*/"false", + "Elaborate all public classes">, + Option<"allowUnevaluated", "allow-unevaluated", "bool", + /*default=*/"false", + "Allow external objects and unevaluated operations to remain"> ]; } diff --git a/lib/Dialect/OM/Transforms/ElaborateObject.cpp b/lib/Dialect/OM/Transforms/ElaborateObject.cpp index 69e6af7809a7..e8b6dda43bf7 100644 --- a/lib/Dialect/OM/Transforms/ElaborateObject.cpp +++ b/lib/Dialect/OM/Transforms/ElaborateObject.cpp @@ -46,8 +46,10 @@ using FieldIndex = DenseMap, unsigned>; /// Pattern to inline ObjectOp instances by cloning the class body and /// replacing them with ElaboratedObjectOp. struct ObjectOpInliningPattern : public OpRewritePattern { - ObjectOpInliningPattern(MLIRContext *context, SymbolTable &symTable) - : OpRewritePattern(context), symTable(symTable) {} + ObjectOpInliningPattern(MLIRContext *context, SymbolTable &symTable, + bool replaceExternalWithUnknown) + : OpRewritePattern(context), symTable(symTable), + replaceExternalWithUnknown(replaceExternalWithUnknown) {} LogicalResult matchAndRewrite(ObjectOp objOp, PatternRewriter &rewriter) const override { @@ -56,6 +58,8 @@ struct ObjectOpInliningPattern : public OpRewritePattern { // External classes cannot be elaborated; replace with unknown values. if (isa(classLike)) { + if (!replaceExternalWithUnknown) + return failure(); rewriter.replaceOpWithNewOp(objOp, objOp.getType()); return success(); } @@ -88,6 +92,7 @@ struct ObjectOpInliningPattern : public OpRewritePattern { } const SymbolTable &symTable; + bool replaceExternalWithUnknown; }; /// Pattern to fold ObjectFieldOp on ElaboratedObjectOp by directly accessing @@ -176,8 +181,8 @@ bool isFullyEvaluated(Operation *op) { ListCreateOp, ListConcatOp>(op); } -LogicalResult verifyResult(ClassOp module) { - auto isLegal = [](Operation *op) -> LogicalResult { +LogicalResult verifyResult(ClassOp module, bool allowUnevaluated) { + auto isLegal = [allowUnevaluated](Operation *op) -> LogicalResult { // Check assert satisfied. if (auto assertOp = dyn_cast(op)) { // Check if the condition is a constant false, which means the assertion @@ -205,11 +210,16 @@ LogicalResult verifyResult(ClassOp module) { return checkAssert(true); // This means the condition was not fully evaluated. + if (allowUnevaluated) + return success(); return emitError(op->getLoc(), "failed to evaluate assertion condition"); } - if (!isFullyEvaluated(op)) + if (!isFullyEvaluated(op)) { + if (allowUnevaluated) + return success(); return emitError(op->getLoc()) << "failed to evaluate " << op->getName(); + } return success(); }; @@ -224,13 +234,15 @@ struct ElaborateObjectPass using Base::Base; static LogicalResult elaborateClass(ClassOp classOp, SymbolTable &symTable, - FieldIndex &fieldIndexes) { + FieldIndex &fieldIndexes, + bool allowUnevaluated = false) { // Elaborate objects by inlining all ObjectOps and folding field accesses // using a greedy pattern rewriter. NOTE: The conversion framework is not // suitable here because inlining patterns need to be applied recursively to // fully evaluate nested object instantiations. RewritePatternSet patterns(classOp.getContext()); - patterns.add(classOp.getContext(), symTable); + patterns.add(classOp.getContext(), symTable, + !allowUnevaluated); patterns.add(classOp.getContext(), symTable, fieldIndexes); patterns.add(classOp.getContext()); @@ -241,13 +253,16 @@ struct ElaborateObjectPass return failure(); // Check if elaboration succeeded after saturation. - return verifyResult(classOp); + return verifyResult(classOp, allowUnevaluated); } LogicalResult initialize(MLIRContext *context) override { - if (test.getValue() ^ targetClass.getValue().empty()) + unsigned numModes = + allPublicClasses.getValue() + !targetClass.getValue().empty(); + if (numModes != 1) return emitError(UnknownLoc::get(context)) - << "either 'test' or 'target-class' must be specified"; + << "exactly one of 'target-class' or 'all-public-classes' must " + "be specified"; return success(); } @@ -265,12 +280,15 @@ struct ElaborateObjectPass fieldIndexes[{name, fieldName}] = idx; } - // Test mode: elaborate all zero-argument classes. - if (test) { - for (auto classOp : module.getOps()) - if (classOp.getBodyBlock()->getNumArguments() == 0) - if (failed(elaborateClass(classOp, symTable, fieldIndexes))) - return signalPassFailure(); + // Elaborate all public classes. + if (allPublicClasses) { + for (auto classOp : module.getOps()) { + if (!classOp.isPublic()) + continue; + if (failed(elaborateClass(classOp, symTable, fieldIndexes, + allowUnevaluated))) + return signalPassFailure(); + } return; } @@ -282,7 +300,8 @@ struct ElaborateObjectPass return signalPassFailure(); } - if (failed(elaborateClass(classOp, symTable, fieldIndexes))) + if (failed( + elaborateClass(classOp, symTable, fieldIndexes, allowUnevaluated))) return signalPassFailure(); } }; diff --git a/lib/Firtool/Firtool.cpp b/lib/Firtool/Firtool.cpp index 0e754718520c..8b4d814cc5a3 100644 --- a/lib/Firtool/Firtool.cpp +++ b/lib/Firtool/Firtool.cpp @@ -439,6 +439,12 @@ LogicalResult firtool::populateFinalizeIR(mlir::PassManager &pm, const FirtoolOptions &opt) { pm.addPass(firrtl::createFinalizeIR()); pm.addPass(om::createFreezePathsPass()); + om::ElaborateObjectOptions options; + options.allPublicClasses = true; + options.allowUnevaluated = true; + pm.addPass(om::createElaborateObject(options)); + // TODO: Add SymbolDCE to elimiate unused private classes once after we + // stopped using private classes. return success(); } diff --git a/test/Dialect/OM/elaborate-object-errors.mlir b/test/Dialect/OM/elaborate-object-errors.mlir index 3aac2bba6b69..402241d54a40 100644 --- a/test/Dialect/OM/elaborate-object-errors.mlir +++ b/test/Dialect/OM/elaborate-object-errors.mlir @@ -1,4 +1,4 @@ -// RUN: circt-opt -om-elaborate-object='test=true' %s -verify-diagnostics -split-input-file +// RUN: circt-opt -om-elaborate-object='all-public-classes=true' %s -verify-diagnostics -split-input-file om.class @AssertFalse() { %false = om.constant false @@ -22,7 +22,7 @@ om.class @MultipleAsserts() { // ----- // Multiple assertions in nested classes -om.class @WrapperWithAssert(%in: i1) -> (out: i1) { +om.class private @WrapperWithAssert(%in: i1) -> (out: i1) { // expected-error @below {{OM property assertion failed: wrapper assertion fails}} om.property_assert %in, "wrapper assertion fails" : i1 om.class.fields %in : i1 @@ -49,14 +49,14 @@ om.class @ComplexExpressionFalse() { om.class.fields } -om.class @BoolWrapper(%in: i1) -> (out: i1) { +om.class private @BoolWrapper(%in: i1) -> (out: i1) { om.class.fields %in : i1 } // ----- // Cycle in dataflow (field access creates a cycle that can't be evaluated) -om.class @WrapperCycle(%val: !om.integer) -> (out: !om.integer) { +om.class private @WrapperCycle(%val: !om.integer) -> (out: !om.integer) { om.class.fields %val : !om.integer } diff --git a/test/Dialect/OM/elaborate-object-option-errors.mlir b/test/Dialect/OM/elaborate-object-option-errors.mlir index 58a7b690c4a0..a8dc5929808e 100644 --- a/test/Dialect/OM/elaborate-object-option-errors.mlir +++ b/test/Dialect/OM/elaborate-object-option-errors.mlir @@ -1,5 +1,5 @@ // RUN: circt-opt -om-elaborate-object %s -verify-diagnostics -// expected-error @unknown {{either 'test' or 'target-class' must be specified}} +// expected-error @unknown {{exactly one of 'target-class' or 'all-public-classes' must be specified}} module { om.class @SomeClass() { om.class.fields diff --git a/test/Dialect/OM/elaborate-object.mlir b/test/Dialect/OM/elaborate-object.mlir index 8ac4faaf5e72..843a0790a10f 100644 --- a/test/Dialect/OM/elaborate-object.mlir +++ b/test/Dialect/OM/elaborate-object.mlir @@ -1,5 +1,6 @@ // RUN: circt-opt -om-elaborate-object='target-class=Top' %s | FileCheck %s --check-prefix=TOP -// RUN: circt-opt -om-elaborate-object='test=true' %s | FileCheck %s --check-prefixes=TOP,CHECK +// RUN: circt-opt -om-elaborate-object='target-class=UseExtern' %s | FileCheck %s --check-prefix=STRICT-EXTERN +// RUN: circt-opt -om-elaborate-object='all-public-classes=true allow-unevaluated=true' %s | FileCheck %s --check-prefixes=CHECK,PUBLIC // CHECK-LABEL: om.class @StringOps() -> (str: !om.string, concat: !om.string) { // CHECK-DAG: %[[HELLO:.+]] = om.constant "hello" : !om.string @@ -133,10 +134,16 @@ om.class @AssertUnknown() { // Test external class instantiation (should be replaced with unknown) om.class.extern @ExternalModule(%param: !om.integer) -> (output: !om.integer) {} -// CHECK-LABEL: om.class @UseExtern() -> (result: !om.integer) { -// CHECK: %[[UNKNOWN:.+]] = om.unknown : !om.integer -// CHECK: om.class.fields %[[UNKNOWN]] -// CHECK: } +// STRICT-EXTERN-LABEL: om.class @UseExtern() -> (result: !om.integer) { +// STRICT-EXTERN: %[[UNKNOWN:.+]] = om.unknown : !om.integer +// STRICT-EXTERN: om.class.fields %[[UNKNOWN]] +// STRICT-EXTERN: } +// PUBLIC-LABEL: om.class @UseExtern() -> (result: !om.integer) { +// PUBLIC: %[[INPUT:.+]] = om.constant #om.integer<42 : si64> : !om.integer +// PUBLIC: %[[EXT:.+]] = om.object @ExternalModule(%[[INPUT]]) : (!om.integer) -> !om.class.type<@ExternalModule> +// PUBLIC: %[[RESULT:.+]] = om.object.field %[[EXT]]["output"] : (!om.class.type<@ExternalModule>) -> !om.integer +// PUBLIC: om.class.fields %[[RESULT]] +// PUBLIC: } om.class @UseExtern() -> (result: !om.integer) { %input = om.constant #om.integer<42 : si64> : !om.integer %ext = om.object @ExternalModule(%input) : (!om.integer) -> !om.class.type<@ExternalModule> @@ -144,3 +151,27 @@ om.class @UseExtern() -> (result: !om.integer) { om.class.fields %result : !om.integer } +// Test all-public-classes mode elaborates public classes through private +// helpers, but does not elaborate private top-level classes. +om.class private @PublicModeHelper() -> (value: !om.integer) { + %value = om.constant #om.integer<7 : si4> : !om.integer + om.class.fields %value : !om.integer +} + +// PUBLIC-LABEL: om.class @PublicModeTop() -> (value: !om.integer) { +// PUBLIC-NEXT: %[[value:.+]] = om.constant +// PUBLIC-NEXT: om.class.fields %[[value]] +om.class @PublicModeTop() -> (value: !om.integer) { + %helper = om.object @PublicModeHelper() : () -> !om.class.type<@PublicModeHelper> + %value = om.object.field %helper["value"] : (!om.class.type<@PublicModeHelper>) -> !om.integer + om.class.fields %value : !om.integer +} + +// PUBLIC-LABEL: om.class private @PrivateModeTop() -> (value: !om.integer) { +// PUBLIC: om.object @PublicModeHelper() +// PUBLIC: om.object.field +om.class private @PrivateModeTop() -> (value: !om.integer) { + %helper = om.object @PublicModeHelper() : () -> !om.class.type<@PublicModeHelper> + %value = om.object.field %helper["value"] : (!om.class.type<@PublicModeHelper>) -> !om.integer + om.class.fields %value : !om.integer +} diff --git a/test/firtool/domains.fir b/test/firtool/domains.fir index 1a3dde737961..4ddac3154bfe 100644 --- a/test/firtool/domains.fir +++ b/test/firtool/domains.fir @@ -20,12 +20,12 @@ circuit Foo : input b : UInt<1> domains [A, B] ; CHECK-LABEL: om.class @Foo_Class -; DOMAIN-NEXT: om.object @ClockDomain_out(%basepath, %A, %[[#association:]]) +; DOMAIN-NEXT: om.elaborated_object @ClockDomain_out(%A, %[[#association:]]) ; DOMAIN-NEXT: %[[#a:]] = om.frozenpath_create reference %basepath "Foo>a" ; DOMAIN-NEXT: %[[#b:]] = om.frozenpath_create reference %basepath "Foo>b" ; DOMAIN-NEXT: %[[#association]] = om.list_create %[[#a]], %[[#b]] ; -; DOMAIN: om.object @PowerDomain_out(%basepath, %B, %[[#association:]]) +; DOMAIN: om.elaborated_object @PowerDomain_out(%B, %[[#association:]]) ; DOMAIN-NEXT: %[[#a:]] = om.frozenpath_create reference %basepath "Foo>a" ; DOMAIN-NEXT: %[[#b:]] = om.frozenpath_create reference %basepath "Foo>b" ; DOMAIN-NEXT: %[[#association]] = om.list_create %[[#a]], %[[#b]] diff --git a/test/firtool/om-elaboration-errors.fir b/test/firtool/om-elaboration-errors.fir new file mode 100644 index 000000000000..465172e63f5c --- /dev/null +++ b/test/firtool/om-elaboration-errors.fir @@ -0,0 +1,12 @@ +; RUN: firtool %s -verify-diagnostics --output-final-mlir - -o /dev/null --strip-fir-debug-info=false + +FIRRTL version 6.0.0 +circuit OMElaborationErrors: + class Child: + output cond : Bool + propassign cond, Bool(false) + + public module OMElaborationErrors: + object child_obj of Child + ; expected-error @below {{OM property assertion failed: must hold}} + propassert child_obj.cond, "must hold" diff --git a/test/om-linker/Inputs/elaborate-def.mlir b/test/om-linker/Inputs/elaborate-def.mlir new file mode 100644 index 000000000000..0f97482c2824 --- /dev/null +++ b/test/om-linker/Inputs/elaborate-def.mlir @@ -0,0 +1,6 @@ +module { + om.class @Child() -> (cond: i1) { + %false = om.constant false + om.class.fields %false : i1 + } +} diff --git a/test/om-linker/Inputs/elaborate-use.mlir b/test/om-linker/Inputs/elaborate-use.mlir new file mode 100644 index 000000000000..43dbcdb90674 --- /dev/null +++ b/test/om-linker/Inputs/elaborate-use.mlir @@ -0,0 +1,10 @@ +module { + om.class.extern @Child() -> (cond: i1) {} + + om.class @Top() -> (cond: i1) { + %child = om.object @Child() : () -> !om.class.type<@Child> + %cond = om.object.field %child["cond"] : (!om.class.type<@Child>) -> i1 + om.property_assert %cond, "linked child condition must hold" : i1 + om.class.fields %cond : i1 + } +} diff --git a/test/om-linker/elaborate.mlir b/test/om-linker/elaborate.mlir new file mode 100644 index 000000000000..ce89e90413a0 --- /dev/null +++ b/test/om-linker/elaborate.mlir @@ -0,0 +1,3 @@ +// RUN: not om-linker %S/Inputs/elaborate-def.mlir %S/Inputs/elaborate-use.mlir 2>&1 | FileCheck %s + +// CHECK: error: OM property assertion failed: linked child condition must hold diff --git a/tools/om-linker/om-linker.cpp b/tools/om-linker/om-linker.cpp index 50a875d6d68f..f379cdd61548 100644 --- a/tools/om-linker/om-linker.cpp +++ b/tools/om-linker/om-linker.cpp @@ -16,6 +16,7 @@ #include "circt/Dialect/HW/HWDialect.h" #include "circt/Dialect/LTL/LTLDialect.h" #include "circt/Dialect/OM/OMDialect.h" +#include "circt/Dialect/OM/OMOps.h" #include "circt/Dialect/OM/OMPasses.h" #include "circt/Dialect/SV/SVDialect.h" #include "circt/Dialect/Verif/VerifDialect.h" @@ -55,6 +56,11 @@ static cl::opt cl::desc("Emit bytecode when generating MLIR output"), cl::init(false), cl::cat(mainCategory)); +static cl::opt disableElaboration( + "disable-elaboration", + cl::desc("Disable elaboration of all public OM classes after linking"), + cl::init(false), cl::cat(mainCategory)); + /// Check output stream before writing bytecode to it. /// Warn and return true if output is known to be displayed. static bool checkBytecodeOutputToConsole(raw_ostream &os) { @@ -158,6 +164,12 @@ static LogicalResult executeOMLinker(MLIRContext &context) { // Construct a linker pipeline. pm.addPass(om::createLinkModules()); + if (!disableElaboration) { + om::ElaborateObjectOptions options; + options.allPublicClasses = true; + options.allowUnevaluated = true; + pm.addPass(om::createElaborateObject(options)); + } if (failed(pm.run(module.get()))) return failure();