From a00eacfce74d0d9ddb4b3910e6c408e2b7cef60d Mon Sep 17 00:00:00 2001 From: Tobias Roeser Date: Sun, 19 Nov 2023 12:25:26 +0100 Subject: [PATCH] Put hierarchy checks for test trait behind an overridable def (#2876) This puts all checks/assertions into a `protected def hierarchyChecks`, so users have an option to allow trait mix-ins which don't make sense to us now, by overriding it. I split the logic in a set of matching trait names and the assertion logic. I also added `MavenModule` and `SbtModule` to the mix. To get that working, I moved the logic up into `JavaModule`. Pull request: https://github.com/com-lihaoyi/mill/pull/2876 --- .../scalajslib/ScalaTestsErrorTests.scala | 11 +++++-- scalalib/src/mill/scalalib/JavaModule.scala | 33 ++++++++++++++++++- scalalib/src/mill/scalalib/ScalaModule.scala | 33 +------------------ .../scalanativelib/ScalaTestsErrorTests.scala | 9 ++++- 4 files changed, 49 insertions(+), 37 deletions(-) diff --git a/scalajslib/test/src/mill/scalajslib/ScalaTestsErrorTests.scala b/scalajslib/test/src/mill/scalajslib/ScalaTestsErrorTests.scala index 8a9cb541b1b..f5582116b61 100644 --- a/scalajslib/test/src/mill/scalajslib/ScalaTestsErrorTests.scala +++ b/scalajslib/test/src/mill/scalajslib/ScalaTestsErrorTests.scala @@ -1,6 +1,5 @@ package mill.scalajslib -import mill._ import mill.define.Discover import mill.scalalib.TestModule import mill.util.TestUtil @@ -12,8 +11,10 @@ object ScalaTestsErrorTests extends TestSuite { def scalaVersion = sys.props.getOrElse("TEST_SCALA_3_3_VERSION", ???) def scalaJSVersion = sys.props.getOrElse("TEST_SCALAJS_VERSION", ???) object test extends ScalaTests with TestModule.Utest + object testDisabledError extends ScalaTests with TestModule.Utest { + override def hierarchyChecks(): Unit = {} + } } - override lazy val millDiscover = Discover[this.type] } @@ -24,8 +25,12 @@ object ScalaTestsErrorTests extends TestSuite { } val message = error.getCause.getMessage assert( - message == s"scalaTestsError is a `ScalaJSModule`. scalaTestsError.test needs to extend `ScalaJSTests`." + message == s"scalaTestsError is a `mill.scalajslib.ScalaJSModule`. scalaTestsError.test needs to extend `ScalaJSTests`." ) } + test("extends-ScalaTests-disabled-hierarchy-check") { + // expect no throws exception + ScalaTestsError.scalaTestsError.testDisabledError + } } } diff --git a/scalalib/src/mill/scalalib/JavaModule.scala b/scalalib/src/mill/scalalib/JavaModule.scala index 96b073359dc..d6184a6035f 100644 --- a/scalalib/src/mill/scalalib/JavaModule.scala +++ b/scalalib/src/mill/scalalib/JavaModule.scala @@ -9,7 +9,7 @@ import coursier.parse.ModuleParser import coursier.util.ModuleMatcher import mainargs.Flag import mill.Agg -import mill.api.{Ctx, JarManifest, PathRef, Result, internal} +import mill.api.{Ctx, JarManifest, MillException, PathRef, Result, internal} import mill.define.{Command, ModuleRef, Segment, Task, TaskModule} import mill.scalalib.internal.ModuleUtils import mill.scalalib.api.CompilationResult @@ -34,6 +34,9 @@ trait JavaModule def zincWorker: ModuleRef[ZincWorkerModule] = ModuleRef(mill.scalalib.ZincWorkerModule) trait JavaModuleTests extends JavaModule with TestModule { + // Run some consistence checks + hierarchyChecks() + override def moduleDeps: Seq[JavaModule] = Seq(outer) override def repositoriesTask: Task[Seq[Repository]] = outer.repositoriesTask override def resolutionCustomizer: Task[Option[coursier.Resolution => coursier.Resolution]] = @@ -47,6 +50,34 @@ trait JavaModule PathRef(this.millSourcePath / src.path.relativeTo(outer.millSourcePath)) } } + + /** + * JavaModule and its derivates define inner test modules. + * To avoid unexpected misbehavior due to the use of the wrong inner test trait + * we apply some hierarchy consistency checks. + * If for some reasons, those are too restrictive to you, you can override this method. + * @throws MillException + */ + protected def hierarchyChecks(): Unit = { + val outerInnerSets = Seq( + ("mill.scalajslib.ScalaJSModule", "ScalaJSTests"), + ("mill.scalanativelib.ScalaNativeModule", "ScalaNativeTests"), + ("mill.scalalib.SbtModule", "SbtModuleTests"), + ("mill.scalalib.MavenModule", "MavenModuleTests") + ) + for { + (mod, testModShort) <- outerInnerSets + testMod = s"${mod}$$${testModShort}" + } + try { + if (Class.forName(mod).isInstance(outer) && !Class.forName(testMod).isInstance(this)) + throw new MillException( + s"$outer is a `${mod}`. $this needs to extend `${testModShort}`." + ) + } catch { + case _: ClassNotFoundException => // if we can't find the classes, we certainly are not in a ScalaJSModule + } + } } def defaultCommandName(): String = "run" diff --git a/scalalib/src/mill/scalalib/ScalaModule.scala b/scalalib/src/mill/scalalib/ScalaModule.scala index 5a391444cb8..2071117b7ce 100644 --- a/scalalib/src/mill/scalalib/ScalaModule.scala +++ b/scalalib/src/mill/scalalib/ScalaModule.scala @@ -1,15 +1,7 @@ package mill package scalalib -import mill.api.{ - DummyInputStream, - JarManifest, - MillException, - PathRef, - Result, - SystemStreams, - internal -} +import mill.api.{DummyInputStream, JarManifest, PathRef, Result, SystemStreams, internal} import mill.main.BuildInfo import mill.util.{Jvm, Util} import mill.util.Jvm.createJar @@ -28,29 +20,6 @@ trait ScalaModule extends JavaModule with TestModule.ScalaModuleBase { outer => type ScalaModuleTests = ScalaTests trait ScalaTests extends JavaModuleTests with ScalaModule { - try { - if ( - Class.forName("mill.scalajslib.ScalaJSModule").isInstance(outer) && !Class.forName( - "mill.scalajslib.ScalaJSModule$ScalaJSTests" - ).isInstance(this) - ) throw new MillException( - s"$outer is a `ScalaJSModule`. $this needs to extend `ScalaJSTests`." - ) - } catch { - case _: ClassNotFoundException => // if we can't find the classes, we certainly are not in a ScalaJSModule - } - try { - if ( - Class.forName("mill.scalanativelib.ScalaNativeModule").isInstance(outer) && !Class.forName( - "mill.scalanativelib.ScalaNativeModule$ScalaNativeTests" - ).isInstance(this) - ) throw new MillException( - s"$outer is a `ScalaNativeModule`. $this needs to extend `ScalaNativeTests`." - ) - } catch { - case _: ClassNotFoundException => // if we can't find the classes, we certainly are not in a ScalaNativeModule - } - override def scalaOrganization: Target[String] = outer.scalaOrganization() override def scalaVersion: Target[String] = outer.scalaVersion() override def scalacPluginIvyDeps: Target[Agg[Dep]] = outer.scalacPluginIvyDeps() diff --git a/scalanativelib/test/src/mill/scalanativelib/ScalaTestsErrorTests.scala b/scalanativelib/test/src/mill/scalanativelib/ScalaTestsErrorTests.scala index af843d2102c..fcc13540fcd 100644 --- a/scalanativelib/test/src/mill/scalanativelib/ScalaTestsErrorTests.scala +++ b/scalanativelib/test/src/mill/scalanativelib/ScalaTestsErrorTests.scala @@ -12,6 +12,9 @@ object ScalaTestsErrorTests extends TestSuite { def scalaVersion = sys.props.getOrElse("TEST_SCALA_3_3_VERSION", ???) def scalaNativeVersion = sys.props.getOrElse("TEST_SCALANATIVE_VERSION", ???) object test extends ScalaTests with TestModule.Utest + object testDisabledError extends ScalaTests with TestModule.Utest { + override def hierarchyChecks(): Unit = {} + } } override lazy val millDiscover = Discover[this.type] @@ -24,8 +27,12 @@ object ScalaTestsErrorTests extends TestSuite { } val message = error.getCause.getMessage assert( - message == s"scalaTestsError is a `ScalaNativeModule`. scalaTestsError.test needs to extend `ScalaNativeTests`." + message == s"scalaTestsError is a `mill.scalanativelib.ScalaNativeModule`. scalaTestsError.test needs to extend `ScalaNativeTests`." ) } + test("extends-ScalaTests-disabled-hierarchy-check") { + // expect no throws exception + ScalaTestsError.scalaTestsError.testDisabledError + } } }