Skip to content

Commit

Permalink
Put hierarchy checks for test trait behind an overridable def (#2876)
Browse files Browse the repository at this point in the history
This puts all checks/assertions into a `protected def hierarchyChecks`,
so users have an option to allow trait mix-ins which don't make sense to
us now, by overriding it.

I split the logic in a set of matching trait names and the assertion
logic. I also added `MavenModule` and `SbtModule` to the mix. To get
that working, I moved the logic up into `JavaModule`.

Pull request: #2876
  • Loading branch information
lefou authored Nov 19, 2023
1 parent 9e0ec9a commit a00eacf
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 37 deletions.
11 changes: 8 additions & 3 deletions scalajslib/test/src/mill/scalajslib/ScalaTestsErrorTests.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package mill.scalajslib

import mill._
import mill.define.Discover
import mill.scalalib.TestModule
import mill.util.TestUtil
Expand All @@ -12,8 +11,10 @@ object ScalaTestsErrorTests extends TestSuite {
def scalaVersion = sys.props.getOrElse("TEST_SCALA_3_3_VERSION", ???)
def scalaJSVersion = sys.props.getOrElse("TEST_SCALAJS_VERSION", ???)
object test extends ScalaTests with TestModule.Utest
object testDisabledError extends ScalaTests with TestModule.Utest {
override def hierarchyChecks(): Unit = {}
}
}

override lazy val millDiscover = Discover[this.type]
}

Expand All @@ -24,8 +25,12 @@ object ScalaTestsErrorTests extends TestSuite {
}
val message = error.getCause.getMessage
assert(
message == s"scalaTestsError is a `ScalaJSModule`. scalaTestsError.test needs to extend `ScalaJSTests`."
message == s"scalaTestsError is a `mill.scalajslib.ScalaJSModule`. scalaTestsError.test needs to extend `ScalaJSTests`."
)
}
test("extends-ScalaTests-disabled-hierarchy-check") {
// expect no throws exception
ScalaTestsError.scalaTestsError.testDisabledError
}
}
}
33 changes: 32 additions & 1 deletion scalalib/src/mill/scalalib/JavaModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import coursier.parse.ModuleParser
import coursier.util.ModuleMatcher
import mainargs.Flag
import mill.Agg
import mill.api.{Ctx, JarManifest, PathRef, Result, internal}
import mill.api.{Ctx, JarManifest, MillException, PathRef, Result, internal}
import mill.define.{Command, ModuleRef, Segment, Task, TaskModule}
import mill.scalalib.internal.ModuleUtils
import mill.scalalib.api.CompilationResult
Expand All @@ -34,6 +34,9 @@ trait JavaModule
def zincWorker: ModuleRef[ZincWorkerModule] = ModuleRef(mill.scalalib.ZincWorkerModule)

trait JavaModuleTests extends JavaModule with TestModule {
// Run some consistence checks
hierarchyChecks()

override def moduleDeps: Seq[JavaModule] = Seq(outer)
override def repositoriesTask: Task[Seq[Repository]] = outer.repositoriesTask
override def resolutionCustomizer: Task[Option[coursier.Resolution => coursier.Resolution]] =
Expand All @@ -47,6 +50,34 @@ trait JavaModule
PathRef(this.millSourcePath / src.path.relativeTo(outer.millSourcePath))
}
}

/**
* JavaModule and its derivates define inner test modules.
* To avoid unexpected misbehavior due to the use of the wrong inner test trait
* we apply some hierarchy consistency checks.
* If for some reasons, those are too restrictive to you, you can override this method.
* @throws MillException
*/
protected def hierarchyChecks(): Unit = {
val outerInnerSets = Seq(
("mill.scalajslib.ScalaJSModule", "ScalaJSTests"),
("mill.scalanativelib.ScalaNativeModule", "ScalaNativeTests"),
("mill.scalalib.SbtModule", "SbtModuleTests"),
("mill.scalalib.MavenModule", "MavenModuleTests")
)
for {
(mod, testModShort) <- outerInnerSets
testMod = s"${mod}$$${testModShort}"
}
try {
if (Class.forName(mod).isInstance(outer) && !Class.forName(testMod).isInstance(this))
throw new MillException(
s"$outer is a `${mod}`. $this needs to extend `${testModShort}`."
)
} catch {
case _: ClassNotFoundException => // if we can't find the classes, we certainly are not in a ScalaJSModule
}
}
}

def defaultCommandName(): String = "run"
Expand Down
33 changes: 1 addition & 32 deletions scalalib/src/mill/scalalib/ScalaModule.scala
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
package mill
package scalalib

import mill.api.{
DummyInputStream,
JarManifest,
MillException,
PathRef,
Result,
SystemStreams,
internal
}
import mill.api.{DummyInputStream, JarManifest, PathRef, Result, SystemStreams, internal}
import mill.main.BuildInfo
import mill.util.{Jvm, Util}
import mill.util.Jvm.createJar
Expand All @@ -28,29 +20,6 @@ trait ScalaModule extends JavaModule with TestModule.ScalaModuleBase { outer =>
type ScalaModuleTests = ScalaTests

trait ScalaTests extends JavaModuleTests with ScalaModule {
try {
if (
Class.forName("mill.scalajslib.ScalaJSModule").isInstance(outer) && !Class.forName(
"mill.scalajslib.ScalaJSModule$ScalaJSTests"
).isInstance(this)
) throw new MillException(
s"$outer is a `ScalaJSModule`. $this needs to extend `ScalaJSTests`."
)
} catch {
case _: ClassNotFoundException => // if we can't find the classes, we certainly are not in a ScalaJSModule
}
try {
if (
Class.forName("mill.scalanativelib.ScalaNativeModule").isInstance(outer) && !Class.forName(
"mill.scalanativelib.ScalaNativeModule$ScalaNativeTests"
).isInstance(this)
) throw new MillException(
s"$outer is a `ScalaNativeModule`. $this needs to extend `ScalaNativeTests`."
)
} catch {
case _: ClassNotFoundException => // if we can't find the classes, we certainly are not in a ScalaNativeModule
}

override def scalaOrganization: Target[String] = outer.scalaOrganization()
override def scalaVersion: Target[String] = outer.scalaVersion()
override def scalacPluginIvyDeps: Target[Agg[Dep]] = outer.scalacPluginIvyDeps()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ object ScalaTestsErrorTests extends TestSuite {
def scalaVersion = sys.props.getOrElse("TEST_SCALA_3_3_VERSION", ???)
def scalaNativeVersion = sys.props.getOrElse("TEST_SCALANATIVE_VERSION", ???)
object test extends ScalaTests with TestModule.Utest
object testDisabledError extends ScalaTests with TestModule.Utest {
override def hierarchyChecks(): Unit = {}
}
}

override lazy val millDiscover = Discover[this.type]
Expand All @@ -24,8 +27,12 @@ object ScalaTestsErrorTests extends TestSuite {
}
val message = error.getCause.getMessage
assert(
message == s"scalaTestsError is a `ScalaNativeModule`. scalaTestsError.test needs to extend `ScalaNativeTests`."
message == s"scalaTestsError is a `mill.scalanativelib.ScalaNativeModule`. scalaTestsError.test needs to extend `ScalaNativeTests`."
)
}
test("extends-ScalaTests-disabled-hierarchy-check") {
// expect no throws exception
ScalaTestsError.scalaTestsError.testDisabledError
}
}
}

0 comments on commit a00eacf

Please sign in to comment.