.xml`.
+
+To add new languages, create a separate module for this language and register its implementation as an extension of
+the `psiHelperProvider` EP. Then follow the template provided above.
+
+### 2. Prompt Generation
+
+When we know how to parse the code, we need to construct the prompt.
+
+For each language, adjust the prompt that goes to the LLM. Ensure that the language, framework platform, and mocking
+framework are defined correctly in:
+
+```kotlin
+data class PromptConfiguration(
+ val desiredLanguage: String,
+ val desiredTestingPlatform: String,
+ val desiredMockingFramework: String,
+)
+```
+
+Additionally, check that all the dependencies (collected by `PsiHelper` for the current strategy) are passed
+properly. `PromptGenerator` and `PromptBuilder` are responsible for this job.
+
+### 3. Parsing LLM Response
+
+When the LLM response to our prompt is received, we have to parse it.
+
+We want to retrieve test case, all the test functions and additional information like imports or supporting functions
+from the response.
+
+The current structure of this part is located in:
+
+- `kotlin/org/jetbrains/research/testspark/core/test`
+- `kotlin/org/jetbrains/research/testspark/tools`
+
+It can be more easily understood with the following diagram:
+![](https://private-user-images.githubusercontent.com/70476032/349256986-dc7e1ff9-a9a5-4bd2-a51f-ecbfabeb6cba.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjIzNTEyOTAsIm5iZiI6MTcyMjM1MDk5MCwicGF0aCI6Ii83MDQ3NjAzMi8zNDkyNTY5ODYtZGM3ZTFmZjktYTlhNS00YmQyLWE1MWYtZWNiZmFiZWI2Y2JhLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA3MzAlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwNzMwVDE0NDk1MFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWJjMDg3MWM2ZDA4MDJlZGUwNzliMzNkNzA3YWI4YTcwM2RmYTFjMmE1MGM4MjM5NjJiOGI2ZjgxNTE2OTU2YjQmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.8OfRa1wJhDfFq3QT6h5yIjBh1VqB9UrrQfZGp0_SLDo)
+
+- `TestsAssembler`: Assembler class for generating and organizing test cases from the LLM response.
+- `TestSuiteParser`: Extracts test cases from raw text and generates a test suite.
+- `TestBodyPrinter`: Generates the body of a test function as a string.
+
+### 4. Compilation
+
+Before showing the code to the user, it should be checked for compilation.
+
+- `TestCompiler`: Compiles a list of test cases and returns the compilation result.
+
+Here one should specify the appropriate compilation strategy for each language. With all the dependencies and build paths.
+
+### 5. UI Representation
+
+Once the code generated by the LLM is checked for the compilation, it should be presented in the UI.
+
+- `TestCaseDisplayService`: Service responsible for the representation of all the UI components.
+- `TestSuiteView`: Interface specific for working with buttons.
+- `TestClassCodeAnalyzer`: Interface for retrieving information from test class code.
+- `TestClassCodeGenerator`: Interface for generating and formatting test class code.
+
+### 6. Running and saving tests
+
+We should be able to run all the tests in the UI and then save them to the desired folder.
+
+- `TestPersistentStorage`: Interface representing a contract for saving generated tests to a specified file system location.
+
+For Kotlin and Java, the `TestProcessor` implementation also allows saving the JaCoCo report to see the code coverage of
+the test that will be saved.
+
+---
+
## Plugin Configuration File
The plugin configuration file is `plugin.xml` which can be found in `src/main/resources/META-INF` directory. All declarations (such as actions, services, listeners) are present in this file.
@@ -86,30 +229,7 @@ All the listener classes can be found in `listeners` directory.
### Services
All the service classes can be found in `services` directory.
-
-- `CoverageSelectionToggleListener` is a topic interface for showing or hiding highlighting when the coverage is toggled for one test or many tests.
-- `CoverageToolWindowDisplayService` creates the *"Coverage visualisation"* tool window panel and the coverage table to display the test coverage data of the tests generated by EvoSuite.
-- `CoverageVisualisationService` visualises the coverage in the gutter and the editor (by colouring), injects the coverage data into the *"Coverage visualisation"* tool window tab.
-- `ErrorService`service class for handling error occurrences.
-- `QuickAccessParametersService` allows to load and get the state of the parameters in the *"Parameters"* tool window panel.
-- `RunnerService` is used to limit TestSpark to generate tests only once at a time.
-- `SettingsApplicationService` stores the application-level settings persistently. It uses `SettingsApplicationState` class for that.
-- `SettingsProjectService` stores the project-level settings persistently. It uses `SettingsProjectState` class for that.
-- `StaticInvalidationService` invalidates the cache statically.
-- `TestCaseCachingService` contains the data structure for caching the generated test cases and is responsible for adding, retrieving and removing (invalidating) the generated tests.
-- `TestCaseDisplayService` displays the tests generated by EvoSuite, in the *"Generated tests"* tool window panel.
-- `TestSparkTelemetryService` sends usage information to Intelligent Collaboration Tools Lab at JetBrains Research if the user has opted in.
-
-### Settings
-
-All the classes related to TestSpark `Settings/Preferences` page can be found in `settings` directory.
-
-- `SettingsApplicationState` is responsible for storing the values of the EvoSuite Settings entries.
-- `SettingsEvoSuiteComponent` displays and captures the changes to the values of the entries in the EvoSuite page of the Settings dialog.
-- `SettingsEvoSuiteConfigurable` allows to configure some EvoSuite settings via the EvoSuite page in the Settings dialog, observes the changes and manages the UI and state.
-- `SettingsPluginComponent` displays and captures the changes to the values of the entries in the TestSpark main page of the Settings dialog.
-- `SettingsPluginConfigurable` allows to configure some Plugin settings via the Plugin page in the Settings dialog, observes the changes and manages the UI and state.
-- `SettingsProjectState` is responsible for storing the values of the Plugin Settings entries.
+We currently have three services for managing EvoSuite settings (`EvoSuiteSettingsService`), the LLM-based approach/generation (`LLMSettingsService`), and general plugin settings (`PluginSettingsService`).
### Tools
diff --git a/README.md b/README.md
index 3cc9024e0..fe2f7295b 100644
--- a/README.md
+++ b/README.md
@@ -21,14 +21,14 @@ TestSpark is a plugin for generating unit tests. TestSpark natively integrates d
TestSpark currently supports two test generation strategies:
- - LLM-based test generation (using OpenAI and JetBrains internal AI Assistant platform)
+ - LLM-based test generation (using OpenAI, HuggingFace, and JetBrains internal AI Assistant platform)
- Local search-based test generation (using EvoSuite)
LLM-based test generation
For this type of test generation, TestSpark sends request to different Large Language Models. Also, it automatically checks if tests are valid before presenting it to users.
- This feature needs a token from OpenAI platform or the AI Assistant platform.
+ This feature needs a token from OpenAI, HuggingFace, or the AI Assistant platform.
- - Supports Java (any version).
+ - Supports Java (any version) and Kotlin (K2 mode should be disabled, checkout the Settings section on README).
- Generates unit tests for capturing failures.
- Generate tests for Java classes, methods, and single lines.
@@ -70,6 +70,7 @@ If you are running the plugin for the first time, checkout the [Settings](#setti
- [Coverage](#coverage)
- [Integrating tests into the project](#integrating-tests-into-the-project)
- [Settings](#settings)
+- [Disable K2 for Kotlin Test Generation](#disable-K2)
- [Telemetry](#telemetry-opt-in)
### Generating Tests
@@ -229,7 +230,9 @@ Or to a new file:
![Tests adding to a new file](readme-images/gifs/AddingToANewFile.gif#gh-light-mode-only)
![Tests adding to a new file_dark](readme-images/gifs/AddingToANewFile_dark.gif#gh-dark-mode-only)
-
+### Disable K2
+For LLM-based Kotlin test generation, you need to disable the K2 mode for now.
+![Disable K2 mode](readme-images/pngs/k2-mode/disable-k2.png)
### Settings
The plugin is configured mainly through the Settings menu. The plugin settings can be found under Settings > Tools > TestSpark. Here, the user is able to select options for the plugin:
diff --git a/build.gradle.kts b/build.gradle.kts
index ad61ed0b3..c09e328e3 100644
--- a/build.gradle.kts
+++ b/build.gradle.kts
@@ -1,4 +1,8 @@
import org.jetbrains.changelog.markdownToHTML
+import org.jetbrains.intellij.platform.gradle.IntelliJPlatformType
+import org.jetbrains.intellij.platform.gradle.TestFrameworkType
+import org.jetbrains.intellij.platform.gradle.models.ProductRelease
+import org.jetbrains.intellij.platform.gradle.tasks.RunIdeTask
import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
import java.io.FileOutputStream
import java.net.URL
@@ -9,20 +13,24 @@ import java.util.zip.ZipInputStream
fun properties(key: String) = project.findProperty(key).toString()
-val thunderdomeVersion = "1.0.5"
-
+// Space credentials
val spaceUsername =
System.getProperty("space.username")?.toString() ?: project.properties["spaceUsername"]?.toString() ?: ""
val spacePassword =
System.getProperty("space.pass")?.toString() ?: project.properties["spacePassword"]?.toString() ?: ""
+// the test generation module for interacting with Grazie (used when the space credentials are provided)
+val grazieTestGenerationVersion = "1.0.5"
+
plugins {
// Java support
id("java")
// Kotlin support
id("org.jetbrains.kotlin.jvm") version "1.9.0"
// Gradle IntelliJ Plugin
- id("org.jetbrains.intellij") version "1.15.0"
+ id("org.jetbrains.intellij.platform") version "2.1.0"
+ // Gradle IntelliJ Plugin Migration Help (uncomment it for migration tips)
+// id("org.jetbrains.intellij.platform.migration") version "2.1.0"
// Gradle Changelog Plugin
id("org.jetbrains.changelog") version "2.1.2"
// Gradle Qodana Plugin
@@ -34,6 +42,11 @@ version = properties("pluginVersion")
// Configure project's dependencies
repositories {
mavenCentral()
+ // this part is mandatory for all modules for platform version 2:
+ // See https://plugins.jetbrains.com/docs/intellij/tools-intellij-platform-gradle-plugin-repositories-extension.html#default-repositories
+ intellijPlatform {
+ defaultRepositories()
+ }
maven("https://packages.jetbrains.team/maven/p/ij/intellij-dependencies")
maven("https://www.jetbrains.com/intellij-repository/snapshots")
@@ -63,9 +76,16 @@ if (spaceCredentialsProvided()) {
usingSourceSet(hasGrazieAccess)
}
+ // Add the dependencies for the new source set
+ dependencies {
+ add(hasGrazieAccess.implementationConfigurationName, kotlin("stdlib"))
+ add(hasGrazieAccess.implementationConfigurationName, "org.jetbrains.kotlinx:kotlinx-coroutines-core:1.7.3")
+ add(hasGrazieAccess.implementationConfigurationName, "org.jetbrains.research:grazie-test-generation:$grazieTestGenerationVersion")
+ }
+
tasks.register("checkCredentials") {
configurations.detachedConfiguration(
- dependencies.create("org.jetbrains.research:grazie-test-generation:1.0.1"),
+ dependencies.create("org.jetbrains.research:grazie-test-generation:$grazieTestGenerationVersion"),
).files()
}
@@ -73,8 +93,12 @@ if (spaceCredentialsProvided()) {
dependsOn("checkCredentials")
}
+ tasks.named(hasGrazieAccess.jarTaskName) {
+ exclude("**/plugin.xml")
+ }
+
// add build of new source set as the part of UI testing
- tasks.prepareUiTestingSandbox.configure {
+ tasks.prepareTestSandbox.configure {
dependsOn(hasGrazieAccess.jarTaskName)
from(tasks.getByName(hasGrazieAccess.jarTaskName).outputs.files.asPath) { into("TestSpark/lib") }
@@ -96,16 +120,34 @@ if (spaceCredentialsProvided()) {
}
dependencies {
- implementation(files("lib/evosuite-$thunderdomeVersion.jar"))
+ // Check platform V2 documentation for more details: https://plugins.jetbrains.com/docs/intellij/tools-intellij-platform-gradle-plugin-dependencies-extension.html
+ intellijPlatform {
+ // make a custom version of IDEA
+ create(properties("platformType"), properties("platformVersion"))
+ // Plugin Dependencies. Uses `platformPlugins` property from the gradle.properties file.
+ bundledPlugins(providers.gradleProperty("platformPlugins").map { it.split(',') })
+
+ pluginVerifier()
+ zipSigner()
+ instrumentationTools()
+
+ testFramework(TestFrameworkType.Bundled)
+ }
+
+ implementation(files("lib/evosuite-${properties("evosuiteVersion")}.jar"))
implementation(files("lib/standalone-runtime.jar"))
implementation(files("lib/jacocoagent.jar"))
implementation(files("lib/jacococli.jar"))
+ implementation(files("lib/opentest4j-1.1.1.jar"))
implementation(files("lib/mockito-core-5.0.0.jar"))
implementation(files("lib/byte-buddy-1.14.6.jar"))
implementation(files("lib/byte-buddy-agent-1.14.6.jar"))
implementation(files("lib/JUnitRunner.jar"))
implementation(project(":core"))
+ implementation(project(":langwrappers")) // Needed to use Psi related interfaces and load proper implementation
+ implementation(project(":kotlin")) // Needed to load the testspark-kotlin.xml
+ implementation(project(":java")) // Needed to load the testspark-java.xml
if (spaceCredentialsProvided()) {
"hasGrazieAccessCompileOnly"(project(":core"))
}
@@ -123,7 +165,7 @@ dependencies {
implementation("org.junit.jupiter:junit-jupiter-engine:5.10.0")
// https://mvnrepository.com/artifact/org.jacoco/org.jacoco.core
- implementation("org.jacoco:org.jacoco.core:0.8.8")
+ implementation("org.jacoco:org.jacoco.core:0.8.12")
// https://mvnrepository.com/artifact/com.github.javaparser/javaparser-core
implementation("com.github.javaparser:javaparser-symbol-solver-core:3.24.2")
@@ -151,6 +193,7 @@ dependencies {
// https://mvnrepository.com/artifact/org.mockito/mockito-all
testImplementation("org.mockito:mockito-all:1.10.19")
+ testImplementation("org.mockito.kotlin:mockito-kotlin:5.1.0")
// https://mvnrepository.com/artifact/net.jqwik/jqwik
testImplementation("net.jqwik:jqwik:1.6.5")
@@ -161,23 +204,41 @@ dependencies {
implementation("org.jetbrains.kotlin:kotlin-test:1.8.0")
implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.6.3")
-
- if (spaceCredentialsProvided()) {
- // Dependencies for hasGrazieAccess variant
- "hasGrazieAccessImplementation"(kotlin("stdlib"))
- "hasGrazieAccessImplementation"("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.7.3")
- "hasGrazieAccessImplementation"("org.jetbrains.research:grazie-test-generation:1.0.4")
- }
}
-// Configure Gradle IntelliJ Plugin - read more: https://github.com/JetBrains/gradle-intellij-plugin
-intellij {
- pluginName.set(properties("pluginName"))
- version.set(properties("platformVersion"))
- type.set(properties("platformType"))
+// Configure Gradle IntelliJ Plugin - read more: // Configure Gradle IntelliJ Plugin - read more: https://github.com/JetBrains/gradle-intellij-plugin
+intellijPlatform {
+ pluginConfiguration {
+ name = properties("pluginName")
+ version = properties("pluginVersion")
- // Plugin Dependencies. Uses `platformPlugins` property from the gradle.properties file.
- plugins.set(properties("platformPlugins").split(',').map(String::trim).filter(String::isNotEmpty))
+ ideaVersion {
+ sinceBuild = properties("pluginSinceBuild")
+ untilBuild = properties("pluginUntilBuild")
+ }
+ }
+
+ publishing {
+ token = System.getenv("PUBLISH_TOKEN")
+ channels = listOf(properties("pluginVersion").split('-').getOrElse(1) { "default" }.split('.').first())
+ }
+ // Set the ides on which the plugin verification is executed.
+ // See https://plugins.jetbrains.com/docs/intellij/tools-intellij-platform-gradle-plugin-extension.html#intellijPlatform-pluginVerification-ides
+ pluginVerification {
+ ides {
+ recommended()
+ select {
+ types = listOf(IntelliJPlatformType.IntellijIdeaUltimate)
+ channels = listOf(ProductRelease.Channel.RELEASE)
+ sinceBuild = properties("pluginSinceBuild")
+ untilBuild = properties("pluginUntilBuild")
+ }
+ }
+ freeArgs = listOf(
+ "-mute",
+ "TemplateWordInPluginId,ForbiddenPluginIdPrefix"
+ )
+ }
}
// Configure Gradle Changelog Plugin - read more: https://github.com/JetBrains/gradle-changelog-plugin
@@ -201,6 +262,7 @@ tasks {
dependsOn("copyJUnitRunnerLib")
dependsOn(":core:compileKotlin")
}
+
// Set the JVM compatibility versions
properties("javaVersion").let {
withType {
@@ -223,11 +285,13 @@ tasks {
}
}
- patchPluginXml {
- version.set(properties("pluginVersion"))
- sinceBuild.set(properties("pluginSinceBuild"))
- untilBuild.set(properties("pluginUntilBuild"))
+ signPlugin {
+ certificateChain.set(providers.environmentVariable("CERTIFICATE_CHAIN"))
+ privateKey.set(providers.environmentVariable("PRIVATE_KEY"))
+ password.set(providers.environmentVariable("PRIVATE_KEY_PASSWORD"))
+ }
+ patchPluginXml {
// Extract the section from README.md and provide for the plugin's manifest
pluginDescription.set(
projectDir.resolve("README.md").readText().lines().run {
@@ -251,33 +315,8 @@ tasks {
)
}
- // Configure UI tests plugin
- // Read more: https://github.com/JetBrains/intellij-ui-test-robot
- runIdeForUiTests {
- systemProperty("robot-server.port", "8082")
- systemProperty("ide.mac.message.dialogs.as.sheets", "false")
- systemProperty("jb.privacy.policy.text", "")
- systemProperty("jb.consents.confirmation.enabled", "false")
- systemProperty("idea.trust.all.projects", "true")
- systemProperty("ide.show.tips.on.startup.default.value", "false")
- systemProperty("jb.consents.confirmation.enabled", "false")
- systemProperty("ide.mac.file.chooser.native", "false")
- systemProperty("apple.laf.useScreenMenuBar", "false")
- }
-
- signPlugin {
- certificateChain.set(System.getenv("CERTIFICATE_CHAIN").trimIndent())
- privateKey.set(System.getenv("PRIVATE_KEY").trimIndent())
- password.set(System.getenv("PRIVATE_KEY_PASSWORD"))
- }
-
publishPlugin {
dependsOn("patchChangelog")
- token.set(System.getenv("PUBLISH_TOKEN"))
- // pluginVersion is based on the SemVer (https://semver.org) and supports pre-release labels, like 2.1.7-alpha.3
- // Specify pre-release label to publish the plugin in a custom Release Channel automatically. Read more:
- // https://plugins.jetbrains.com/docs/intellij/deployment.html#specifying-a-release-channel
- channels.set(listOf(properties("pluginVersion").split('-').getOrElse(1) { "default" }.split('.').first()))
}
}
@@ -313,7 +352,7 @@ abstract class CopyJUnitRunnerLib : DefaultTask() {
*/
abstract class UpdateEvoSuite : DefaultTask() {
@Input
- var version: String = ""
+ var evoSuiteVersion: String = ""
@TaskAction
fun execute() {
@@ -322,7 +361,7 @@ abstract class UpdateEvoSuite : DefaultTask() {
libDir.mkdirs()
}
- val jarName = "evosuite-$version.jar"
+ val jarName = "evosuite-$evoSuiteVersion.jar"
if (libDir.listFiles()?.any { it.name.matches(Regex(jarName)) } == true) {
logger.info("Specified evosuite jar found, skipping update")
@@ -331,7 +370,7 @@ abstract class UpdateEvoSuite : DefaultTask() {
logger.info("Specified evosuite jar not found, downloading release $jarName")
val downloadUrl =
- "https://github.com/ciselab/evosuite/releases/download/thunderdome/release/$version/release.zip"
+ "https://github.com/ciselab/evosuite/releases/download/thunderdome/release/$evoSuiteVersion/release.zip"
val stream =
try {
URL(downloadUrl).openStream()
@@ -359,12 +398,15 @@ abstract class UpdateEvoSuite : DefaultTask() {
}
tasks.register("updateEvosuite") {
- version = thunderdomeVersion
+ evoSuiteVersion = properties("evosuiteVersion")
}
-
+/**
+ * Copies the JUnitRunner.jar file to the lib directory of the project.
+ * This task depends on the "JUnitRunner" module being built beforehand.
+ * JUnitRunner.jar is required for running tests with coverage in the main plugin
+ */
tasks.register("copyJUnitRunnerLib") {
dependsOn(":JUnitRunner:jar")
-
val libName = "JUnitRunner.jar"
val libSrcDir =
"${project.projectDir}${File.separator}JUnitRunner${File.separator}build${File.separator}libs${File.separator}"
@@ -375,4 +417,48 @@ tasks.register("copyJUnitRunnerLib") {
into(libDestDir)
}
+/**
+ * Returns the original string if it is not null, or the default string if the original string is null.
+ *
+ * @param default the default string to return if the original string is null
+ * @return the original string if it is not null, or the default string if the original string is null
+ */
+fun String?.orDefault(default: String): String = this ?: default
+
+/**
+ * This code sets up a Gradle task for running the plugin in headless mode
+ *
+ * @param root The root directory of the project under test.
+ * @param file The file containing unit under test.
+ * @param cut The class under test.
+ * @param cp The classpath of the project.
+ * @param llm The model used for the test generation task.
+ * @param token The token for using LLM.
+ * @param prompt a txt file containing the LLM's prompt template
+ * @param out The output directory for the project.
+ * @param enableCoverage flag to enable/disable coverage computation
+ */
+tasks.create("headless") {
+ val root: String? by project
+ val file: String? by project
+ val cut: String? by project
+ val cp: String? by project
+ val junitv: String? by project
+ val llm: String? by project
+ val token: String? by project
+ val prompt: String? by project
+ val out: String? by project
+ val enableCoverage: String? by project
+
+ args = listOfNotNull("testspark", root, file, cut, cp, junitv, llm, token, prompt, out, enableCoverage.orDefault("false"))
+
+ jvmArgs(
+ "-Xmx16G",
+ "-Djava.awt.headless=true",
+ "--add-exports",
+ "java.base/jdk.internal.vm=ALL-UNNAMED",
+ "-Didea.system.path",
+ )
+}
+
fun spaceCredentialsProvided() = spaceUsername.isNotEmpty() && spacePassword.isNotEmpty()
diff --git a/core/build.gradle.kts b/core/build.gradle.kts
index f49c21613..f9ae862d2 100644
--- a/core/build.gradle.kts
+++ b/core/build.gradle.kts
@@ -24,7 +24,17 @@ tasks.test {
useJUnitPlatform()
}
kotlin {
- jvmToolchain(17)
+ jvmToolchain(rootProject.properties["jvmToolchainVersion"].toString().toInt())
+}
+
+tasks.register("sourcesJar") {
+ from(sourceSets.main.get().allSource)
+ archiveClassifier.set("sources")
+}
+
+tasks.register("javadocJar") {
+ from(tasks.named("javadoc"))
+ archiveClassifier.set("javadoc")
}
publishing {
@@ -32,10 +42,28 @@ publishing {
create("maven") {
groupId = group as String
artifactId = "testspark-core"
- version = "2.0.4"
+ version = "5.0.1"
from(components["java"])
+
+ artifact(tasks["sourcesJar"])
+ artifact(tasks["javadocJar"])
+
+ pom {
+ inceptionYear.set("2024")
+ name.set(project.name)
+ description.set(project.description)
+ packaging = "jar"
+
+ licenses {
+ license {
+ name.set("Apache-2.0")
+ url.set("https://www.apache.org/licenses/LICENSE-2.0")
+ }
+ }
+ }
}
}
+
repositories {
maven {
url = uri("https://packages.jetbrains.team/maven/p/automatically-generating-unit-tests/public")
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/ChatMessage.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/ChatMessage.kt
index e91a36f2f..fb83654fc 100644
--- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/ChatMessage.kt
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/ChatMessage.kt
@@ -1,6 +1,14 @@
package org.jetbrains.research.testspark.core.data
-data class ChatMessage(
- val role: String,
+open class ChatMessage protected constructor(
+ val role: ChatRole,
val content: String,
-)
+) {
+ enum class ChatRole {
+ User,
+ Assistant,
+ }
+}
+
+class ChatUserMessage(content: String) : ChatMessage(ChatRole.User, content)
+class ChatAssistantMessage(content: String) : ChatMessage(ChatRole.Assistant, content)
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/ClassType.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/ClassType.kt
new file mode 100644
index 000000000..8b6114e0f
--- /dev/null
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/ClassType.kt
@@ -0,0 +1,15 @@
+package org.jetbrains.research.testspark.core.data
+
+/**
+ * Enumeration representing different types of classes.
+ *
+ * @param representation The string representation of the class type.
+ */
+enum class ClassType(val representation: String) {
+ INTERFACE("interface"),
+ ABSTRACT_CLASS("abstract class"),
+ CLASS("class"),
+ DATA_CLASS("data class"),
+ INLINE_VALUE_CLASS("inline value class"),
+ OBJECT("object"),
+}
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/Report.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/Report.kt
index cbb0d731d..a91faa799 100644
--- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/Report.kt
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/Report.kt
@@ -1,17 +1,26 @@
package org.jetbrains.research.testspark.core.data
/**
- * Storage of generated tests. Implemented on the basis of org.evosuite.utils.CompactReport structure.
+ * Stores generated test cases and their coverage.
+ * Implemented on the basis of `org.evosuite.utils.CompactReport` structure.
+ *
+ * `Report`'s member fields were created based on the fields in
+ * `org.evosuite.utils.CompactReport` for easier transformation.
*/
open class Report {
// Fields were created based on the fields in org.evosuite.utils.CompactReport for easier transformation
- var UUT: String = "" // Unit Under Test
+ /**
+ * Unit Under Test. This variable stores the name of the class or component that is being tested.
+ */
+ var UUT: String = ""
var allCoveredLines: Set = setOf()
var allUncoveredLines: Set = setOf()
var testCaseList: HashMap = hashMapOf()
/**
- * AllCoveredLines update
+ * Calculates the normalized report by updating the set of all covered lines.
+ *
+ * @return The normalized report.
*/
fun normalized(): Report {
allCoveredLines = testCaseList.values.map { it.coveredLines }.flatten().toSet()
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt
index d11f346d5..64c215e93 100644
--- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt
@@ -1,7 +1,5 @@
package org.jetbrains.research.testspark.core.data
-import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM
-
data class TestGenerationData(
// Result processing
// Report object for each test case
@@ -16,32 +14,27 @@ data class TestGenerationData(
// Code required of imports and package for generated tests
var importsCode: MutableSet = mutableSetOf(),
- var packageLine: String = "",
+ var packageName: String = "",
var runWith: String = "",
var otherInfo: String = "",
// changing parameters with a large prompt
var polyDepthReducing: Int = 0,
var inputParamsDepthReducing: Int = 0,
-
- // list of correct test cases during the incorrect compilation
- val compilableTestCases: MutableSet = mutableSetOf(),
-
) {
/**
- * Cleaning all old data before new test generation.
+ * Cleaning all old data before a new test generation.
*/
fun clear() {
testGenerationResultList.clear()
resultName = ""
fileUrl = ""
importsCode = mutableSetOf()
- packageLine = ""
+ packageName = ""
runWith = ""
otherInfo = ""
polyDepthReducing = 0
inputParamsDepthReducing = 0
- compilableTestCases.clear()
}
}
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/exception/Exceptions.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/exception/Exceptions.kt
new file mode 100644
index 000000000..c1b1f337f
--- /dev/null
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/exception/Exceptions.kt
@@ -0,0 +1,38 @@
+package org.jetbrains.research.testspark.core.exception
+
+/**
+ * Represents custom exceptions within TestSpark.
+ *
+ * This class serves as a base class for specific exceptions.
+ *
+ * @param message A descriptive message explaining the error that led to the exception.
+ */
+sealed class TestSparkException(message: String) : RuntimeException(message)
+
+/**
+ * Custom exception to indicate that the Kotlin compiler was not found.
+ *
+ * @param message A descriptive message explaining the error.
+ */
+class KotlinCompilerNotFoundException(message: String) : TestSparkException(message)
+
+/**
+ * Custom exception to indicate that the Java compiler was not found.
+ *
+ * @param message A descriptive message explaining the error.
+ */
+class JavaCompilerNotFoundException(message: String) : TestSparkException(message)
+
+/**
+ * Represents an exception thrown when a required Java SDK is missing in the system.
+ *
+ * @param message A descriptive message explaining the specific error that led to this exception.
+ */
+class JavaSDKMissingException(message: String) : TestSparkException(message)
+
+/**
+ * Represents an exception thrown when a class file could not be found in the same path after the code compilation.
+ *
+ * @param message A descriptive message explaining the error
+ */
+class ClassFileNotFoundException(message: String) : TestSparkException(message)
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt
index cdd26f994..31fcde547 100644
--- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt
@@ -7,7 +7,10 @@ import org.jetbrains.research.testspark.core.generation.llm.network.LLMResponse
import org.jetbrains.research.testspark.core.generation.llm.network.RequestManager
import org.jetbrains.research.testspark.core.generation.llm.network.ResponseErrorCode
import org.jetbrains.research.testspark.core.generation.llm.prompt.PromptSizeReductionStrategy
+import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor
+import org.jetbrains.research.testspark.core.monitor.ErrorMonitor
import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator
+import org.jetbrains.research.testspark.core.test.SupportedLanguage
import org.jetbrains.research.testspark.core.test.TestCompiler
import org.jetbrains.research.testspark.core.test.TestsAssembler
import org.jetbrains.research.testspark.core.test.TestsPersistentStorage
@@ -24,27 +27,57 @@ enum class FeedbackCycleExecutionResult {
SAVING_TEST_FILES_ISSUE,
}
+/**
+ * Represents a response (result) of a feedback cycle.
+ *
+ * @param executionResult The result of executing the feedback cycle.
+ * @param generatedTestSuite The test suite generated by LLM.
+ * @param compilableTestCases The set of compilable test cases generated by LLM.
+ *
+ * @throws IllegalArgumentException if `executionResult` is [FeedbackCycleExecutionResult.OK] and `generatedTestSuite` is null.
+ */
data class FeedbackResponse(
val executionResult: FeedbackCycleExecutionResult,
val generatedTestSuite: TestSuiteGeneratedByLLM?,
val compilableTestCases: MutableSet,
) {
init {
- if (executionResult == FeedbackCycleExecutionResult.OK && generatedTestSuite == null) {
- throw IllegalArgumentException("Test suite must be provided when FeedbackCycleExecutionResult is OK, got null")
- } else if (executionResult != FeedbackCycleExecutionResult.OK && generatedTestSuite != null) {
+ if ((executionResult == FeedbackCycleExecutionResult.OK || executionResult == FeedbackCycleExecutionResult.NO_COMPILABLE_TEST_CASES_GENERATED) &&
+ (generatedTestSuite == null)
+ ) {
throw IllegalArgumentException(
- "Test suite must not be provided when FeedbackCycleExecutionResult is not OK, got $generatedTestSuite",
+ "Test suite must be provided when FeedbackCycleExecutionResult is OK or NO_COMPILABLE_TEST_CASES_GENERATED (currently, ${executionResult.name}), got null",
)
}
}
}
+/**
+ * LLMWithFeedbackCycle class represents a feedback cycle for an LLM.
+ *
+ * @property report The `Report` instance used for storing generated tests.
+ * @property language The `SupportedLanguage` enum value representing the programming language used.
+ * @property initialPromptMessage The initial prompt message to start the feedback cycle.
+ * @property promptSizeReductionStrategy The `PromptSizeReductionStrategy` instance used for reducing the prompt size.
+ * @property testSuiteFilename The name of the file in which the test suite is saved in the result path.
+ * @property packageName The package name for the generated tests.
+ * @property resultPath The temporary path where all the generated tests and their Jacoco report are saved.
+ * @property buildPath All the directories where the compiled code of the project under test is saved.
+ * @property requestManager The `RequestManager` instance used for making LLM requests.
+ * @property testsAssembler The `TestsAssembler` instance used for assembling generated tests.
+ * @property testCompiler The `TestCompiler` instance used for compiling tests.
+ * @property testStorage The `TestsPersistentStorage` instance used for storing generated tests.
+ * @property testsPresenter The `TestsPresenter` instance used for presenting generated tests.
+ * @property indicator The `CustomProgressIndicator` instance used for tracking progress.
+ * @property requestsCountThreshold The threshold for the maximum number of requests in the feedback cycle.
+ * @property errorMonitor The `ErrorMonitor` instance used for monitoring errors.
+ */
class LLMWithFeedbackCycle(
private val report: Report,
+ private val language: SupportedLanguage,
private val initialPromptMessage: String,
private val promptSizeReductionStrategy: PromptSizeReductionStrategy,
- // filename in which the test suite is saved in result path
+ // filename in which the test suite is saved in the result path
private val testSuiteFilename: String,
private val packageName: String,
// temp path where all the generated tests and their jacoco report are saved
@@ -58,6 +91,7 @@ class LLMWithFeedbackCycle(
private val testsPresenter: TestsPresenter,
private val indicator: CustomProgressIndicator,
private val requestsCountThreshold: Int,
+ private val errorMonitor: ErrorMonitor = DefaultErrorMonitor(),
) {
enum class WarningType {
TEST_SUITE_PARSING_FAILED,
@@ -75,6 +109,9 @@ class LLMWithFeedbackCycle(
var executionResult = FeedbackCycleExecutionResult.OK
val compilableTestCases: MutableSet = mutableSetOf()
+ // collect imports from all responses
+ val imports: MutableSet = mutableSetOf()
+
var generatedTestSuite: TestSuiteGeneratedByLLM? = null
while (!generatedTestsArePassing) {
@@ -90,19 +127,32 @@ class LLMWithFeedbackCycle(
if (isLastIteration(requestsCount) && compilableTestCases.isEmpty()) {
executionResult = FeedbackCycleExecutionResult.NO_COMPILABLE_TEST_CASES_GENERATED
+ // record a report with parsable yet potentially
+ // non-compilable test cases stored in
+ // the generated test suite
+ // TODO: ensure generatedTestSuite is always non-null here
+ generatedTestSuite?.let { recordReport(report, it.testCases) }
break
}
// clearing test assembler's collected text on the previous attempts
testsAssembler.clear()
val response: LLMResponse = requestManager.request(
+ language = language,
prompt = nextPromptMessage,
indicator = indicator,
packageName = packageName,
testsAssembler = testsAssembler,
isUserFeedback = false,
+ errorMonitor,
)
+ // Process stopped checking
+ if (indicator.isCanceled()) {
+ executionResult = FeedbackCycleExecutionResult.CANCELED
+ break
+ }
+
when (response.errorCode) {
ResponseErrorCode.OK -> {
log.info { "Test suite generated successfully: ${response.testSuite!!}" }
@@ -115,11 +165,14 @@ class LLMWithFeedbackCycle(
continue
}
}
+
ResponseErrorCode.PROMPT_TOO_LONG -> {
if (promptSizeReductionStrategy.isReductionPossible()) {
nextPromptMessage = promptSizeReductionStrategy.reduceSizeAndGeneratePrompt()
/**
- * Current attempt does not count as a failure since it was rejected due to the prompt size exceeding the threshold
+ * The current attempt does not count as a failure
+ * since it was rejected due to the prompt size
+ * exceeding the threshold
*/
requestsCount--
continue
@@ -128,11 +181,13 @@ class LLMWithFeedbackCycle(
break
}
}
+
ResponseErrorCode.EMPTY_LLM_RESPONSE -> {
nextPromptMessage =
"You have provided an empty answer! Please, answer my previous question with the same formats"
continue
}
+
ResponseErrorCode.TEST_SUITE_PARSING_FAILURE -> {
onWarningCallback?.invoke(WarningType.TEST_SUITE_PARSING_FAILED)
log.info { "Cannot parse a test suite from the LLM response. LLM response: '$response'" }
@@ -144,6 +199,9 @@ class LLMWithFeedbackCycle(
generatedTestSuite = response.testSuite
+ // update imports list
+ imports.addAll(generatedTestSuite.imports)
+
// Process stopped checking
if (indicator.isCanceled()) {
executionResult = FeedbackCycleExecutionResult.CANCELED
@@ -157,12 +215,15 @@ class LLMWithFeedbackCycle(
generatedTestSuite.updateTestCases(compilableTestCases.toMutableList())
} else {
for (testCaseIndex in generatedTestSuite.testCases.indices) {
- val testCaseFilename = "${getClassWithTestCaseName(generatedTestSuite.testCases[testCaseIndex].name)}.java"
+ val testCaseFilename = when (language) {
+ SupportedLanguage.Java -> "${getClassWithTestCaseName(generatedTestSuite.testCases[testCaseIndex].name)}.java"
+ SupportedLanguage.Kotlin -> "${getClassWithTestCaseName(generatedTestSuite.testCases[testCaseIndex].name)}.kt"
+ }
val testCaseRepresentation = testsPresenter.representTestCase(generatedTestSuite, testCaseIndex)
val saveFilepath = testStorage.saveGeneratedTest(
- generatedTestSuite.packageString,
+ generatedTestSuite.packageName,
testCaseRepresentation,
resultPath,
testCaseFilename,
@@ -173,7 +234,7 @@ class LLMWithFeedbackCycle(
}
val generatedTestSuitePath: String = testStorage.saveGeneratedTest(
- generatedTestSuite.packageString,
+ generatedTestSuite.packageName,
testsPresenter.representTestSuite(generatedTestSuite),
resultPath,
testSuiteFilename,
@@ -201,33 +262,49 @@ class LLMWithFeedbackCycle(
// Compile the test file
indicator.setText("Compilation tests checking")
- val testCasesCompilationResult = testCompiler.compileTestCases(generatedTestCasesPaths, buildPath, testCases)
- val testSuiteCompilationResult = testCompiler.compileCode(File(generatedTestSuitePath).absolutePath, buildPath)
+ val testCasesCompilationResult =
+ testCompiler.compileTestCases(generatedTestCasesPaths, buildPath, testCases, resultPath)
+ val testSuiteCompilationResult =
+ testCompiler.compileCode(File(generatedTestSuitePath).absolutePath, buildPath, resultPath)
// saving the compilable test cases
compilableTestCases.addAll(testCasesCompilationResult.compilableTestCases)
+ // Process stopped checking
+ if (indicator.isCanceled()) {
+ executionResult = FeedbackCycleExecutionResult.CANCELED
+ break
+ }
+
if (!testCasesCompilationResult.allTestCasesCompilable && !isLastIteration(requestsCount)) {
log.info { "Non-compilable test suite: \n${testsPresenter.representTestSuite(generatedTestSuite!!)}" }
onWarningCallback?.invoke(WarningType.COMPILATION_ERROR_OCCURRED)
- nextPromptMessage = "I cannot compile the tests that you provided. The error is:\n${testSuiteCompilationResult.second}\n Fix this issue in the provided tests.\nGenerate public classes and public methods. Response only a code with tests between ```, do not provide any other text."
+ nextPromptMessage = """
+ I cannot compile the tests that you provided. The error is:
+ ```
+ ${testSuiteCompilationResult.executionMessage}
+ ```
+ Fix this issue in the provided tests.\nGenerate public classes and public methods. Response only a code with tests between ```, do not provide any other text.
+ """.trimIndent()
log.info { nextPromptMessage }
continue
}
log.info { "Result is compilable" }
+ generatedTestSuite.imports.addAll(imports)
+
generatedTestsArePassing = true
- for (index in testCases.indices) {
- report.testCaseList[index] = TestCase(index, testCases[index].name, testCases[index].toString(), setOf())
- }
+ recordReport(report, testCases)
}
// test suite must not be provided upon failed execution
- if (executionResult != FeedbackCycleExecutionResult.OK) {
+ if (executionResult != FeedbackCycleExecutionResult.OK &&
+ executionResult != FeedbackCycleExecutionResult.NO_COMPILABLE_TEST_CASES_GENERATED
+ ) {
generatedTestSuite = null
}
@@ -238,5 +315,17 @@ class LLMWithFeedbackCycle(
)
}
+ /**
+ * Records the generated test cases in the given report.
+ *
+ * @param report The report object to store the test cases in.
+ * @param testCases The list of test cases generated by LLM.
+ */
+ private fun recordReport(report: Report, testCases: MutableList) {
+ for ((index, test) in testCases.withIndex()) {
+ report.testCaseList[index] = TestCase(index, test.name, test.toString(), setOf())
+ }
+ }
+
private fun isLastIteration(requestsCount: Int): Boolean = requestsCount > requestsCountThreshold
}
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt
index c6aa2b220..4cf5956bf 100644
--- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt
@@ -1,13 +1,50 @@
package org.jetbrains.research.testspark.core.generation.llm
import org.jetbrains.research.testspark.core.generation.llm.network.RequestManager
+import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor
+import org.jetbrains.research.testspark.core.monitor.ErrorMonitor
import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator
+import org.jetbrains.research.testspark.core.test.SupportedLanguage
import org.jetbrains.research.testspark.core.test.TestsAssembler
import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM
+import org.jetbrains.research.testspark.core.utils.javaPackagePattern
+import org.jetbrains.research.testspark.core.utils.kotlinPackagePattern
import java.util.Locale
// TODO: find a better place for the below functions
+/**
+ * Retrieves the package declaration from the given test suite code for any language.
+ *
+ * @param testSuiteCode The generated code of the test suite.
+ * @return The package name extracted from the test suite code, or an empty string if no package declaration was found.
+ */
+fun getPackageFromTestSuiteCode(testSuiteCode: String?, language: SupportedLanguage): String {
+ testSuiteCode ?: return ""
+ return when (language) {
+ SupportedLanguage.Kotlin -> kotlinPackagePattern.find(testSuiteCode)?.groups?.get(1)?.value.orEmpty()
+ SupportedLanguage.Java -> javaPackagePattern.find(testSuiteCode)?.groups?.get(1)?.value.orEmpty()
+ }
+}
+
+/**
+ * Retrieves the imports code from a given test suite code.
+ *
+ * @param testSuiteCode The test suite code from which to extract the imports code. If null, an empty string is returned.
+ * @param classFQN The fully qualified name of the class to be excluded from the imports code. It will not be included in the result.
+ * @return The imports code extracted from the test suite code. If no imports are found or the result is empty after filtering, an empty string is returned.
+ */
+fun getImportsCodeFromTestSuiteCode(testSuiteCode: String?, classFQN: String?): MutableSet {
+ testSuiteCode ?: return mutableSetOf()
+ return testSuiteCode.replace("\r\n", "\n").split("\n").asSequence()
+ .filter { it.contains("^import".toRegex()) }
+ .filterNot { it.contains("evosuite".toRegex()) }
+ .filterNot { it.contains("RunWith".toRegex()) }
+ // classFQN will be null for the top level function
+ .filterNot { classFQN != null && it.contains(classFQN.toRegex()) }
+ .toMutableSet()
+}
+
/**
* Returns the generated class name for a given test case.
*
@@ -36,31 +73,33 @@ fun getClassWithTestCaseName(testCaseName: String): String {
* @return instance of TestSuiteGeneratedByLLM if the generated test cases are parsable, otherwise null.
*/
fun executeTestCaseModificationRequest(
+ language: SupportedLanguage,
testCase: String,
task: String,
indicator: CustomProgressIndicator,
requestManager: RequestManager,
testsAssembler: TestsAssembler,
+ errorMonitor: ErrorMonitor = DefaultErrorMonitor(),
): TestSuiteGeneratedByLLM? {
// Update Token information
- val prompt = "For this test:\n ```\n $testCase\n ```\nPerform the following task: $task"
-
- var packageName = ""
- testCase.split("\n")[0].let {
- if (it.startsWith("package")) {
- packageName = it
- .removePrefix("package ")
- .removeSuffix(";")
- .trim()
- }
+ val prompt = buildString {
+ append("For this test:\n ```\n ")
+ append(testCase)
+ append("\n```\nGenerate a SINGLE test method. Do not change class and method names.")
+ append("\nPerform the following task:\n")
+ append(task)
}
+ val packageName = getPackageFromTestSuiteCode(testCase, language)
+
val response = requestManager.request(
+ language,
prompt,
indicator,
packageName,
testsAssembler,
isUserFeedback = true,
+ errorMonitor,
)
return response.testSuite
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt
index 104d79fb4..441e51231 100644
--- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt
@@ -1,8 +1,13 @@
package org.jetbrains.research.testspark.core.generation.llm.network
import io.github.oshai.kotlinlogging.KotlinLogging
+import org.jetbrains.research.testspark.core.data.ChatAssistantMessage
import org.jetbrains.research.testspark.core.data.ChatMessage
+import org.jetbrains.research.testspark.core.data.ChatUserMessage
+import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor
+import org.jetbrains.research.testspark.core.monitor.ErrorMonitor
import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator
+import org.jetbrains.research.testspark.core.test.SupportedLanguage
import org.jetbrains.research.testspark.core.test.TestsAssembler
abstract class RequestManager(var token: String) {
@@ -26,20 +31,21 @@ abstract class RequestManager(var token: String) {
* @return the generated TestSuite, or null and prompt message
*/
open fun request(
+ language: SupportedLanguage,
prompt: String,
indicator: CustomProgressIndicator,
packageName: String,
testsAssembler: TestsAssembler,
isUserFeedback: Boolean = false,
+ errorMonitor: ErrorMonitor = DefaultErrorMonitor(), // The plugin for other IDEs can send LLM requests without passing an errorMonitor
): LLMResponse {
// save the prompt in chat history
- // TODO: make role to be an enum class
- chatHistory.add(ChatMessage("user", prompt))
+ chatHistory.add(ChatUserMessage(prompt))
// Send Request to LLM
log.info { "Sending Request..." }
- val sendResult = send(prompt, indicator, testsAssembler)
+ val sendResult = send(prompt, indicator, testsAssembler, errorMonitor)
if (sendResult == SendResult.PROMPT_TOO_LONG) {
return LLMResponse(ResponseErrorCode.PROMPT_TOO_LONG, null)
@@ -51,27 +57,28 @@ abstract class RequestManager(var token: String) {
}
return when (isUserFeedback) {
- true -> processUserFeedbackResponse(testsAssembler, packageName)
- false -> processResponse(testsAssembler, packageName)
+ true -> processUserFeedbackResponse(testsAssembler, packageName, language)
+ false -> processResponse(testsAssembler, packageName, language)
}
}
open fun processResponse(
testsAssembler: TestsAssembler,
packageName: String,
+ language: SupportedLanguage,
): LLMResponse {
// save the full response in the chat history
val response = testsAssembler.getContent()
log.info { "The full response: \n $response" }
- chatHistory.add(ChatMessage("assistant", response))
+ chatHistory.add(ChatAssistantMessage(response))
- // check if response is empty
+ // check if the response is empty
if (response.isEmpty() || response.isBlank()) {
return LLMResponse(ResponseErrorCode.EMPTY_LLM_RESPONSE, null)
}
- val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite(packageName)
+ val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite()
return if (testSuiteGeneratedByLLM == null) {
LLMResponse(ResponseErrorCode.TEST_SUITE_PARSING_FAILURE, null)
@@ -84,22 +91,24 @@ abstract class RequestManager(var token: String) {
prompt: String,
indicator: CustomProgressIndicator,
testsAssembler: TestsAssembler,
+ errorMonitor: ErrorMonitor = DefaultErrorMonitor(),
): SendResult
open fun processUserFeedbackResponse(
testsAssembler: TestsAssembler,
packageName: String,
+ language: SupportedLanguage,
): LLMResponse {
val response = testsAssembler.getContent()
log.info { "The full response: \n $response" }
- // check if response is empty
+ // check if the response is empty
if (response.isEmpty() || response.isBlank()) {
return LLMResponse(ResponseErrorCode.EMPTY_LLM_RESPONSE, null)
}
- val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite(packageName)
+ val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite()
return if (testSuiteGeneratedByLLM == null) {
LLMResponse(ResponseErrorCode.TEST_SUITE_PARSING_FAILURE, null)
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt
index 9691a5321..7040f3e30 100644
--- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt
@@ -1,136 +1,169 @@
package org.jetbrains.research.testspark.core.generation.llm.prompt
+import org.jetbrains.research.testspark.core.data.ClassType
import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.ClassRepresentation
+import java.util.EnumMap
+
+/**
+ * Builds prompts by populating a template with keyword values
+ * and validates that all mandatory keywords are provided.
+ *
+ * @property promptTemplate The template string for the prompt.
+ */
+class PromptBuilder(private val promptTemplate: String) {
+ private val insertedKeywordValues: EnumMap = EnumMap(PromptKeyword::class.java)
+
+ // collect all the keywords present in the prompt template
+ private val templateKeywords: List = buildList {
+ for (keyword in PromptKeyword.entries) {
+ if (promptTemplate.contains(keyword.variable)) {
+ add(keyword)
+ }
+ }
+ }
-internal class PromptBuilder(private var prompt: String) {
- private fun isPromptValid(
- keyword: PromptKeyword,
- prompt: String,
- ): Boolean {
- val keywordText = keyword.text
- val isMandatory = keyword.mandatory
+ /**
+ * Builds the prompt by populating the template with the inserted values
+ * and validating that all mandatory keywords were provided.
+ *
+ * @return The built prompt.
+ * @throws IllegalStateException if a mandatory keyword is not present in the template.
+ */
+ fun build(): String {
+ var populatedPrompt = promptTemplate
+
+ // populate the template with the inserted values
+ for ((keyword, value) in insertedKeywordValues.entries) {
+ populatedPrompt = populatedPrompt.replace(keyword.variable, value, ignoreCase = false)
+ }
- return (prompt.contains(keywordText) || !isMandatory)
+ // validate that all mandatory keywords were provided
+ for (keyword in templateKeywords) {
+ if (!insertedKeywordValues.contains(keyword) && keyword.mandatory) {
+ throw IllegalStateException("The prompt must contain ${keyword.name} keyword")
+ }
+ }
+
+ return populatedPrompt
}
- fun insertLanguage(language: String) = apply {
- if (isPromptValid(PromptKeyword.LANGUAGE, prompt)) {
- val keyword = "\$${PromptKeyword.LANGUAGE.text}"
- prompt = prompt.replace(keyword, language, ignoreCase = false)
- } else {
- throw IllegalStateException("The prompt must contain ${PromptKeyword.LANGUAGE.text}")
+ /**
+ * Inserts a keyword and its corresponding value into the prompt template.
+ * If the keyword is marked as mandatory and not present in the template, an IllegalArgumentException is thrown.
+ *
+ * @param keyword The keyword to be inserted.
+ * @param value The value corresponding to the keyword.
+ * @throws IllegalArgumentException if a mandatory keyword is not present in the template.
+ */
+ private fun insert(keyword: PromptKeyword, value: String) {
+ if (!templateKeywords.contains(keyword) && keyword.mandatory) {
+ throw IllegalArgumentException("Prompt template does not contain mandatory ${keyword.name}")
}
+ insertedKeywordValues[keyword] = value
+ }
+
+ fun insertLanguage(language: String) = apply {
+ insert(PromptKeyword.LANGUAGE, language)
}
fun insertName(classDisplayName: String) = apply {
- if (isPromptValid(PromptKeyword.NAME, prompt)) {
- val keyword = "\$${PromptKeyword.NAME.text}"
- prompt = prompt.replace(keyword, classDisplayName, ignoreCase = false)
- } else {
- throw IllegalStateException("The prompt must contain ${PromptKeyword.NAME.text}")
- }
+ insert(PromptKeyword.NAME, classDisplayName)
}
fun insertTestingPlatform(testingPlatformName: String) = apply {
- if (isPromptValid(PromptKeyword.TESTING_PLATFORM, prompt)) {
- val keyword = "\$${PromptKeyword.TESTING_PLATFORM.text}"
- prompt = prompt.replace(keyword, testingPlatformName, ignoreCase = false)
- } else {
- throw IllegalStateException("The prompt must contain ${PromptKeyword.TESTING_PLATFORM.text}")
- }
+ insert(PromptKeyword.TESTING_PLATFORM, testingPlatformName)
}
fun insertMockingFramework(mockingFrameworkName: String) = apply {
- if (isPromptValid(PromptKeyword.MOCKING_FRAMEWORK, prompt)) {
- val keyword = "\$${PromptKeyword.MOCKING_FRAMEWORK.text}"
- prompt = prompt.replace(keyword, mockingFrameworkName, ignoreCase = false)
- } else {
- throw IllegalStateException("The prompt must contain ${PromptKeyword.MOCKING_FRAMEWORK.text}")
- }
+ insert(PromptKeyword.MOCKING_FRAMEWORK, mockingFrameworkName)
}
- fun insertCodeUnderTest(classFullText: String, classesToTest: List) = apply {
- if (isPromptValid(PromptKeyword.CODE, prompt)) {
- val keyword = "\$${PromptKeyword.CODE.text}"
- var fullText = "```\n${classFullText}\n```\n"
-
- for (i in 2..classesToTest.size) {
- val subClass = classesToTest[i - 2]
- val superClass = classesToTest[i - 1]
-
- fullText += "${subClass.qualifiedName} extends ${superClass.qualifiedName}. " +
- "The source code of ${superClass.qualifiedName} is:\n```\n${superClass.fullText}\n" +
- "```\n"
- }
- prompt = prompt.replace(keyword, fullText, ignoreCase = false)
- } else {
- throw IllegalStateException("The prompt must contain ${PromptKeyword.CODE.text}")
+ /**
+ * Inserts the code under test and its related superclass code into the prompt template.
+ *
+ * @param codeFullText The full text of the code under test.
+ * @param classesToTest The list of ClassRepresentation objects representing the classes involved in the code under test.
+ * @return The modified prompt builder.
+ */
+ fun insertCodeUnderTest(codeFullText: String, classesToTest: List) = apply {
+ val fullText = StringBuilder("```\n${codeFullText}\n```\n")
+
+ for (i in 2..classesToTest.size) {
+ val subClass = classesToTest[i - 2]
+ val superClass = classesToTest[i - 1]
+
+ fullText.append("${subClass.qualifiedName} extends ${superClass.qualifiedName}. ")
+ .append("The source code of ${superClass.qualifiedName} is:\n```\n${superClass.fullText}\n")
+ .append("```\n")
}
+
+ insert(PromptKeyword.CODE, fullText.toString())
}
fun insertMethodsSignatures(interestingClasses: List) = apply {
- val keyword = "\$${PromptKeyword.METHODS.text}"
+ val fullText = StringBuilder()
- if (isPromptValid(PromptKeyword.METHODS, prompt)) {
- var fullText = ""
- if (interestingClasses.isNotEmpty()) {
- fullText += "Here are some information about other methods and classes used by the class under test. Only use them for creating objects, not your own ideas.\n"
- }
- for (interestingClass in interestingClasses) {
- if (interestingClass.qualifiedName.startsWith("java")) {
- continue
- }
+ if (interestingClasses.isNotEmpty()) {
+ fullText.append("Here are some information about other methods and classes used by the class under test. Only use them for creating objects, not your own ideas.\n")
+ }
- fullText += "=== methods in ${interestingClass.qualifiedName}:\n"
+ for (interestingClass in interestingClasses) {
+ if (interestingClass.qualifiedName.startsWith("java") ||
+ interestingClass.qualifiedName.startsWith("kotlin")
+ ) {
+ continue
+ }
- for (method in interestingClass.allMethods) {
- // Skip java methods
- // TODO: checks for java methods should be done by a caller to make
- // this class as abstract and language agnostic as possible.
- if (method.containingClassQualifiedName.startsWith("java")) {
- continue
- }
+ fullText.append("=== methods in ${interestingClass.qualifiedName}:\n")
- fullText += " - ${method.signature}\n"
+ for (method in interestingClass.allMethods) {
+ // TODO: checks for java methods should be done by a caller to make
+ // this class as abstract and language agnostic as possible.
+ if (method.containingClassQualifiedName.startsWith("java") ||
+ method.containingClassQualifiedName.startsWith("kotlin")
+ ) {
+ continue
}
+
+ fullText.append(" - ${method.signature}\n")
}
- prompt = prompt.replace(keyword, fullText, ignoreCase = false)
- } else {
- throw IllegalStateException("The prompt must contain ${PromptKeyword.METHODS.text}")
}
+
+ insert(PromptKeyword.METHODS, fullText.toString())
}
fun insertPolymorphismRelations(
polymorphismRelations: Map>,
) = apply {
- val keyword = "\$${PromptKeyword.POLYMORPHISM.text}"
- if (isPromptValid(PromptKeyword.POLYMORPHISM, prompt)) {
- var fullText = ""
+ val fullText = StringBuilder()
+
+ if (polymorphismRelations.isNotEmpty()) {
+ fullText.append("Use the following polymorphic relationships of classes present in the project. Use them for instantiation when necessary. Do not mock classes if an instantiation of a sub-class is applicable.\n\n")
+ }
- polymorphismRelations.forEach { entry ->
- for (currentSubClass in entry.value) {
- fullText += "${currentSubClass.qualifiedName} is a sub-class of ${entry.key.qualifiedName}.\n"
+ for (entry in polymorphismRelations) {
+ for (currentSubClass in entry.value) {
+ val subClassTypeName = when (currentSubClass.classType) {
+ ClassType.INTERFACE -> "an interface implementing"
+ ClassType.ABSTRACT_CLASS -> "an abstract sub-class of"
+ ClassType.CLASS -> "a sub-class of"
+ ClassType.DATA_CLASS -> "a sub data class of"
+ ClassType.INLINE_VALUE_CLASS -> "a sub inline value class class of"
+ ClassType.OBJECT -> "a sub object of"
}
+ fullText.append("${currentSubClass.qualifiedName} is $subClassTypeName ${entry.key.qualifiedName}.\n")
}
- prompt = prompt.replace(keyword, fullText, ignoreCase = false)
- } else {
- throw IllegalStateException("The prompt must contain ${PromptKeyword.POLYMORPHISM.text}")
}
+
+ insert(PromptKeyword.POLYMORPHISM, fullText.toString())
}
fun insertTestSample(testSamplesCode: String) = apply {
- val keyword = "\$${PromptKeyword.TEST_SAMPLE.text}"
-
- if (isPromptValid(PromptKeyword.TEST_SAMPLE, prompt)) {
- var fullText = testSamplesCode
- if (fullText.isNotBlank()) {
- fullText = "Use this test samples:\n$fullText\n"
- }
- prompt = prompt.replace(keyword, fullText, ignoreCase = false)
- } else {
- throw IllegalStateException("The prompt must contain ${PromptKeyword.TEST_SAMPLE.text}")
+ var fullText = testSamplesCode
+ if (fullText.isNotBlank()) {
+ fullText = "Use this test samples:\n$fullText\n"
}
- }
- fun build(): String = prompt
+ insert(PromptKeyword.TEST_SAMPLE, fullText)
+ }
}
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt
index 384510161..eb40c9ea9 100644
--- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt
@@ -19,7 +19,7 @@ class PromptGenerator(
fun generatePromptForClass(interestingClasses: List, testSamplesCode: String): String {
val prompt = PromptBuilder(promptTemplates.classPrompt)
.insertLanguage(context.promptConfiguration.desiredLanguage)
- .insertName(context.cut.qualifiedName)
+ .insertName(context.cut!!.qualifiedName)
.insertTestingPlatform(context.promptConfiguration.desiredTestingPlatform)
.insertMockingFramework(context.promptConfiguration.desiredMockingFramework)
.insertCodeUnderTest(context.cut.fullText, context.classesToTest)
@@ -28,6 +28,7 @@ class PromptGenerator(
.insertTestSample(testSamplesCode)
.build()
+ println("Prompt: $prompt")
return prompt
}
@@ -43,10 +44,13 @@ class PromptGenerator(
method: MethodRepresentation,
interestingClassesFromMethod: List,
testSamplesCode: String,
+ packageName: String,
): String {
+ val methodQualifiedName = context.cut?.let { "${it.qualifiedName}.${method.name}" } ?: "$packageName.${method.name}"
+
val prompt = PromptBuilder(promptTemplates.methodPrompt)
.insertLanguage(context.promptConfiguration.desiredLanguage)
- .insertName("${context.cut.qualifiedName}.${method.name}")
+ .insertName(methodQualifiedName)
.insertTestingPlatform(context.promptConfiguration.desiredTestingPlatform)
.insertMockingFramework(context.promptConfiguration.desiredMockingFramework)
.insertCodeUnderTest(method.text, context.classesToTest)
@@ -59,8 +63,7 @@ class PromptGenerator(
}
/**
- * Generates a prompt for a given line under test.
- * It accepts the code of a line under test, a representation of the method that contains the line, and the set of interesting classes (e.g., the containing class of the method, classes listed in parameters of the method and constructors of the containing class).
+ * Generates a prompt for a given line under test using a surrounding method/function as a context.
*
* @param lineUnderTest The source code of the line to be tested.
* @param method The representation of the method that contains the line.
@@ -73,13 +76,25 @@ class PromptGenerator(
method: MethodRepresentation,
interestingClassesFromMethod: List,
testSamplesCode: String,
+ packageName: String,
): String {
+ val codeUnderTest = if (context.cut != null) {
+ // `method` is a method within a class
+ buildCutDeclaration(context.cut, method)
+ } else {
+ // `method` is a top-level function
+ method.text
+ }
+
+ val methodQualifiedName = context.cut?.let { "${it.qualifiedName}.${method.name}" } ?: "$packageName.${method.name}"
+ val lineReference = "`${lineUnderTest.trim()}` within `$methodQualifiedName`"
+
val prompt = PromptBuilder(promptTemplates.linePrompt)
.insertLanguage(context.promptConfiguration.desiredLanguage)
- .insertName(lineUnderTest.trim())
+ .insertName(lineReference)
.insertTestingPlatform(context.promptConfiguration.desiredTestingPlatform)
.insertMockingFramework(context.promptConfiguration.desiredMockingFramework)
- .insertCodeUnderTest(method.text, context.classesToTest)
+ .insertCodeUnderTest(codeUnderTest, context.classesToTest)
.insertMethodsSignatures(interestingClassesFromMethod)
.insertPolymorphismRelations(context.polymorphismRelations)
.insertTestSample(testSamplesCode)
@@ -87,4 +102,104 @@ class PromptGenerator(
return prompt
}
+
+ /**
+ * Generates a prompt for a given line under test using CUT as a context.
+ *
+ * **Contract: `context.cut` is not `null`.**
+ *
+ * @param lineUnderTest The source code of the line to be tested.
+ * @param interestingClasses The list of `ClassRepresentation` objects related to the line under test.
+ * @param testSamplesCode The code snippet that serves as test samples.
+ * @return The generated prompt as `String`.
+ * @throws IllegalStateException If any of the required keywords are missing in the prompt template.
+ */
+ fun generatePromptForLine(
+ lineUnderTest: String,
+ interestingClasses: List,
+ testSamplesCode: String,
+ ): String {
+ val lineReference = "`${lineUnderTest.trim()}` within `${context.cut!!.qualifiedName}`"
+
+ val prompt = PromptBuilder(promptTemplates.linePrompt)
+ .insertLanguage(context.promptConfiguration.desiredLanguage)
+ .insertName(lineReference)
+ .insertTestingPlatform(context.promptConfiguration.desiredTestingPlatform)
+ .insertMockingFramework(context.promptConfiguration.desiredMockingFramework)
+ .insertCodeUnderTest(context.cut.fullText, context.classesToTest)
+ .insertMethodsSignatures(interestingClasses)
+ .insertPolymorphismRelations(context.polymorphismRelations)
+ .insertTestSample(testSamplesCode)
+ .build()
+
+ return prompt
+ }
+}
+
+/**
+ * Builds a cut declaration with constructor declarations and a method under test.
+ *
+ * Example when there exist non-default constructors:
+ * ```
+ * [Instruction]: Use the following constructor declarations to instantiate `org.example.CalcKotlin` and call the method under test `add`:
+ *
+ * Constructors of the class org.example.CalcKotlin:
+ * === (val value: Int)
+ * === constructor(c: Int, d: Int) : this(c+d)
+ *
+ * Method:
+ * fun add(a: Int, b: Int): Int {
+ * return a + b
+ * }
+ * ```
+ *
+ * Example when only a default constructor exists:
+ * ```
+ * [Instruction]: Use a default constructor with zero arguments to instantiate `Calc` and call the method under test `sum`:
+ *
+ * Constructors of the class Calc:
+ * === Default constructor
+ *
+ * Method:
+ * public int sum(int a, int b) {
+ * return a + b;
+ * }
+ * ```
+ *
+ * @param cut The `ClassRepresentation` object representing the class to be instantiated.
+ * @param method The `MethodRepresentation` object representing the method under test.
+ * @return A formatted `String` representing the cut declaration, containing constructor declarations and method text.
+ */
+private fun buildCutDeclaration(cut: ClassRepresentation, method: MethodRepresentation): String {
+ val instruction = buildString {
+ val constructorToUse = if (cut.constructorSignatures.isEmpty()) {
+ "a default constructor with zero arguments"
+ } else {
+ "the following constructor declarations"
+ }
+ append("Use $constructorToUse to instantiate `${cut.qualifiedName}` and call the method under test `${method.name}`")
+ }
+
+ val classType = cut.classType.representation
+
+ val constructorDeclarations = buildString {
+ appendLine("Constructors of the $classType ${cut.qualifiedName}:")
+ if (cut.constructorSignatures.isEmpty()) {
+ appendLine("=== Default constructor")
+ }
+ for (constructor in cut.constructorSignatures) {
+ appendLine("\t=== $constructor")
+ }
+ }.trim()
+
+ val cutDeclaration = buildString {
+ appendLine("[Instruction]: $instruction:")
+ appendLine()
+ appendLine(constructorDeclarations)
+ appendLine()
+ appendLine("Method:")
+ appendLine(method.text)
+ }.trim()
+
+ return cutDeclaration
}
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptKeyword.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptKeyword.kt
index f1d46eece..98a968d23 100644
--- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptKeyword.kt
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptKeyword.kt
@@ -1,26 +1,24 @@
package org.jetbrains.research.testspark.core.generation.llm.prompt
-enum class PromptKeyword(val text: String, val description: String, val mandatory: Boolean) {
- NAME("NAME", "The name of the code under test (Class name, method name, line number)", true),
- CODE("CODE", "The code under test (Class, method, or line)", true),
- LANGUAGE("LANGUAGE", "Programming language of the project under test (only Java supported at this point)", true),
+enum class PromptKeyword(val description: String, val mandatory: Boolean) {
+ NAME("The name of the code under test (Class name, method name, line number)", true),
+ CODE("The code under test (Class, method, or line)", true),
+ LANGUAGE("Programming language of the project under test (only Java supported at this point)", true),
TESTING_PLATFORM(
- "TESTING_PLATFORM",
"Testing platform used in the project (Only JUnit 4 is supported at this point)",
true,
),
MOCKING_FRAMEWORK(
- "MOCKING_FRAMEWORK",
"Mock framework that can be used in generated test (Only Mockito is supported at this point)",
false,
),
- METHODS("METHODS", "Signature of methods used in the code under tests", false),
- POLYMORPHISM("POLYMORPHISM", "Polymorphism relations between classes involved in the code under test", false),
- TEST_SAMPLE("TEST_SAMPLE", "Test samples for LLM for test generation", false),
+ METHODS("Signature of methods used in the code under tests", false),
+ POLYMORPHISM("Polymorphism relations between classes involved in the code under test", false),
+ TEST_SAMPLE("Test samples for LLM for test generation", false),
;
fun getOffsets(prompt: String): Pair? {
- val textToHighlight = "\$$text"
+ val textToHighlight = variable
if (!prompt.contains(textToHighlight)) {
return null
}
@@ -29,4 +27,13 @@ enum class PromptKeyword(val text: String, val description: String, val mandator
val endOffset = startOffset + textToHighlight.length
return Pair(startOffset, endOffset)
}
+
+ /**
+ * Returns a keyword's text (i.e., its name) with a `$` attached at the start.
+ *
+ * Inside a prompt template every keyword is used as `$KEYWORD_NAME`.
+ * Therefore, this property encapsulates the keyword's representation in a prompt.
+ */
+ val variable: String
+ get() = "\$${this.name}"
}
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt
index eb5f0d825..a98abe4ad 100644
--- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt
@@ -1,5 +1,7 @@
package org.jetbrains.research.testspark.core.generation.llm.prompt.configuration
+import org.jetbrains.research.testspark.core.data.ClassType
+
/**
* Represents the context for generating prompts for generating unit tests.
*
@@ -8,7 +10,10 @@ package org.jetbrains.research.testspark.core.generation.llm.prompt.configuratio
* @property polymorphismRelations A map where the key represents a ClassRepresentation object and the value is a list of its detected subclasses.
*/
data class PromptGenerationContext(
- val cut: ClassRepresentation,
+ /**
+ * The cut is null when we want to generate tests for top-level function
+ */
+ val cut: ClassRepresentation?,
val classesToTest: List,
val polymorphismRelations: Map>,
val promptConfiguration: PromptConfiguration,
@@ -38,7 +43,9 @@ data class PromptConfiguration(
data class ClassRepresentation(
val qualifiedName: String,
val fullText: String,
+ val constructorSignatures: List,
val allMethods: List,
+ val classType: ClassType,
)
/**
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/monitor/ErrorMonitor.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/monitor/ErrorMonitor.kt
new file mode 100644
index 000000000..dc7891f5c
--- /dev/null
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/monitor/ErrorMonitor.kt
@@ -0,0 +1,72 @@
+package org.jetbrains.research.testspark.core.monitor
+
+/**
+ * This interface is used for contract adherence in error monitoring.
+ */
+interface ErrorMonitor {
+
+ /**
+ * Notifies when an error has occurred. If this function is called,
+ * it means an error has occurred and should return true.
+ *
+ * @return Boolean
+ */
+ fun notifyErrorOccurrence(): Boolean
+
+ /**
+ * Checks if an error has occurred. If an error has occurred, this
+ * function should return true; otherwise, it should return false.
+ *
+ * @return Boolean
+ */
+ fun hasErrorOccurred(): Boolean
+
+ /**
+ * Clears the status of the error occurrence. Calling this function should
+ * reset the error state to a state as if no error had occurred.
+ */
+ fun clear()
+}
+
+/**
+ * An abstract ErrorMonitor that provides basic implementation of ErrorMonitor.
+ * This class includes default behavior for hasErrorOccurred() and clear() methods.
+ * It also consists of 'errorOccurred' variable to keep track of error occurrence.
+ */
+abstract class AbstractErrorMonitor : ErrorMonitor {
+ protected var errorOccurred: Boolean = false
+
+ /**
+ * Returns the current state of error occurrence.
+ * @return errorOccurred - a Boolean value representing whether an error has occurred or not.
+ */
+ override fun hasErrorOccurred(): Boolean {
+ return errorOccurred
+ }
+
+ /**
+ * Resets the state of error occurrence by setting 'errorOccurred' to false.
+ */
+ override fun clear() {
+ errorOccurred = false
+ }
+}
+
+/**
+ * A specific implementation of AbstractErrorMonitor that includes behavior for notifyErrorOccurrence() method.
+ * If an error has already occurred, it returns false, otherwise it sets 'errorOccurred' to true and returns true.
+ */
+class DefaultErrorMonitor : AbstractErrorMonitor() {
+
+ /**
+ * Handles the case when an error occurrence has been notified.
+ * If an error has already occurred, it ignores the notification and returns false.
+ * If an error has not occurred yet, it sets 'errorOccurred' to true and return true.
+ * @return Boolean value indicating whether the error notification has been handled or not.
+ */
+ override fun notifyErrorOccurrence(): Boolean {
+ if (errorOccurred) return false
+ errorOccurred = true
+ return true
+ }
+}
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/progress/CustomProgressIndicator.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/progress/CustomProgressIndicator.kt
index 37d531e41..c451a4d74 100644
--- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/progress/CustomProgressIndicator.kt
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/progress/CustomProgressIndicator.kt
@@ -11,4 +11,5 @@ interface CustomProgressIndicator {
fun isCanceled(): Boolean
fun start()
fun stop()
+ fun isRunning(): Boolean
}
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/SupportedLanguage.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/SupportedLanguage.kt
new file mode 100644
index 000000000..4b4de90c8
--- /dev/null
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/SupportedLanguage.kt
@@ -0,0 +1,8 @@
+package org.jetbrains.research.testspark.core.test
+
+/**
+ * Language ID string should be the same as the language name in com.intellij.lang.Language
+ */
+enum class SupportedLanguage(val languageId: String) {
+ Java("JAVA"), Kotlin("kotlin")
+}
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestBodyPrinter.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestBodyPrinter.kt
new file mode 100644
index 000000000..450400ac3
--- /dev/null
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestBodyPrinter.kt
@@ -0,0 +1,21 @@
+package org.jetbrains.research.testspark.core.test
+
+import org.jetbrains.research.testspark.core.test.data.TestLine
+
+interface TestBodyPrinter {
+ /**
+ * Generates a test body as a string based on the provided parameters.
+ *
+ * @param testInitiatedText A string containing the upper part of the test case.
+ * @param lines A mutable list of `TestLine` objects representing the lines of the test body.
+ * @param throwsException The exception type that the test function throws, if any.
+ * @param name The name of the test function.
+ * @return A string representing the complete test body.
+ */
+ fun printTestBody(
+ testInitiatedText: String,
+ lines: MutableList,
+ throwsException: String,
+ name: String,
+ ): String
+}
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt
index bc4d40617..3d85f15c1 100644
--- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt
@@ -1,43 +1,46 @@
package org.jetbrains.research.testspark.core.test
-import io.github.oshai.kotlinlogging.KotlinLogging
import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM
-import org.jetbrains.research.testspark.core.utils.CommandLineRunner
import org.jetbrains.research.testspark.core.utils.DataFilesUtil
-import java.io.File
data class TestCasesCompilationResult(
val allTestCasesCompilable: Boolean,
val compilableTestCases: MutableSet,
)
-/**
- * TestCompiler is a class that is responsible for compiling generated test cases using the proper javac.
- * It provides methods for compiling test cases and code files.
- */
-open class TestCompiler(
- private val javaHomeDirectoryPath: String,
- private val libPaths: List,
- private val junitLibPaths: List,
+data class ExecutionResult(
+ val exitCode: Int,
+ val executionMessage: String,
) {
- private val log = KotlinLogging.logger { this::class.java }
+ fun isSuccessful(): Boolean = exitCode == 0
+}
+
+abstract class TestCompiler(libPaths: List, junitLibPaths: List) {
+ val separator = DataFilesUtil.classpathSeparator
+ val dependencyLibPath = libPaths.joinToString(separator.toString())
+ val junitPath = junitLibPaths.joinToString(separator.toString())
+ val commonPath = "$junitPath${separator}$dependencyLibPath$separator"
/**
- * Compiles the generated files with test cases using the proper javac.
+ * Compiles a list of test cases and returns the compilation result.
*
- * @return true if all the provided test cases are successfully compiled,
- * otherwise returns false.
+ * @param generatedTestCasesPaths A list of file paths where the generated test cases are located.
+ * @param buildPath All the directories where the compiled code of the project under test is saved. This path is used as a classpath to run each test case.
+ * @param testCases A mutable list of `TestCaseGeneratedByLLM` objects representing the test cases to be compiled.
+ * @param workingDir The path of the directory that contains package directories of the code to compile
+ * @return A `TestCasesCompilationResult` object containing the overall compilation success status and a set of compilable test cases.
*/
fun compileTestCases(
generatedTestCasesPaths: List,
buildPath: String,
testCases: MutableList,
+ workingDir: String,
): TestCasesCompilationResult {
var allTestCasesCompilable = true
val compilableTestCases: MutableSet = mutableSetOf()
for (index in generatedTestCasesPaths.indices) {
- val compilable = compileCode(generatedTestCasesPaths[index], buildPath).first
+ val compilable = compileCode(generatedTestCasesPaths[index], buildPath, workingDir).isSuccessful()
allTestCasesCompilable = allTestCasesCompilable && compilable
if (compilable) {
compilableTestCases.add(testCases[index])
@@ -51,45 +54,12 @@ open class TestCompiler(
* Compiles the code at the specified path using the provided project build path.
*
* @param path The path of the code file to compile.
- * @param projectBuildPath The project build path to use during compilation.
+ * @param projectBuildPath All the directories where the compiled code of the project under test is saved. This path is used as a classpath to run each test case.
+ * @param workingDir The path of the directory that contains package directories of the code to compile
* @return A pair containing a boolean value indicating whether the compilation was successful (true) or not (false),
* and a string message describing any error encountered during compilation.
*/
- fun compileCode(path: String, projectBuildPath: String): Pair {
- // find the proper javac
- val javaCompile = File(javaHomeDirectoryPath).walk()
- .filter {
- val isCompilerName = if (DataFilesUtil.isWindows()) it.name.equals("javac.exe") else it.name.equals("javac")
- isCompilerName && it.isFile
- }
- .firstOrNull()
-
- if (javaCompile == null) {
- val msg = "Cannot find java compiler 'javac' at '$javaHomeDirectoryPath'"
- log.error { msg }
- throw RuntimeException(msg)
- }
-
- println("javac found at '${javaCompile.absolutePath}'")
-
- // compile file
- val errorMsg = CommandLineRunner.run(
- arrayListOf(
- javaCompile.absolutePath,
- "-cp",
- "\"${getPath(projectBuildPath)}\"",
- path,
- ),
- )
-
- log.info { "Error message: '$errorMsg'" }
-
- // create .class file path
- val classFilePath = path.replace(".java", ".class")
-
- // check is .class file exists
- return Pair(File(classFilePath).exists(), errorMsg)
- }
+ abstract fun compileCode(path: String, projectBuildPath: String, workingDir: String): ExecutionResult
/**
* Generates the path for the command by concatenating the necessary paths.
@@ -97,15 +67,5 @@ open class TestCompiler(
* @param buildPath The path of the build file.
* @return The generated path as a string.
*/
- fun getPath(buildPath: String): String {
- // create the path for the command
- val separator = DataFilesUtil.classpathSeparator
- val dependencyLibPath = libPaths.joinToString(separator.toString())
- val junitPath = junitLibPaths.joinToString(separator.toString())
-
- val path = "$junitPath${separator}$dependencyLibPath${separator}$buildPath"
- println("[TestCompiler]: the path is: $path")
-
- return path
- }
+ abstract fun getClassPaths(buildPath: String): String
}
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/TestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestSuiteParser.kt
similarity index 61%
rename from core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/TestSuiteParser.kt
rename to core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestSuiteParser.kt
index 9ce724888..60c4016d4 100644
--- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/TestSuiteParser.kt
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestSuiteParser.kt
@@ -1,10 +1,17 @@
-package org.jetbrains.research.testspark.core.test.parsers
+package org.jetbrains.research.testspark.core.test
+import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM
import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM
+data class TestCaseParseResult(
+ val testCase: TestCaseGeneratedByLLM?,
+ val errorMessage: String,
+ val errorOccurred: Boolean,
+)
+
interface TestSuiteParser {
/**
- * Extracts test cases from raw text and generates a test suite using the given package name.
+ * Extracts test cases from raw text and generates a test suite.
*
* @param rawText The raw text provided by the LLM that contains the generated test cases.
* @return A GeneratedTestSuite instance containing the extracted test cases.
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt
index a761e53b2..0d9c672de 100644
--- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt
@@ -32,10 +32,9 @@ abstract class TestsAssembler {
}
/**
- * Extracts test cases from raw text and generates a TestSuite using the given package name.
+ * Extracts test cases from raw text and generates a TestSuite.
*
- * @param packageName The package name to be set in the generated TestSuite.
- * @return A TestSuiteGeneratedByLLM object containing the extracted test cases and package name.
+ * @return A TestSuiteGeneratedByLLM object containing information about the extracted test cases.
*/
- abstract fun assembleTestSuite(packageName: String): TestSuiteGeneratedByLLM?
+ abstract fun assembleTestSuite(): TestSuiteGeneratedByLLM?
}
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsPersistentStorage.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsPersistentStorage.kt
index 1673fea4a..b9d50132c 100644
--- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsPersistentStorage.kt
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsPersistentStorage.kt
@@ -4,6 +4,7 @@ package org.jetbrains.research.testspark.core.test
* The TestPersistentStorage interface represents a contract for saving generated tests to a specified file system location.
*/
interface TestsPersistentStorage {
+
/**
* Save the generated tests to a specified directory.
*
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/data/CodeType.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/CodeType.kt
similarity index 73%
rename from src/main/kotlin/org/jetbrains/research/testspark/data/CodeType.kt
rename to core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/CodeType.kt
index 8e91aded4..12f18eb54 100644
--- a/src/main/kotlin/org/jetbrains/research/testspark/data/CodeType.kt
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/CodeType.kt
@@ -1,4 +1,4 @@
-package org.jetbrains.research.testspark.data
+package org.jetbrains.research.testspark.core.test.data
/**
* Enum class, which contains all code elements for which it is possible to request test generation.
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt
index 6ef9f6907..2a565e82e 100644
--- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt
@@ -1,5 +1,7 @@
package org.jetbrains.research.testspark.core.test.data
+import org.jetbrains.research.testspark.core.test.TestBodyPrinter
+
/**
*
* Represents a test case generated by LLM.
@@ -11,6 +13,7 @@ data class TestCaseGeneratedByLLM(
var expectedException: String = "",
var throwsException: String = "",
var lines: MutableList = mutableListOf(),
+ val printTestBodyStrategy: TestBodyPrinter,
) {
/**
@@ -104,31 +107,7 @@ data class TestCaseGeneratedByLLM(
* @return a string containing the body of test case
*/
private fun printTestBody(testInitiatedText: String): String {
- var testFullText = testInitiatedText
-
- // start writing the test signature
- testFullText += "\n\tpublic void $name() "
-
- // add throws exception if exists
- if (throwsException.isNotBlank()) {
- testFullText += "throws $throwsException"
- }
-
- // start writing the test lines
- testFullText += "{\n"
-
- // write each line
- lines.forEach { line ->
- testFullText += when (line.type) {
- TestLineType.BREAK -> "\t\t\n"
- else -> "\t\t${line.text}\n"
- }
- }
-
- // close test case
- testFullText += "\t}\n"
-
- return testFullText
+ return printTestBodyStrategy.printTestBody(testInitiatedText, lines, throwsException, name)
}
/**
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt
index 211063bb7..525fb9afc 100644
--- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt
@@ -4,12 +4,12 @@ package org.jetbrains.research.testspark.core.test.data
* Represents a test suite generated by LLM.
*
* @property imports The set of import statements in the test suite.
- * @property packageString The package string of the test suite.
+ * @property packageName The package name of the test suite.
* @property testCases The list of test cases in the test suite.
*/
data class TestSuiteGeneratedByLLM(
- var imports: Set = emptySet(),
- var packageString: String = "",
+ var imports: MutableSet = mutableSetOf(),
+ var packageName: String = "",
var runWith: String = "",
var otherInfo: String = "",
var testCases: MutableList = mutableListOf(),
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/JavaTestCompilationDependencies.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/TestCompilationDependencies.kt
similarity index 85%
rename from core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/JavaTestCompilationDependencies.kt
rename to core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/TestCompilationDependencies.kt
index 2e78b0b50..72baf689a 100644
--- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/JavaTestCompilationDependencies.kt
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/TestCompilationDependencies.kt
@@ -6,7 +6,7 @@ import org.jetbrains.research.testspark.core.data.JarLibraryDescriptor
* The class represents a list of dependencies required for java test compilation.
* The libraries listed are used during test suite/test case compilation.
*/
-class JavaTestCompilationDependencies {
+class TestCompilationDependencies {
companion object {
fun getJarDescriptors() = listOf(
JarLibraryDescriptor(
@@ -25,6 +25,10 @@ class JavaTestCompilationDependencies {
"byte-buddy-agent-1.14.6.jar",
"https://repo1.maven.org/maven2/net/bytebuddy/byte-buddy-agent/1.14.6/byte-buddy-agent-1.14.6.jar",
),
+ JarLibraryDescriptor(
+ "opentest4j-1.1.1.jar",
+ "https://repo1.maven.org/maven2/org/opentest4j/opentest4j/1.1.1/",
+ ),
)
}
}
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaJUnitTestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaJUnitTestSuiteParser.kt
new file mode 100644
index 000000000..279badc57
--- /dev/null
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaJUnitTestSuiteParser.kt
@@ -0,0 +1,32 @@
+package org.jetbrains.research.testspark.core.test.java
+
+import org.jetbrains.research.testspark.core.data.JUnitVersion
+import org.jetbrains.research.testspark.core.generation.llm.getPackageFromTestSuiteCode
+import org.jetbrains.research.testspark.core.test.SupportedLanguage
+import org.jetbrains.research.testspark.core.test.TestBodyPrinter
+import org.jetbrains.research.testspark.core.test.TestSuiteParser
+import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM
+import org.jetbrains.research.testspark.core.test.strategies.JUnitTestSuiteParserStrategy
+import org.jetbrains.research.testspark.core.utils.javaImportPattern
+
+class JavaJUnitTestSuiteParser(
+ private var packageName: String,
+ private val junitVersion: JUnitVersion,
+ private val testBodyPrinter: TestBodyPrinter,
+) : TestSuiteParser {
+ override fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM? {
+ val packageInsideTestText = getPackageFromTestSuiteCode(rawText, SupportedLanguage.Java)
+ if (packageInsideTestText.isNotBlank()) {
+ packageName = packageInsideTestText
+ }
+
+ return JUnitTestSuiteParserStrategy.parseJUnitTestSuite(
+ rawText,
+ junitVersion,
+ javaImportPattern,
+ packageName,
+ testNamePattern = "void",
+ testBodyPrinter,
+ )
+ }
+}
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaTestBodyPrinter.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaTestBodyPrinter.kt
new file mode 100644
index 000000000..bafbcaf13
--- /dev/null
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaTestBodyPrinter.kt
@@ -0,0 +1,40 @@
+package org.jetbrains.research.testspark.core.test.java
+
+import org.jetbrains.research.testspark.core.test.TestBodyPrinter
+import org.jetbrains.research.testspark.core.test.data.TestLine
+import org.jetbrains.research.testspark.core.test.data.TestLineType
+
+class JavaTestBodyPrinter : TestBodyPrinter {
+ override fun printTestBody(
+ testInitiatedText: String,
+ lines: MutableList,
+ throwsException: String,
+ name: String,
+ ): String {
+ var testFullText = testInitiatedText
+
+ // start writing the test signature
+ testFullText += "\n\tpublic void $name() "
+
+ // add throws exception if exists
+ if (throwsException.isNotBlank()) {
+ testFullText += "throws $throwsException"
+ }
+
+ // start writing the test lines
+ testFullText += "{\n"
+
+ // write each line
+ lines.forEach { line ->
+ testFullText += when (line.type) {
+ TestLineType.BREAK -> "\t\t\n"
+ else -> "\t\t${line.text}\n"
+ }
+ }
+
+ // close test case
+ testFullText += "\t}\n"
+
+ return testFullText
+ }
+}
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaTestCompiler.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaTestCompiler.kt
new file mode 100644
index 000000000..4486eac52
--- /dev/null
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaTestCompiler.kt
@@ -0,0 +1,75 @@
+package org.jetbrains.research.testspark.core.test.java
+
+import io.github.oshai.kotlinlogging.KotlinLogging
+import org.jetbrains.research.testspark.core.exception.ClassFileNotFoundException
+import org.jetbrains.research.testspark.core.exception.JavaCompilerNotFoundException
+import org.jetbrains.research.testspark.core.test.ExecutionResult
+import org.jetbrains.research.testspark.core.test.TestCompiler
+import org.jetbrains.research.testspark.core.utils.CommandLineRunner
+import org.jetbrains.research.testspark.core.utils.DataFilesUtil
+import java.io.File
+
+class JavaTestCompiler(
+ libPaths: List,
+ junitLibPaths: List,
+ javaHomeDirectoryPath: String,
+) : TestCompiler(libPaths, junitLibPaths) {
+ private val logger = KotlinLogging.logger { this::class.java }
+ private val javac: String
+
+ // init block to find the javac compiler
+ init {
+ // find the proper javac
+ val javaCompiler = File(javaHomeDirectoryPath).walk()
+ .filter {
+ val isCompilerName = if (DataFilesUtil.isWindows()) {
+ it.name.equals("javac.exe")
+ } else {
+ it.name.equals("javac")
+ }
+ isCompilerName && it.isFile
+ }
+ .firstOrNull()
+
+ if (javaCompiler == null) {
+ val msg = "Cannot find Java compiler 'javac' at $javaHomeDirectoryPath"
+ logger.error { msg }
+ throw JavaCompilerNotFoundException("Ensure Java SDK is configured for the project. $msg.")
+ }
+ javac = javaCompiler.absolutePath
+ }
+
+ override fun compileCode(path: String, projectBuildPath: String, workingDir: String): ExecutionResult {
+ val classPaths = "\"${getClassPaths(projectBuildPath)}\""
+ // compile file
+ val executionResult = CommandLineRunner.run(
+ arrayListOf(
+ /**
+ * Filepath may contain spaces, so we need to wrap it in quotes.
+ */
+ "'$javac'",
+ "-cp",
+ classPaths,
+ path,
+ /**
+ * We don't have to provide -d option, since javac saves class files in the same place by default
+ */
+ ),
+ )
+ logger.info { "Exit code: '${executionResult.exitCode}'; Execution message: '${executionResult.executionMessage}'" }
+
+ val classFilePath = path.replace(".java", ".class")
+ if (!File(classFilePath).exists()) {
+ throw ClassFileNotFoundException("Expected class file at $classFilePath after the compilation of file $path, but it does not exist.")
+ }
+ return executionResult
+ }
+
+ override fun getClassPaths(buildPath: String): String {
+ var path = commonPath.plus(buildPath)
+
+ if (path.endsWith(separator)) path = path.removeSuffix(separator.toString())
+
+ return path
+ }
+}
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinJUnitTestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinJUnitTestSuiteParser.kt
new file mode 100644
index 000000000..18b164810
--- /dev/null
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinJUnitTestSuiteParser.kt
@@ -0,0 +1,32 @@
+package org.jetbrains.research.testspark.core.test.kotlin
+
+import org.jetbrains.research.testspark.core.data.JUnitVersion
+import org.jetbrains.research.testspark.core.generation.llm.getPackageFromTestSuiteCode
+import org.jetbrains.research.testspark.core.test.SupportedLanguage
+import org.jetbrains.research.testspark.core.test.TestBodyPrinter
+import org.jetbrains.research.testspark.core.test.TestSuiteParser
+import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM
+import org.jetbrains.research.testspark.core.test.strategies.JUnitTestSuiteParserStrategy
+import org.jetbrains.research.testspark.core.utils.kotlinImportPattern
+
+class KotlinJUnitTestSuiteParser(
+ private var packageName: String,
+ private val junitVersion: JUnitVersion,
+ private val testBodyPrinter: TestBodyPrinter,
+) : TestSuiteParser {
+ override fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM? {
+ val packageInsideTestText = getPackageFromTestSuiteCode(rawText, SupportedLanguage.Kotlin)
+ if (packageInsideTestText.isNotBlank()) {
+ packageName = packageInsideTestText
+ }
+
+ return JUnitTestSuiteParserStrategy.parseJUnitTestSuite(
+ rawText,
+ junitVersion,
+ kotlinImportPattern,
+ packageName,
+ testNamePattern = "fun",
+ testBodyPrinter,
+ )
+ }
+}
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestBodyPrinter.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestBodyPrinter.kt
new file mode 100644
index 000000000..a1a9dc8df
--- /dev/null
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestBodyPrinter.kt
@@ -0,0 +1,40 @@
+package org.jetbrains.research.testspark.core.test.kotlin
+
+import org.jetbrains.research.testspark.core.test.TestBodyPrinter
+import org.jetbrains.research.testspark.core.test.data.TestLine
+import org.jetbrains.research.testspark.core.test.data.TestLineType
+
+class KotlinTestBodyPrinter : TestBodyPrinter {
+ override fun printTestBody(
+ testInitiatedText: String,
+ lines: MutableList,
+ throwsException: String,
+ name: String,
+ ): String {
+ var testFullText = testInitiatedText
+
+ // start writing the test signature
+ testFullText += "\n\tfun $name() "
+
+ // add throws exception if exists
+ if (throwsException.isNotBlank()) {
+ testFullText += "throws $throwsException"
+ }
+
+ // start writing the test lines
+ testFullText += "{\n"
+
+ // write each line
+ lines.forEach { line ->
+ testFullText += when (line.type) {
+ TestLineType.BREAK -> "\t\t\n"
+ else -> "\t\t${line.text}\n"
+ }
+ }
+
+ // close test case
+ testFullText += "\t}\n"
+
+ return testFullText
+ }
+}
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestCompiler.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestCompiler.kt
new file mode 100644
index 000000000..e1487ebba
--- /dev/null
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestCompiler.kt
@@ -0,0 +1,85 @@
+package org.jetbrains.research.testspark.core.test.kotlin
+
+import io.github.oshai.kotlinlogging.KotlinLogging
+import org.jetbrains.research.testspark.core.exception.ClassFileNotFoundException
+import org.jetbrains.research.testspark.core.exception.KotlinCompilerNotFoundException
+import org.jetbrains.research.testspark.core.test.ExecutionResult
+import org.jetbrains.research.testspark.core.test.TestCompiler
+import org.jetbrains.research.testspark.core.utils.CommandLineRunner
+import org.jetbrains.research.testspark.core.utils.DataFilesUtil
+import java.io.File
+
+class KotlinTestCompiler(
+ libPaths: List,
+ junitLibPaths: List,
+ kotlinSDKHomeDirectory: String,
+) : TestCompiler(libPaths, junitLibPaths) {
+ private val logger = KotlinLogging.logger { this::class.java }
+ private val kotlinc: String
+
+ // init block to find the kotlinc compiler
+ init {
+ // search for a proper kotlinc
+ val kotlinCompiler = File(kotlinSDKHomeDirectory).walk()
+ .filter {
+ /**
+ * Tested on Windows 10, IntelliJ IDEA Community Edition 2023.1.4 (2023.1.4.IC-231.9225.26)
+ *
+ * Windows' kotlinc requires `java` command to be present in ENV (e.g., present in PATH).
+ * Otherwise, it won't be able to execute itself.
+ *
+ * Missing `java` in PATH does not yield runtime error but is considered
+ * as failed compilation because `kotlinc` will complain about
+ * `java` command missing in PATH.
+ *
+ * TODO(vartiukhov): find a way to locate `java` on Windows
+ */
+ val isCompilerName = if (DataFilesUtil.isWindows()) {
+ it.name.equals("kotlinc")
+ } else {
+ it.name.equals("kotlinc")
+ }
+ isCompilerName && it.isFile
+ }.firstOrNull()
+
+ if (kotlinCompiler == null) {
+ val msg = "Cannot find Kotlin compiler 'kotlinc' at $kotlinSDKHomeDirectory"
+ logger.error { msg }
+ throw KotlinCompilerNotFoundException("Please make sure that the Kotlin plugin is installed and enabled. $msg.")
+ }
+
+ kotlinc = kotlinCompiler.absolutePath
+ }
+
+ override fun compileCode(path: String, projectBuildPath: String, workingDir: String): ExecutionResult {
+ logger.info { "[KotlinTestCompiler] Compiling ${path.substringAfterLast('/')}" }
+
+ val classPaths = "\"${getClassPaths(projectBuildPath)}\""
+ // Compile file
+ val executionResult = CommandLineRunner.run(
+ arrayListOf(
+ /**
+ * Filepath may contain spaces, so we need to wrap it in quotes.
+ */
+ "'$kotlinc'",
+ "-cp",
+ classPaths,
+ path,
+ /**
+ * Forcing kotlinc to save a classfile in the same place, as '.kt' file
+ */
+ "-d",
+ workingDir,
+ ),
+ )
+ logger.info { "Exit code: '${executionResult.exitCode}'; Execution message: '${executionResult.executionMessage}'" }
+
+ val classFilePath = path.removeSuffix(".kt") + ".class"
+ if (!File(classFilePath).exists()) {
+ throw ClassFileNotFoundException("Expected class file at $classFilePath after the compilation of file $path, but it does not exist.")
+ }
+ return executionResult
+ }
+
+ override fun getClassPaths(buildPath: String): String = commonPath.plus(buildPath)
+}
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/java/JUnitTestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/java/JUnitTestSuiteParser.kt
deleted file mode 100644
index 2186a61c7..000000000
--- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/java/JUnitTestSuiteParser.kt
+++ /dev/null
@@ -1,174 +0,0 @@
-package org.jetbrains.research.testspark.core.test.parsers.java
-
-import org.jetbrains.research.testspark.core.data.JUnitVersion
-import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM
-import org.jetbrains.research.testspark.core.test.data.TestLine
-import org.jetbrains.research.testspark.core.test.data.TestLineType
-import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM
-import org.jetbrains.research.testspark.core.test.parsers.TestSuiteParser
-import org.jetbrains.research.testspark.core.utils.importPattern
-
-class JUnitTestSuiteParser(
- private val packageName: String,
- private val junitVersion: JUnitVersion,
-) : TestSuiteParser {
- override fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM? {
- if (rawText.isBlank()) {
- return null
- }
-
- try {
- var rawCode = rawText
-
- if (rawText.contains("```")) {
- rawCode = rawText.split("```")[1]
- }
-
- // save imports
- val imports = importPattern.findAll(rawCode, 0)
- .map { it.groupValues[0] }
- .toSet()
-
- // save RunWith
- val runWith: String = junitVersion.runWithAnnotationMeta.extract(rawCode) ?: ""
-
- val testSet: MutableList = rawCode.split("@Test").toMutableList()
-
- // save annotations and pre-set methods
- val otherInfo: String = run {
- val otherInfoList = testSet.removeAt(0).split("{").toMutableList()
- otherInfoList.removeFirst()
- val otherInfo = otherInfoList.joinToString("{").trimEnd() + "\n\n"
- otherInfo.ifBlank { "" }
- }
-
- // Save the main test cases
- val testCases: MutableList = mutableListOf()
- val testCaseParser = JUnitTestCaseParser()
-
- testSet.forEach ca@{
- val rawTest = "@Test$it"
-
- val isLastTestCaseInTestSuite = (testCases.size == testSet.size - 1)
- val result: TestCaseParseResult = testCaseParser.parse(rawTest, isLastTestCaseInTestSuite)
-
- if (result.errorOccurred) {
- println("WARNING: ${result.errorMessage}")
- return@ca
- }
-
- val currentTest = result.testCase!!
-
- // TODO: make logging work
- // log.info("New test case: $currentTest")
- println("New test case: $currentTest")
-
- testCases.add(currentTest)
- }
-
- val testSuite = TestSuiteGeneratedByLLM(
- imports = imports,
- packageString = packageName,
- runWith = runWith,
- otherInfo = otherInfo,
- testCases = testCases,
- )
-
- return testSuite
- } catch (e: Exception) {
- return null
- }
- }
-}
-
-private data class TestCaseParseResult(
- val testCase: TestCaseGeneratedByLLM?,
- val errorMessage: String,
- val errorOccurred: Boolean,
-)
-
-private class JUnitTestCaseParser {
- fun parse(rawTest: String, isLastTestCaseInTestSuite: Boolean): TestCaseParseResult {
- var expectedException = ""
- var throwsException = ""
- val testLines: MutableList = mutableListOf()
-
- // Get expected Exception
- if (rawTest.startsWith("@Test(expected =")) {
- expectedException = rawTest.split(")")[0].trim()
- }
-
- // Get unexpected exceptions
- /* Each test case should follow [public] void {...}
- Tests do not return anything so it is safe to consider that void always appears before test case name
- */
- val voidString = "void"
- if (!rawTest.contains(voidString)) {
- return TestCaseParseResult(
- testCase = null,
- errorMessage = "The raw Test does not contain $voidString:\n $rawTest",
- errorOccurred = true,
- )
- }
- val interestingPartOfSignature = rawTest.split(voidString)[1]
- .split("{")[0]
- .split("()")[1]
- .trim()
-
- if (interestingPartOfSignature.contains("throws")) {
- throwsException = interestingPartOfSignature.split("throws")[1].trim()
- }
-
- // Get test name
- val testName: String = rawTest.split(voidString)[1]
- .split("()")[0]
- .trim()
-
- // Get test body and remove opening bracket
- var testBody = rawTest.split("{").toMutableList().apply { removeFirst() }
- .joinToString("{").trim()
-
- // remove closing bracket
- val tempList = testBody.split("}").toMutableList()
- tempList.removeLast()
-
- if (isLastTestCaseInTestSuite) {
- // it is the last test, thus we should remove another closing bracket
- if (tempList.isNotEmpty()) {
- tempList.removeLast()
- } else {
- println("WARNING: the final test does not have the enclosing bracket:\n $testBody")
- }
- }
-
- testBody = tempList.joinToString("}")
-
- // Save each line
- val rawLines = testBody.split("\n").toMutableList()
- rawLines.forEach { rawLine ->
- val line = rawLine.trim()
-
- val type: TestLineType = when {
- line.startsWith("//") -> TestLineType.COMMENT
- line.isBlank() -> TestLineType.BREAK
- line.lowercase().startsWith("assert") -> TestLineType.ASSERTION
- else -> TestLineType.CODE
- }
-
- testLines.add(TestLine(type, line))
- }
-
- val currentTest = TestCaseGeneratedByLLM(
- name = testName,
- expectedException = expectedException,
- throwsException = throwsException,
- lines = testLines,
- )
-
- return TestCaseParseResult(
- testCase = currentTest,
- errorMessage = "",
- errorOccurred = false,
- )
- }
-}
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/strategies/JUnitTestSuiteParserStrategy.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/strategies/JUnitTestSuiteParserStrategy.kt
new file mode 100644
index 000000000..3e3f6ac34
--- /dev/null
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/strategies/JUnitTestSuiteParserStrategy.kt
@@ -0,0 +1,192 @@
+package org.jetbrains.research.testspark.core.test.strategies
+
+import org.jetbrains.research.testspark.core.data.JUnitVersion
+import org.jetbrains.research.testspark.core.test.TestBodyPrinter
+import org.jetbrains.research.testspark.core.test.TestCaseParseResult
+import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM
+import org.jetbrains.research.testspark.core.test.data.TestLine
+import org.jetbrains.research.testspark.core.test.data.TestLineType
+import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM
+
+class JUnitTestSuiteParserStrategy {
+ companion object {
+ fun parseJUnitTestSuite(
+ rawText: String,
+ junitVersion: JUnitVersion,
+ importPattern: Regex,
+ packageName: String,
+ testNamePattern: String,
+ printTestBodyStrategy: TestBodyPrinter,
+ ): TestSuiteGeneratedByLLM? {
+ if (rawText.isBlank()) {
+ return null
+ }
+
+ try {
+ val rawCode = if (rawText.contains("```")) rawText.split("```")[1] else rawText
+
+ // save imports
+ val imports = importPattern.findAll(rawCode)
+ .map { it.groupValues[0] }
+ .toMutableSet()
+
+ // save RunWith
+ val runWith: String = junitVersion.runWithAnnotationMeta.extract(rawCode) ?: ""
+
+ val testSet: MutableList = rawCode.split("@Test").toMutableList()
+
+ // save annotations and pre-set methods
+ val otherInfo: String = run {
+ val otherInfoList = testSet.removeAt(0).split("{").toMutableList()
+ otherInfoList.removeFirst()
+ val otherInfo = otherInfoList.joinToString("{").trimEnd() + "\n\n"
+ otherInfo.ifBlank { "" }
+ }
+
+ // Save the main test cases
+ val testCases: MutableList = mutableListOf()
+ val testCaseParser = JUnitTestCaseParser()
+
+ testSet.forEach ca@{
+ val rawTest = "@Test$it"
+
+ val isLastTestCaseInTestSuite = (testCases.size == testSet.size - 1)
+ val result: TestCaseParseResult =
+ testCaseParser.parse(rawTest, isLastTestCaseInTestSuite, testNamePattern, printTestBodyStrategy)
+
+ if (result.errorOccurred) {
+ println("WARNING: ${result.errorMessage}")
+ return@ca
+ }
+
+ val currentTest = result.testCase!!
+
+ // TODO: make logging work
+ // log.info("New test case: $currentTest")
+
+ testCases.add(currentTest)
+ }
+
+ val testSuite = TestSuiteGeneratedByLLM(
+ imports = imports,
+ packageName = packageName,
+ runWith = runWith,
+ otherInfo = otherInfo,
+ testCases = testCases,
+ )
+
+ return testSuite
+ } catch (e: Exception) {
+ return null
+ }
+ }
+ }
+
+ private class JUnitTestCaseParser {
+ fun parse(
+ rawTest: String,
+ isLastTestCaseInTestSuite: Boolean,
+ testNamePattern: String,
+ printTestBodyStrategy: TestBodyPrinter,
+ ): TestCaseParseResult {
+ var expectedException = ""
+ var throwsException = ""
+ val testLines: MutableList = mutableListOf()
+
+ // Get expected Exception
+ if (rawTest.startsWith("@Test(expected =")) {
+ expectedException = rawTest.split(")")[0].trim()
+ }
+
+ // Get unexpected exceptions
+ /* Each test case should follow fun {...}
+ Tests do not return anything so it is safe to consider that void always appears before test case name
+ */
+ if (!rawTest.contains(testNamePattern)) {
+ return TestCaseParseResult(
+ testCase = null,
+ errorMessage = "The raw Test does not contain $testNamePattern:\n $rawTest",
+ errorOccurred = true,
+ )
+ }
+
+ /**
+ * The test definition is split into two parts:
+ * [void|fun] () [throws ] { ... }
+ * The first part is the test definition prologue,
+ * the second part is the test definition epilogue.
+ *
+ * Therefore, as an epilogue, we have everything after [void|fun].
+ *
+ * `limit = 2` is used to avoid additional splitting in case if the test
+ * case name starts with the same word as the test definition prologue.
+ */
+ val testCaseEpilogue = rawTest.split(testNamePattern, limit = 2)[1]
+
+ /**
+ * Optional [throws ] part is extracted from the test definition epilogue.
+ */
+ val interestingPartOfSignature = testCaseEpilogue
+ .split("{")[0]
+ .split("()")[1]
+ .trim()
+
+ if (interestingPartOfSignature.contains("throws")) {
+ throwsException = interestingPartOfSignature.split("throws")[1].trim()
+ }
+
+ // Get test name
+ val testName: String = testCaseEpilogue
+ .split("()")[0]
+ .trim()
+
+ // Get test body and remove opening bracket
+ var testBody = rawTest.split("{").toMutableList().apply { removeFirst() }
+ .joinToString("{").trim()
+
+ // remove closing bracket
+ val tempList = testBody.split("}").toMutableList()
+ tempList.removeLast()
+
+ if (isLastTestCaseInTestSuite) {
+ // it is the last test, thus we should remove another closing bracket
+ if (tempList.isNotEmpty()) {
+ tempList.removeLast()
+ } else {
+ println("WARNING: the final test does not have the enclosing bracket:\n $testBody")
+ }
+ }
+
+ testBody = tempList.joinToString("}")
+
+ // Save each line
+ val rawLines = testBody.split("\n").toMutableList()
+ rawLines.forEach { rawLine ->
+ val line = rawLine.trim()
+
+ val type: TestLineType = when {
+ line.startsWith("//") -> TestLineType.COMMENT
+ line.isBlank() -> TestLineType.BREAK
+ line.lowercase().startsWith("assert") -> TestLineType.ASSERTION
+ else -> TestLineType.CODE
+ }
+
+ testLines.add(TestLine(type, line))
+ }
+
+ val currentTest = TestCaseGeneratedByLLM(
+ name = testName,
+ expectedException = expectedException,
+ throwsException = throwsException,
+ lines = testLines,
+ printTestBodyStrategy = printTestBodyStrategy,
+ )
+
+ return TestCaseParseResult(
+ testCase = currentTest,
+ errorMessage = "",
+ errorOccurred = false,
+ )
+ }
+ }
+}
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/CommandLineRunner.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/CommandLineRunner.kt
index 97e870bae..a9d0343c9 100644
--- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/CommandLineRunner.kt
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/CommandLineRunner.kt
@@ -1,6 +1,7 @@
package org.jetbrains.research.testspark.core.utils
import io.github.oshai.kotlinlogging.KotlinLogging
+import org.jetbrains.research.testspark.core.test.ExecutionResult
import java.io.BufferedReader
import java.io.InputStreamReader
@@ -9,16 +10,16 @@ class CommandLineRunner {
protected val log = KotlinLogging.logger {}
/**
- * Executes a command line process and returns the output as a string.
+ * Executes a command line process
*
* @param cmd The command line arguments as an ArrayList of strings.
- * @return The output of the command line process as a string.
+ * @return A pair containing exit code and a string message containing stdout and stderr of the executed process.
*/
- fun run(cmd: ArrayList): String {
- var errorMessage = ""
+ fun run(cmd: ArrayList): ExecutionResult {
+ var executionMsg = ""
/**
- * Since Windows does not provide bash, use cmd or similar default command line interpreter
+ * Since Windows does not provide bash, use cmd or simila r default command line interpreter
*/
val process = if (DataFilesUtil.isWindows()) {
ProcessBuilder()
@@ -32,17 +33,16 @@ class CommandLineRunner {
.redirectErrorStream(true)
.start()
}
-
val reader = BufferedReader(InputStreamReader(process.inputStream))
+ val separator = System.lineSeparator()
var line: String?
while (reader.readLine().also { line = it } != null) {
- errorMessage += line
+ executionMsg += "$line$separator"
}
process.waitFor()
-
- return errorMessage
+ return ExecutionResult(process.exitValue(), executionMsg)
}
}
}
diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt
index 123610ac7..fb1da6841 100644
--- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt
+++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt
@@ -1,13 +1,33 @@
package org.jetbrains.research.testspark.core.utils
-val importPattern =
+val javaImportPattern =
Regex(
pattern = "^import\\s+(static\\s)?((?:[a-zA-Z_]\\w*\\.)*[a-zA-Z_](?:\\w*\\.?)*)(?:\\.\\*)?;",
options = setOf(RegexOption.MULTILINE),
)
-val packagePattern =
+/**
+ * Parse all the possible Kotlin import patterns
+ *
+ * import org.mockito.Mockito.`when`
+ * import kotlin.math.cos
+ * import kotlin.math.*
+ * import kotlin.math.PI as piValue
+ */
+val kotlinImportPattern =
+ Regex(
+ pattern = "^import\\s+((?:[a-zA-Z_]\\w*\\.)*(?:\\w*\\.?)*)?(\\*)?( as \\w*)?(`\\w*`)?",
+ options = setOf(RegexOption.MULTILINE),
+ )
+
+val javaPackagePattern =
Regex(
pattern = "^package\\s+((?:[a-zA-Z_]\\w*\\.)*[a-zA-Z_](?:\\w*\\.?)*)(?:\\.\\*)?;",
options = setOf(RegexOption.MULTILINE),
)
+
+val kotlinPackagePattern =
+ Regex(
+ pattern = "^package\\s+((?:[a-zA-Z_]\\w*\\.)*[a-zA-Z_](?:\\w*\\.?)*)(?:\\.\\*)?",
+ options = setOf(RegexOption.MULTILINE),
+ )
diff --git a/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt b/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt
new file mode 100644
index 000000000..63fbd0abc
--- /dev/null
+++ b/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt
@@ -0,0 +1,260 @@
+package org.jetbrains.research.testspark.core.test.parsers.kotlin
+
+import org.jetbrains.research.testspark.core.data.JUnitVersion
+import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM
+import org.jetbrains.research.testspark.core.test.kotlin.KotlinJUnitTestSuiteParser
+import org.jetbrains.research.testspark.core.test.kotlin.KotlinTestBodyPrinter
+import org.junit.jupiter.api.Assertions.assertEquals
+import org.junit.jupiter.api.Assertions.assertNotNull
+import org.junit.jupiter.api.Assertions.assertTrue
+import org.junit.jupiter.api.Test
+
+class KotlinJUnitTestSuiteParserTest {
+
+ @Test
+ fun testParseTestSuite() {
+ val text = """
+ ```kotlin
+ import org.junit.jupiter.api.Assertions.*
+ import org.junit.jupiter.api.Test
+ import org.mockito.Mockito.*
+ import org.mockito.kotlin.any
+ import org.mockito.kotlin.eq
+ import org.mockito.kotlin.mock
+ import org.test.Message as TestMessage
+
+ class MyClassTest {
+
+ @Test
+ fun compileTestCases_AllCompilableTest() {
+ // Arrange
+ val myClass = MyClass()
+ val generatedTestCasesPaths = listOf("path1", "path2")
+ val buildPath = "buildPath"
+ val testCase1 = TestCaseGeneratedByLLM()
+ val testCase2 = TestCaseGeneratedByLLM()
+ val testCases = mutableListOf(testCase1, testCase2)
+
+ val myClassSpy = spy(myClass)
+ doReturn(Pair(true, "")).`when`(myClassSpy).compileCode(any(), eq(buildPath))
+
+ // Act
+ val result = myClassSpy.compileTestCases(generatedTestCasesPaths, buildPath, testCases)
+
+ // Assert
+ assertTrue(result.allTestCasesCompilable)
+ assertEquals(setOf(testCase1, testCase2), result.compilableTestCases)
+ }
+
+ @Test
+ fun compileTestCases_NoneCompilableTest() {
+ // Arrange
+ val myClass = MyClass()
+ val generatedTestCasesPaths = listOf("path1", "path2")
+ val buildPath = "buildPath"
+ val testCase1 = TestCaseGeneratedByLLM()
+ val testCase2 = TestCaseGeneratedByLLM()
+ val testCases = mutableListOf(testCase1, testCase2)
+
+ val myClassSpy = spy(myClass)
+ doReturn(Pair(false, "")).`when`(myClassSpy).compileCode(any(), eq(buildPath))
+
+ // Act
+ val result = myClassSpy.compileTestCases(generatedTestCasesPaths, buildPath, testCases)
+
+ // Assert
+ assertFalse(result.allTestCasesCompilable)
+ assertTrue(result.compilableTestCases.isEmpty())
+ }
+
+ @Test
+ fun compileTestCases_SomeCompilableTest() {
+ // Arrange
+ val myClass = MyClass()
+ val generatedTestCasesPaths = listOf("path1", "path2")
+ val buildPath = "buildPath"
+ val testCase1 = TestCaseGeneratedByLLM()
+ val testCase2 = TestCaseGeneratedByLLM()
+ val testCases = mutableListOf(testCase1, testCase2)
+
+ val myClassSpy = spy(myClass)
+ doReturn(Pair(true, "")).`when`(myClassSpy).compileCode(eq("path1"), eq(buildPath))
+ doReturn(Pair(false, "")).`when`(myClassSpy).compileCode(eq("path2"), eq(buildPath))
+
+ // Act
+ val result = myClassSpy.compileTestCases(generatedTestCasesPaths, buildPath, testCases)
+
+ // Assert
+ assertFalse(result.allTestCasesCompilable)
+ assertEquals(setOf(testCase1), result.compilableTestCases)
+ }
+
+ @Test
+ fun compileTestCases_EmptyTestCasesTest() {
+ // Arrange
+ val myClass = MyClass()
+ val generatedTestCasesPaths = emptyList()
+ val buildPath = "buildPath"
+ val testCases = mutableListOf()
+
+ // Act
+ val result = myClass.compileTestCases(generatedTestCasesPaths, buildPath, testCases)
+
+ // Assert
+ assertTrue(result.allTestCasesCompilable)
+ assertTrue(result.compilableTestCases.isEmpty())
+ }
+
+ @Test(expected = ArithmeticException::class, Exception::class)
+ fun compileTestCases_omg() {
+ val blackHole = 1 / 0
+ }
+ }
+ ```
+ """.trimIndent()
+
+ val testBodyPrinter = KotlinTestBodyPrinter()
+ val parser =
+ KotlinJUnitTestSuiteParser("org.example", JUnitVersion.JUnit5, testBodyPrinter)
+ val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(text)
+ assertNotNull(testSuite)
+ assertTrue(testSuite!!.imports.contains("import org.mockito.Mockito.*"))
+ assertTrue(testSuite.imports.contains("import org.test.Message as TestMessage"))
+ assertTrue(testSuite.imports.contains("import org.mockito.kotlin.mock"))
+
+ val expectedTestCasesNames = listOf(
+ "compileTestCases_AllCompilableTest",
+ "compileTestCases_NoneCompilableTest",
+ "compileTestCases_SomeCompilableTest",
+ "compileTestCases_EmptyTestCasesTest",
+ "compileTestCases_omg",
+ )
+
+ testSuite.testCases.forEachIndexed { index, testCase ->
+ val expected = expectedTestCasesNames[index]
+ assertEquals(expected, testCase.name) { "${index + 1}st test case has incorrect name" }
+ }
+
+ assertTrue(testSuite.testCases[4].expectedException.isNotBlank())
+ }
+
+ @Test
+ fun testParseEmptyTestSuite() {
+ val text = """
+ ```kotlin
+ package com.example.testsuite
+
+ class EmptyTestClass {
+ }
+ ```
+ """.trimIndent()
+
+ val testBodyPrinter = KotlinTestBodyPrinter()
+ val parser =
+ KotlinJUnitTestSuiteParser("", JUnitVersion.JUnit5, testBodyPrinter)
+ val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(text)
+ assertNotNull(testSuite)
+ assertEquals(testSuite!!.packageName, "com.example.testsuite")
+ assertTrue(testSuite.testCases.isEmpty())
+ }
+
+ @Test
+ fun testParseSingleTestCase() {
+ val text = """
+ ```kotlin
+ import org.junit.jupiter.api.Test
+
+ class SingleTestCaseClass {
+ @Test
+ fun singleTestCase() {
+ // Test case implementation
+ }
+ }
+ ```
+ """.trimIndent()
+
+ val testBodyPrinter = KotlinTestBodyPrinter()
+ val parser =
+ KotlinJUnitTestSuiteParser("org.example", JUnitVersion.JUnit5, testBodyPrinter)
+ val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(text)
+ assertNotNull(testSuite)
+ assertEquals(1, testSuite!!.testCases.size)
+ assertEquals("singleTestCase", testSuite.testCases[0].name)
+ }
+
+ @Test
+ fun testParseTwoTestCases() {
+ val text = """
+ ```kotlin
+ import org.junit.jupiter.api.Test
+
+ class TwoTestCasesClass {
+ @Test
+ fun firstTestCase() {
+ // Test case implementation
+ }
+
+ @Test
+ fun secondTestCase() {
+ // Test case implementation
+ }
+ }
+ ```
+ """.trimIndent()
+
+ val testBodyPrinter = KotlinTestBodyPrinter()
+ val parser =
+ KotlinJUnitTestSuiteParser("org.example", JUnitVersion.JUnit5, testBodyPrinter)
+ val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(text)
+ assertNotNull(testSuite)
+ assertEquals(2, testSuite!!.testCases.size)
+ assertEquals("firstTestCase", testSuite.testCases[0].name)
+ assertEquals("secondTestCase", testSuite.testCases[1].name)
+ }
+
+ @Test
+ fun testParseTwoTestCasesWithDifferentPackage() {
+ val code1 = """
+ ```kotlin
+ package org.pkg1
+
+ import org.junit.jupiter.api.Test
+
+ class TestCasesClass1 {
+ @Test
+ fun firstTestCase() {
+ // Test case implementation
+ }
+ }
+ ```
+ """.trimIndent()
+
+ val code2 = """
+ ```kotlin
+ package org.pkg2
+
+ import org.junit.jupiter.api.Test
+
+ class 2TestCasesClass {
+ @Test
+ fun firstTestCase() {
+ // Test case implementation
+ }
+ }
+ ```
+ """.trimIndent()
+
+ val testBodyPrinter = KotlinTestBodyPrinter()
+ val parser = KotlinJUnitTestSuiteParser("", JUnitVersion.JUnit5, testBodyPrinter)
+
+ // packageName will be set to 'org.pkg1'
+ val testSuite1 = parser.parseTestSuite(code1)
+
+ val testSuite2 = parser.parseTestSuite(code2)
+
+ assertNotNull(testSuite1)
+ assertNotNull(testSuite2)
+ assertEquals("org.pkg1", testSuite1!!.packageName)
+ assertEquals("org.pkg2", testSuite2!!.packageName)
+ }
+}
diff --git a/gradle.properties b/gradle.properties
index 421a31f06..5fa69f82d 100644
--- a/gradle.properties
+++ b/gradle.properties
@@ -4,28 +4,32 @@
pluginGroup = org.jetbrains.research.testspark
pluginName = TestSpark
# SemVer format -> https://semver.org
-pluginVersion = 0.2.1
+pluginVersion = 0.3.0
+
+evosuiteVersion = 1.0.5
# See https://plugins.jetbrains.com/docs/intellij/build-number-ranges.html
# for insight into build numbers and IntelliJ Platform versions.
-pluginSinceBuild = 241
-pluginUntilBuild = 241.*
+pluginSinceBuild = 242
+pluginUntilBuild = 242.*
# IntelliJ Platform Properties -> https://github.com/JetBrains/gradle-intellij-plugin#intellij-platform-properties
platformType = IC
-platformVersion = 2024.1
+platformVersion = 2024.2.3
# Plugin Dependencies -> https://plugins.jetbrains.com/docs/intellij/plugin-dependencies.html
# Example: platformPlugins = com.intellij.java, com.jetbrains.php:203.4449.22
-platformPlugins = com.intellij.java
+platformPlugins = com.intellij.java, org.jetbrains.kotlin, org.jetbrains.idea.maven, com.intellij.gradle
# Java language level used to compile sources and to generate the files for - Java 17 is required since 2023.1
javaVersion = 17
# Gradle Releases -> https://github.com/gradle/gradle/releases
-gradleVersion = 8.2.1
+gradleVersion = 8.10.2
# Opt-out flag for bundling Kotlin standard library.
# See https://plugins.jetbrains.com/docs/intellij/kotlin.html#kotlin-standard-library for details.
# suppress inspection "UnusedProperty"
kotlin.stdlib.default.dependency = false
+
+jvmToolchainVersion = 17
diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties
index 070cb702f..0d1842103 100644
--- a/gradle/wrapper/gradle-wrapper.properties
+++ b/gradle/wrapper/gradle-wrapper.properties
@@ -1,5 +1,5 @@
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
-distributionUrl=https\://services.gradle.org/distributions/gradle-7.6-bin.zip
+distributionUrl=https\://services.gradle.org/distributions/gradle-8.8-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
diff --git a/java/build.gradle.kts b/java/build.gradle.kts
new file mode 100644
index 000000000..57e63b50d
--- /dev/null
+++ b/java/build.gradle.kts
@@ -0,0 +1,40 @@
+plugins {
+ kotlin("jvm")
+ id("org.jetbrains.intellij.platform")
+}
+
+repositories {
+ mavenCentral()
+ intellijPlatform {
+ defaultRepositories()
+ }
+}
+
+dependencies {
+ intellijPlatform {
+ create(rootProject.properties["platformType"].toString(), rootProject.properties["platformVersion"].toString())
+ // Plugin Dependencies. Uses `platformPlugins` property from the gradle.properties file.
+ bundledPlugins(listOf("com.intellij.java"))
+
+ instrumentationTools()
+ }
+ implementation(kotlin("stdlib"))
+
+ implementation(project(":langwrappers")) // Interfaces that cover language-specific logic
+ implementation(project(":core"))
+}
+
+intellijPlatform {
+ pluginConfiguration {
+ rootProject.properties["platformVersion"]?.let { version = it.toString() }
+ }
+}
+
+tasks.named("verifyPlugin") { enabled = false }
+tasks.named("runIde") { enabled = false }
+tasks.named("prepareJarSearchableOptions") { enabled = false }
+tasks.named("publishPlugin") { enabled = false }
+
+kotlin {
+ jvmToolchain(rootProject.properties["jvmToolchainVersion"].toString().toInt())
+}
diff --git a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt
new file mode 100644
index 000000000..9f6a5a28c
--- /dev/null
+++ b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt
@@ -0,0 +1,97 @@
+package org.jetbrains.research.testspark.java
+
+import com.intellij.openapi.project.Project
+import com.intellij.openapi.vfs.VirtualFile
+import com.intellij.psi.PsiAnonymousClass
+import com.intellij.psi.PsiClass
+import com.intellij.psi.PsiFile
+import com.intellij.psi.PsiModifier
+import com.intellij.psi.search.GlobalSearchScope
+import com.intellij.psi.search.searches.ClassInheritorsSearch
+import com.intellij.psi.util.PsiTypesUtil
+import org.jetbrains.research.testspark.core.data.ClassType
+import org.jetbrains.research.testspark.core.utils.javaImportPattern
+import org.jetbrains.research.testspark.core.utils.javaPackagePattern
+import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper
+import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper
+import org.jetbrains.research.testspark.langwrappers.strategies.JavaKotlinClassTextExtractor
+
+class JavaPsiClassWrapper(private val psiClass: PsiClass) : PsiClassWrapper {
+ override val name: String get() = psiClass.name ?: ""
+
+ override val qualifiedName: String get() = psiClass.qualifiedName ?: ""
+
+ override val text: String get() = psiClass.text
+
+ override val methods: List get() = psiClass.methods.map { JavaPsiMethodWrapper(it) }
+
+ override val allMethods: List get() = psiClass.allMethods.map { JavaPsiMethodWrapper(it) }
+
+ override val constructorSignatures: List get() = psiClass.constructors.map { JavaPsiMethodWrapper.buildSignature(it) }
+
+ override val superClass: PsiClassWrapper? get() = psiClass.superClass?.let { JavaPsiClassWrapper(it) }
+
+ override val virtualFile: VirtualFile get() = psiClass.containingFile.virtualFile
+
+ override val containingFile: PsiFile get() = psiClass.containingFile
+
+ override val fullText: String
+ get() = JavaKotlinClassTextExtractor().extract(
+ psiClass.containingFile,
+ psiClass.text,
+ javaPackagePattern,
+ javaImportPattern,
+ )
+
+ override val classType: ClassType
+ get() {
+ if (psiClass.isInterface) {
+ return ClassType.INTERFACE
+ }
+ if (psiClass.hasModifierProperty(PsiModifier.ABSTRACT)) {
+ return ClassType.ABSTRACT_CLASS
+ }
+ return ClassType.CLASS
+ }
+
+ override val rBrace: Int? = psiClass.rBrace?.textRange?.startOffset
+
+ override fun searchSubclasses(project: Project): Collection {
+ val scope = GlobalSearchScope.projectScope(project)
+ val query = ClassInheritorsSearch.search(psiClass, scope, false)
+ return query.findAll().map { JavaPsiClassWrapper(it) }
+ }
+
+ override fun getInterestingPsiClassesWithQualifiedNames(
+ psiMethod: PsiMethodWrapper,
+ ): MutableSet {
+ val interestingMethods = mutableSetOf(psiMethod as JavaPsiMethodWrapper)
+ for (currentPsiMethod in allMethods) {
+ if ((currentPsiMethod as JavaPsiMethodWrapper).isConstructor) interestingMethods.add(currentPsiMethod)
+ }
+ val interestingPsiClasses = mutableSetOf(this)
+ interestingMethods.forEach { methodIt ->
+ methodIt.parameterList.parameters.forEach { paramIt ->
+ PsiTypesUtil.getPsiClass(paramIt.type)?.let { typeIt ->
+ JavaPsiClassWrapper(typeIt).let {
+ if (it.qualifiedName != "" && !it.qualifiedName.startsWith("java.")) {
+ interestingPsiClasses.add(it)
+ }
+ }
+ }
+ }
+ }
+
+ return interestingPsiClasses.toMutableSet()
+ }
+
+ /**
+ * Checks if the constraints on the selected class are satisfied, so that EvoSuite can generate tests for it.
+ * Namely, it is not an enum and not an anonymous inner class.
+ *
+ * @return true if the constraints are satisfied, false otherwise
+ */
+ fun isTestableClass(): Boolean {
+ return !psiClass.isEnum && psiClass !is PsiAnonymousClass
+ }
+}
diff --git a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt
new file mode 100644
index 000000000..8562c382a
--- /dev/null
+++ b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt
@@ -0,0 +1,221 @@
+package org.jetbrains.research.testspark.java
+
+import com.intellij.openapi.actionSystem.AnActionEvent
+import com.intellij.openapi.actionSystem.CommonDataKeys
+import com.intellij.openapi.diagnostic.Logger
+import com.intellij.openapi.editor.Caret
+import com.intellij.openapi.module.ModuleUtilCore
+import com.intellij.openapi.project.Project
+import com.intellij.openapi.util.TextRange
+import com.intellij.psi.PsiClass
+import com.intellij.psi.PsiDocumentManager
+import com.intellij.psi.PsiElement
+import com.intellij.psi.PsiFile
+import com.intellij.psi.PsiJavaFile
+import com.intellij.psi.PsiMethod
+import com.intellij.psi.util.PsiTreeUtil
+import com.intellij.psi.util.PsiTypesUtil
+import org.jetbrains.research.testspark.core.test.SupportedLanguage
+import org.jetbrains.research.testspark.core.test.data.CodeType
+import org.jetbrains.research.testspark.langwrappers.CodeTypeDisplayName
+import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper
+import org.jetbrains.research.testspark.langwrappers.PsiHelper
+import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper
+
+class JavaPsiHelper(private val psiFile: PsiFile) : PsiHelper {
+
+ override val language: SupportedLanguage get() = SupportedLanguage.Java
+
+ /**
+ * When dealing with Java PSI files, we expect that only classes and their methods are tested.
+ * Therefore, we expect a **class** to surround a cursor offset.
+ *
+ * This requirement ensures that the user is not trying
+ * to generate tests for a line of code outside the class scope.
+ *
+ * @param e `AnActionEvent` representing the current action event.
+ * @return `true` if the cursor is inside a class, `false` otherwise.
+ */
+ override fun availableForGeneration(e: AnActionEvent): Boolean =
+ getCurrentListOfCodeTypes(e).any { it.first == CodeType.CLASS }
+
+ private val log = Logger.getInstance(this::class.java)
+
+ override fun generateMethodDescriptor(psiMethod: PsiMethodWrapper): String {
+ val methodDescriptor = psiMethod.methodDescriptor
+ log.info("Method description: $methodDescriptor")
+ return methodDescriptor
+ }
+
+ override fun getSurroundingClass(caretOffset: Int): PsiClassWrapper? {
+ val classElements = PsiTreeUtil.findChildrenOfAnyType(psiFile, PsiClass::class.java)
+ for (cls in classElements) {
+ if (cls.containsOffset(caretOffset)) {
+ val javaClassWrapper = JavaPsiClassWrapper(cls)
+ if (javaClassWrapper.isTestableClass()) {
+ log.info("Surrounding class for caret in $caretOffset is ${javaClassWrapper.qualifiedName}")
+ return javaClassWrapper
+ }
+ }
+ }
+ log.info("No surrounding class for caret in $caretOffset")
+ return null
+ }
+
+ override fun getSurroundingMethod(caretOffset: Int): PsiMethodWrapper? {
+ val methodElements = PsiTreeUtil.findChildrenOfAnyType(psiFile, PsiMethod::class.java)
+ for (method in methodElements) {
+ if (method.body != null && method.containsOffset(caretOffset)) {
+ val surroundingClass =
+ PsiTreeUtil.getParentOfType(method, PsiClass::class.java) ?: continue
+ val surroundingClassWrapper = JavaPsiClassWrapper(surroundingClass)
+ if (surroundingClassWrapper.isTestableClass()) {
+ val javaMethod = JavaPsiMethodWrapper(method)
+ log.info("Surrounding method for caret in $caretOffset is ${javaMethod.methodDescriptor}")
+ return javaMethod
+ }
+ }
+ }
+ log.info("No surrounding method for caret in $caretOffset")
+ return null
+ }
+
+ override fun getSurroundingLineNumber(caretOffset: Int): Int? {
+ val doc = PsiDocumentManager.getInstance(psiFile.project).getDocument(psiFile) ?: return null
+
+ /**
+ * See `getLineNumber`'s documentation for details on the numbering.
+ * It returns an index of the line in the document, starting from 0.
+ *
+ * Therefore, we need to increase the result by one to get the line number.
+ */
+ val selectedLine = doc.getLineNumber(caretOffset)
+ val selectedLineText =
+ doc.getText(TextRange(doc.getLineStartOffset(selectedLine), doc.getLineEndOffset(selectedLine)))
+
+ if (selectedLineText.isBlank()) {
+ log.info("Line $selectedLine at caret $caretOffset is blank")
+ return null
+ }
+ log.info("Surrounding line at caret $caretOffset is $selectedLine")
+ // increase by one is necessary due to different start of numbering
+ return selectedLine + 1
+ }
+
+ override fun collectClassesToTest(
+ project: Project,
+ classesToTest: MutableList,
+ caretOffset: Int,
+ maxPolymorphismDepth: Int,
+ ) {
+ val cutPsiClass = getSurroundingClass(caretOffset)!!
+ var currentPsiClass = cutPsiClass
+ for (index in 0 until maxPolymorphismDepth) {
+ if (!classesToTest.contains(currentPsiClass)) {
+ classesToTest.add(currentPsiClass)
+ }
+
+ if (currentPsiClass.superClass == null ||
+ currentPsiClass.superClass!!.qualifiedName.startsWith("java.")
+ ) {
+ break
+ }
+ currentPsiClass = currentPsiClass.superClass!!
+ }
+ log.info("There are ${classesToTest.size} classes to test")
+ }
+
+ override fun getInterestingPsiClassesWithQualifiedNames(
+ project: Project,
+ classesToTest: List,
+ polyDepthReducing: Int,
+ maxInputParamsDepth: Int,
+ ): MutableSet {
+ val interestingPsiClasses: MutableSet = mutableSetOf()
+
+ var currentLevelClasses =
+ mutableListOf().apply { addAll(classesToTest) }
+
+ repeat(maxInputParamsDepth) {
+ val tempListOfClasses = mutableSetOf()
+
+ currentLevelClasses.forEach { classIt ->
+ classIt.methods.forEach { methodIt ->
+ (methodIt as JavaPsiMethodWrapper).parameterList.parameters.forEach { paramIt ->
+ PsiTypesUtil.getPsiClass(paramIt.type)?.let { typeIt ->
+ JavaPsiClassWrapper(typeIt).let {
+ if (!it.qualifiedName.startsWith("java.")) {
+ interestingPsiClasses.add(it)
+ }
+ }
+ }
+ }
+ }
+ }
+ currentLevelClasses = mutableListOf().apply { addAll(tempListOfClasses) }
+ interestingPsiClasses.addAll(tempListOfClasses)
+ }
+ log.info("There are ${interestingPsiClasses.size} interesting psi classes")
+ return interestingPsiClasses.toMutableSet()
+ }
+
+ override fun getInterestingPsiClassesWithQualifiedNames(
+ cut: PsiClassWrapper?,
+ psiMethod: PsiMethodWrapper,
+ ): MutableSet {
+ // The cut is always not null for Java, because all functions are always inside the class
+ val interestingPsiClasses = cut!!.getInterestingPsiClassesWithQualifiedNames(psiMethod)
+ log.info("There are ${interestingPsiClasses.size} interesting psi classes from method ${psiMethod.methodDescriptor}")
+ return interestingPsiClasses
+ }
+
+ override fun getCurrentListOfCodeTypes(e: AnActionEvent): List {
+ val result: ArrayList = arrayListOf()
+ val caret: Caret =
+ e.dataContext.getData(CommonDataKeys.CARET)?.caretModel?.primaryCaret ?: return result
+
+ val javaPsiClassWrapped = getSurroundingClass(caret.offset) as JavaPsiClassWrapper?
+ val javaPsiMethodWrapped = getSurroundingMethod(caret.offset) as JavaPsiMethodWrapper?
+ val line: Int? = getSurroundingLineNumber(caret.offset)
+
+ javaPsiClassWrapped?.let { result.add(CodeType.CLASS to getClassHTMLDisplayName(it)) }
+ javaPsiMethodWrapped?.let { result.add(CodeType.METHOD to getMethodHTMLDisplayName(it)) }
+ line?.let { result.add(CodeType.LINE to getLineHTMLDisplayName(it)) }
+
+ log.info(
+ "The test can be generated for: \n " +
+ " 1) Class ${javaPsiClassWrapped?.qualifiedName ?: "no class"} \n" +
+ " 2) Method ${javaPsiMethodWrapped?.name ?: "no method"} \n" +
+ " 3) Line $line",
+ )
+
+ return result
+ }
+
+ override fun getPackageName() = (psiFile as PsiJavaFile).packageName
+
+ override fun getModuleFromPsiFile() = ModuleUtilCore.findModuleForFile(psiFile.virtualFile, psiFile.project)!!
+
+ override fun getDocumentFromPsiFile() = psiFile.fileDocument
+
+ override fun getLineHTMLDisplayName(line: Int) = "line $line"
+
+ override fun getClassHTMLDisplayName(psiClass: PsiClassWrapper): String =
+ "${psiClass.classType.representation} ${psiClass.qualifiedName}"
+
+ override fun getMethodHTMLDisplayName(psiMethod: PsiMethodWrapper): String {
+ return if ((psiMethod as JavaPsiMethodWrapper).isDefaultConstructor) {
+ "default constructor"
+ } else if (psiMethod.isConstructor) {
+ "constructor"
+ } else if (psiMethod.isMethodDefault) {
+ "default method ${psiMethod.name}"
+ } else {
+ "method ${psiMethod.name}"
+ }
+ }
+
+ private fun PsiElement.containsOffset(caretOffset: Int): Boolean {
+ return (textRange.startOffset <= caretOffset) && (textRange.endOffset >= caretOffset)
+ }
+}
diff --git a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelperProvider.kt b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelperProvider.kt
new file mode 100644
index 000000000..ae97fcf90
--- /dev/null
+++ b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelperProvider.kt
@@ -0,0 +1,8 @@
+package org.jetbrains.research.testspark.java
+
+import com.intellij.psi.PsiFile
+import org.jetbrains.research.testspark.langwrappers.PsiHelperProvider
+
+class JavaPsiHelperProvider : PsiHelperProvider {
+ override fun getPsiHelper(file: PsiFile) = JavaPsiHelper(file)
+}
diff --git a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiMethodWrapper.kt b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiMethodWrapper.kt
new file mode 100644
index 000000000..fd60cd488
--- /dev/null
+++ b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiMethodWrapper.kt
@@ -0,0 +1,130 @@
+package org.jetbrains.research.testspark.java
+
+import com.intellij.psi.PsiClassType
+import com.intellij.psi.PsiDocumentManager
+import com.intellij.psi.PsiFile
+import com.intellij.psi.PsiMethod
+import com.intellij.psi.PsiSubstitutor
+import com.intellij.psi.PsiType
+import com.intellij.util.containers.stream
+import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper
+import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper
+import java.util.stream.Collectors
+
+class JavaPsiMethodWrapper(private val psiMethod: PsiMethod) : PsiMethodWrapper {
+ override val name: String get() = psiMethod.name
+
+ override val text: String? = psiMethod.text
+
+ override val containingClass: PsiClassWrapper? = psiMethod.containingClass?.let { JavaPsiClassWrapper(it) }
+
+ override val containingFile: PsiFile = psiMethod.containingFile
+
+ override val methodDescriptor: String
+ get() {
+ val parameterTypes =
+ psiMethod.getSignature(PsiSubstitutor.EMPTY)
+ .parameterTypes
+ .stream()
+ .map { i -> generateFieldType(i) }
+ .collect(Collectors.joining())
+
+ val returnType = generateReturnDescriptor(psiMethod)
+
+ return "${psiMethod.name}($parameterTypes)$returnType"
+ }
+
+ override val signature: String
+ get() = buildSignature(psiMethod)
+
+ val parameterList = psiMethod.parameterList
+
+ val isConstructor: Boolean = psiMethod.isConstructor
+
+ val isMethodDefault: Boolean
+ get() {
+ if (psiMethod.body == null) return false
+ return psiMethod.containingClass?.isInterface ?: return false
+ }
+
+ val isDefaultConstructor: Boolean get() = psiMethod.isConstructor && (psiMethod.body?.isEmpty ?: false)
+
+ override fun containsLine(lineNumber: Int): Boolean {
+ val psiFile = psiMethod.containingFile ?: return false
+ val document = PsiDocumentManager.getInstance(psiFile.project).getDocument(psiFile) ?: return false
+ val textRange = psiMethod.textRange
+ val startLine = document.getLineNumber(textRange.startOffset) + 1
+ val endLine = document.getLineNumber(textRange.endOffset) + 1
+ return lineNumber in startLine..endLine
+ }
+
+ /**
+ * Generates the return descriptor for a method.
+ *
+ * @param psiMethod the method
+ * @return the return descriptor
+ */
+ private fun generateReturnDescriptor(psiMethod: PsiMethod): String {
+ if (psiMethod.returnType == null || psiMethod.returnType!!.canonicalText == "void") {
+ // void method
+ return "V"
+ }
+
+ return generateFieldType(psiMethod.returnType!!)
+ }
+
+ /**
+ * Generates the field descriptor for a type.
+ * https://docs.oracle.com/javase/specs/jvms/se7/html/jvms-4.html#jvms-4.3
+ *
+ * @param psiType the type to generate the descriptor for
+ * @return the field descriptor
+ */
+ private fun generateFieldType(psiType: PsiType): String {
+ // arrays (ArrayType)
+ if (psiType.arrayDimensions > 0) {
+ val arrayType = generateFieldType(psiType.deepComponentType)
+ return "[".repeat(psiType.arrayDimensions) + arrayType
+ }
+
+ // objects (ObjectType)
+ if (psiType is PsiClassType) {
+ val classType = psiType.resolve()
+ if (classType != null) {
+ val className = classType.qualifiedName?.replace('.', '/')
+
+ // no need to handle generics: they are not part of method descriptors
+
+ return "L$className;"
+ }
+ }
+
+ // primitives (BaseType)
+ psiType.canonicalText.let {
+ return when (it) {
+ "int" -> "I"
+ "long" -> "J"
+ "float" -> "F"
+ "double" -> "D"
+ "boolean" -> "Z"
+ "byte" -> "B"
+ "char" -> "C"
+ "short" -> "S"
+ else -> throw IllegalArgumentException("Unknown type: $it")
+ }
+ }
+ }
+
+ companion object {
+ /**
+ * Builds a signature for a given `PsiMethod`.
+ *
+ * @param method the PsiMethod for which to build the signature
+ * @return the method signature with the text before the method body, excluding newline characters
+ */
+ fun buildSignature(method: PsiMethod): String {
+ val bodyStart = method.body?.startOffsetInParent ?: method.textLength
+ return method.text.substring(0, bodyStart).replace("\\n", "").trim()
+ }
+ }
+}
diff --git a/java/src/main/resources/META-INF/testgenie-java.xml b/java/src/main/resources/META-INF/testgenie-java.xml
new file mode 100644
index 000000000..68bfa7c85
--- /dev/null
+++ b/java/src/main/resources/META-INF/testgenie-java.xml
@@ -0,0 +1,9 @@
+
+
+
+
+
+
diff --git a/kotlin/build.gradle.kts b/kotlin/build.gradle.kts
new file mode 100644
index 000000000..2faafc052
--- /dev/null
+++ b/kotlin/build.gradle.kts
@@ -0,0 +1,47 @@
+plugins {
+ kotlin("jvm")
+ id("org.jetbrains.intellij.platform")
+}
+
+repositories {
+ mavenCentral()
+ intellijPlatform {
+ defaultRepositories()
+ }
+}
+
+dependencies {
+
+ intellijPlatform {
+ create(rootProject.properties["platformType"].toString(), rootProject.properties["platformVersion"].toString())
+ // Plugin Dependencies. Uses `platformPlugins` property from the gradle.properties file.
+ bundledPlugins(listOf("com.intellij.java", "org.jetbrains.kotlin"))
+
+ instrumentationTools()
+ }
+ implementation(kotlin("stdlib"))
+
+ implementation(project(":langwrappers")) // Interfaces that cover language-specific logic
+ implementation(project(":core"))
+}
+
+intellijPlatform {
+ pluginConfiguration {
+ rootProject.properties["platformVersion"]?.let { version = it.toString() }
+ }
+}
+
+tasks.named("verifyPlugin") { enabled = false }
+tasks.named("runIde") { enabled = false }
+tasks.named("prepareJarSearchableOptions") { enabled = false }
+tasks.named("publishPlugin") { enabled = false }
+
+tasks {
+ buildSearchableOptions {
+ enabled = false
+ }
+}
+
+kotlin {
+ jvmToolchain(rootProject.properties["jvmToolchainVersion"].toString().toInt())
+}
diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt
new file mode 100644
index 000000000..578c22729
--- /dev/null
+++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt
@@ -0,0 +1,126 @@
+package org.jetbrains.research.testspark.kotlin
+
+import com.intellij.openapi.project.Project
+import com.intellij.openapi.vfs.VirtualFile
+import com.intellij.psi.PsiFile
+import com.intellij.psi.search.GlobalSearchScope
+import com.intellij.psi.search.searches.ClassInheritorsSearch
+import com.intellij.psi.util.PsiTreeUtil
+import org.jetbrains.kotlin.asJava.classes.KtUltraLightClass
+import org.jetbrains.kotlin.asJava.toLightClass
+import org.jetbrains.kotlin.idea.base.psi.kotlinFqName
+import org.jetbrains.kotlin.idea.caches.resolve.analyze
+import org.jetbrains.kotlin.idea.refactoring.isInterfaceClass
+import org.jetbrains.kotlin.idea.testIntegration.framework.KotlinPsiBasedTestFramework.Companion.asKtClassOrObject
+import org.jetbrains.kotlin.lexer.KtTokens
+import org.jetbrains.kotlin.psi.KtClass
+import org.jetbrains.kotlin.psi.KtClassOrObject
+import org.jetbrains.kotlin.psi.KtObjectDeclaration
+import org.jetbrains.kotlin.psi.allConstructors
+import org.jetbrains.kotlin.resolve.BindingContext
+import org.jetbrains.kotlin.resolve.DescriptorToSourceUtils
+import org.jetbrains.research.testspark.core.data.ClassType
+import org.jetbrains.research.testspark.core.utils.kotlinImportPattern
+import org.jetbrains.research.testspark.core.utils.kotlinPackagePattern
+import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper
+import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper
+import org.jetbrains.research.testspark.langwrappers.strategies.JavaKotlinClassTextExtractor
+
+class KotlinPsiClassWrapper(private val psiClass: KtClassOrObject) : PsiClassWrapper {
+ override val name: String get() = psiClass.name ?: ""
+
+ override val qualifiedName: String get() = psiClass.fqName!!.asString()
+
+ override val text: String? get() = psiClass.text
+
+ override val methods: List
+ get() = psiClass.body?.functions?.filter { it.name != null }?.map { KotlinPsiMethodWrapper(it) } ?: emptyList()
+
+ override val allMethods: List get() = methods
+
+ override val constructorSignatures: List get() = psiClass.allConstructors.map { KotlinPsiMethodWrapper.buildSignature(it) }
+
+ override val superClass: PsiClassWrapper?
+ get() {
+ // Get the superTypeListEntries of the Kotlin class
+ val superTypeListEntries = psiClass.superTypeListEntries
+ // Find the superclass entry (if any)
+ val superClassEntry = superTypeListEntries.firstOrNull()
+ // Resolve the superclass type reference to a PsiClass
+ val superClassTypeReference = superClassEntry?.typeReference
+ val superClassDescriptor = superClassTypeReference?.let {
+ val bindingContext = it.analyze()
+ bindingContext[BindingContext.TYPE, it]
+ }
+ val superClassPsiClass = superClassDescriptor?.constructor?.declarationDescriptor?.let { descriptor ->
+ DescriptorToSourceUtils.getSourceFromDescriptor(descriptor) as? KtClass
+ }
+ return if (psiClass.fqName != null) {
+ superClassPsiClass?.let { KotlinPsiClassWrapper(it) }
+ } else {
+ null
+ }
+ }
+
+ override val virtualFile: VirtualFile get() = psiClass.containingFile.virtualFile
+
+ override val containingFile: PsiFile get() = psiClass.containingFile
+
+ override val fullText: String
+ get() = JavaKotlinClassTextExtractor().extract(
+ psiClass.containingFile,
+ psiClass.text,
+ kotlinPackagePattern,
+ kotlinImportPattern,
+ )
+
+ override val classType: ClassType
+ get() {
+ return when {
+ psiClass is KtObjectDeclaration -> ClassType.OBJECT
+ psiClass.isInterfaceClass() -> ClassType.INTERFACE
+ psiClass.hasModifier(KtTokens.ABSTRACT_KEYWORD) -> ClassType.ABSTRACT_CLASS
+ psiClass.isData() -> ClassType.DATA_CLASS
+ psiClass.annotationEntries.any { it.text == "@JvmInline" } -> ClassType.INLINE_VALUE_CLASS
+ else -> ClassType.CLASS
+ }
+ }
+
+ override val rBrace: Int? = psiClass.body?.rBrace?.textRange?.startOffset
+
+ override fun searchSubclasses(project: Project): Collection {
+ val scope = GlobalSearchScope.projectScope(project)
+ val lightClass = psiClass.toLightClass()
+ return if (lightClass != null) {
+ val query = ClassInheritorsSearch.search(lightClass, scope, false)
+ query.findAll().filter { it.kotlinFqName != null }.map {
+ // If the sub-class is fetched as an ultra light class, get the KtClass
+ if (it is KtUltraLightClass) {
+ KotlinPsiClassWrapper(it.asKtClassOrObject() as KtClass)
+ } else {
+ KotlinPsiClassWrapper(it as KtClass)
+ }
+ }
+ } else {
+ emptyList()
+ }
+ }
+
+ override fun getInterestingPsiClassesWithQualifiedNames(
+ psiMethod: PsiMethodWrapper,
+ ): MutableSet {
+ val interestingPsiClasses = mutableSetOf()
+ val method = psiMethod as KotlinPsiMethodWrapper
+
+ method.psiFunction.valueParameters.forEach { parameter ->
+ val typeReference = parameter.typeReference
+ val psiClass = PsiTreeUtil.getParentOfType(typeReference, KtClass::class.java)
+ if (psiClass != null && psiClass.fqName != null && !psiClass.fqName.toString().startsWith("kotlin.")) {
+ interestingPsiClasses.add(KotlinPsiClassWrapper(psiClass))
+ }
+ }
+
+ interestingPsiClasses.add(this)
+ return interestingPsiClasses
+ }
+}
diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt
new file mode 100644
index 000000000..760568909
--- /dev/null
+++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt
@@ -0,0 +1,210 @@
+package org.jetbrains.research.testspark.kotlin
+
+import com.intellij.openapi.actionSystem.AnActionEvent
+import com.intellij.openapi.actionSystem.CommonDataKeys
+import com.intellij.openapi.diagnostic.Logger
+import com.intellij.openapi.editor.Caret
+import com.intellij.openapi.module.ModuleUtilCore
+import com.intellij.openapi.project.Project
+import com.intellij.openapi.util.TextRange
+import com.intellij.psi.PsiDocumentManager
+import com.intellij.psi.PsiFile
+import com.intellij.psi.util.parentOfType
+import org.jetbrains.kotlin.psi.KtClassOrObject
+import org.jetbrains.kotlin.psi.KtFile
+import org.jetbrains.kotlin.psi.KtFunction
+import org.jetbrains.kotlin.psi.KtPsiUtil
+import org.jetbrains.research.testspark.core.test.SupportedLanguage
+import org.jetbrains.research.testspark.core.test.data.CodeType
+import org.jetbrains.research.testspark.langwrappers.CodeTypeDisplayName
+import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper
+import org.jetbrains.research.testspark.langwrappers.PsiHelper
+import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper
+
+class KotlinPsiHelper(private val psiFile: PsiFile) : PsiHelper {
+
+ override val language: SupportedLanguage get() = SupportedLanguage.Kotlin
+
+ /**
+ * When dealing with Kotlin PSI files, we expect that only classes, their methods,
+ * top-level functions are tested.
+ * Therefore, we expect either a class or a method (top-level function) to surround a cursor offset.
+ *
+ * This requirement ensures that the user is not trying
+ * to generate tests for a line of code outside the aforementioned scopes.
+ *
+ * @param e `AnActionEvent` representing the current action event.
+ * @return `true` if the cursor is inside a class or method, `false` otherwise.
+ */
+ override fun availableForGeneration(e: AnActionEvent): Boolean =
+ getCurrentListOfCodeTypes(e).any { (it.first == CodeType.CLASS) || (it.first == CodeType.METHOD) }
+
+ private val log = Logger.getInstance(this::class.java)
+
+ override fun generateMethodDescriptor(psiMethod: PsiMethodWrapper): String {
+ val methodDescriptor = psiMethod.methodDescriptor
+ log.info("Method description: $methodDescriptor")
+ return methodDescriptor
+ }
+
+ override fun getSurroundingClass(caretOffset: Int): PsiClassWrapper? {
+ val element = psiFile.findElementAt(caretOffset)
+ val cls = element?.parentOfType(withSelf = true)
+
+ if (cls != null && cls.name != null && cls.fqName != null) {
+ val kotlinClassWrapper = KotlinPsiClassWrapper(cls)
+ log.info("Surrounding class for caret in $caretOffset is ${kotlinClassWrapper.qualifiedName}")
+ return kotlinClassWrapper
+ }
+
+ log.info("No surrounding class for caret in $caretOffset")
+ return null
+ }
+
+ override fun getSurroundingMethod(caretOffset: Int): PsiMethodWrapper? {
+ val element = psiFile.findElementAt(caretOffset)
+ val method = element?.parentOfType(withSelf = true)
+
+ if (method != null && method.name != null) {
+ val wrappedMethod = KotlinPsiMethodWrapper(method)
+ log.info("Surrounding method for caret at $caretOffset is ${wrappedMethod.methodDescriptor}")
+ return wrappedMethod
+ }
+
+ log.info("No surrounding method for caret at $caretOffset")
+ return null
+ }
+
+ override fun getSurroundingLineNumber(caretOffset: Int): Int? {
+ val doc = PsiDocumentManager.getInstance(psiFile.project).getDocument(psiFile) ?: return null
+
+ /**
+ * See `getLineNumber`'s documentation for details on the numbering.
+ * It returns an index of the line in the document, starting from 0.
+ *
+ * Therefore, we need to increase the result by one to get the line number.
+ */
+ val selectedLine = doc.getLineNumber(caretOffset)
+ val selectedLineText =
+ doc.getText(TextRange(doc.getLineStartOffset(selectedLine), doc.getLineEndOffset(selectedLine)))
+
+ if (selectedLineText.isBlank()) {
+ log.info("Line $selectedLine at caret $caretOffset is not valid")
+ return null
+ }
+ log.info("Surrounding line at caret $caretOffset is $selectedLine")
+ return selectedLine + 1
+ }
+
+ override fun collectClassesToTest(
+ project: Project,
+ classesToTest: MutableList,
+ caretOffset: Int,
+ maxPolymorphismDepth: Int,
+ ) {
+ val cutPsiClass = getSurroundingClass(caretOffset) ?: return
+ // will be null for the top level function
+ var currentPsiClass = cutPsiClass
+ for (index in 0 until maxPolymorphismDepth) {
+ if (!classesToTest.contains(currentPsiClass)) {
+ classesToTest.add(currentPsiClass)
+ }
+
+ if (currentPsiClass.superClass == null ||
+ currentPsiClass.superClass!!.qualifiedName.startsWith("kotlin.")
+ ) {
+ break
+ }
+ currentPsiClass = currentPsiClass.superClass!!
+ }
+ log.info("There are ${classesToTest.size} classes to test")
+ }
+
+ override fun getInterestingPsiClassesWithQualifiedNames(
+ project: Project,
+ classesToTest: List,
+ polyDepthReducing: Int,
+ maxInputParamsDepth: Int,
+ ): MutableSet {
+ val interestingPsiClasses: MutableSet = mutableSetOf()
+
+ var currentLevelClasses = mutableListOf().apply { addAll(classesToTest) }
+
+ repeat(maxInputParamsDepth) {
+ val tempListOfClasses = mutableSetOf()
+ currentLevelClasses.forEach { classIt ->
+ classIt.methods.forEach { methodIt ->
+ (methodIt as KotlinPsiMethodWrapper).parameterList?.parameters?.forEach { paramIt ->
+ KtPsiUtil.getClassIfParameterIsProperty(paramIt)?.let { typeIt ->
+ KotlinPsiClassWrapper(typeIt).let {
+ if (!it.qualifiedName.startsWith("kotlin.")) {
+ interestingPsiClasses.add(it)
+ }
+ }
+ }
+ }
+ }
+ }
+ currentLevelClasses = mutableListOf().apply { addAll(tempListOfClasses) }
+ interestingPsiClasses.addAll(tempListOfClasses)
+ }
+ log.info("There are ${interestingPsiClasses.size} interesting psi classes")
+ return interestingPsiClasses.toMutableSet()
+ }
+
+ override fun getInterestingPsiClassesWithQualifiedNames(
+ cut: PsiClassWrapper?,
+ psiMethod: PsiMethodWrapper,
+ ): MutableSet {
+ val interestingPsiClasses =
+ cut?.getInterestingPsiClassesWithQualifiedNames(psiMethod)
+ ?: (psiMethod as KotlinPsiMethodWrapper).getInterestingPsiClassesWithQualifiedNames()
+ log.info("There are ${interestingPsiClasses.size} interesting psi classes from method ${psiMethod.methodDescriptor}")
+ return interestingPsiClasses
+ }
+
+ override fun getCurrentListOfCodeTypes(e: AnActionEvent): List {
+ val result: ArrayList = arrayListOf()
+ val caret: Caret =
+ e.dataContext.getData(CommonDataKeys.CARET)?.caretModel?.primaryCaret ?: return result
+
+ val ktClass = getSurroundingClass(caret.offset)
+ val ktFunction = getSurroundingMethod(caret.offset)
+ val line: Int? = getSurroundingLineNumber(caret.offset)
+
+ ktClass?.let { result.add(CodeType.CLASS to getClassHTMLDisplayName(it)) }
+ ktFunction?.let { result.add(CodeType.METHOD to getMethodHTMLDisplayName(it)) }
+ line?.let { result.add(CodeType.LINE to getLineHTMLDisplayName(it)) }
+
+ log.info(
+ "The test can be generated for: \n " +
+ " 1) Class ${ktClass?.qualifiedName ?: "no class"} \n" +
+ " 2) Method ${ktFunction?.name ?: "no method"} \n" +
+ " 3) Line $line",
+ )
+
+ return result
+ }
+
+ override fun getPackageName() = (psiFile as KtFile).packageFqName.asString()
+
+ override fun getModuleFromPsiFile() = ModuleUtilCore.findModuleForFile(psiFile.virtualFile, psiFile.project)!!
+
+ override fun getDocumentFromPsiFile() = psiFile.fileDocument
+
+ override fun getLineHTMLDisplayName(line: Int) = "line $line"
+
+ override fun getClassHTMLDisplayName(psiClass: PsiClassWrapper): String =
+ "${psiClass.classType.representation} ${psiClass.qualifiedName}"
+
+ override fun getMethodHTMLDisplayName(psiMethod: PsiMethodWrapper): String {
+ psiMethod as KotlinPsiMethodWrapper
+ return when {
+ psiMethod.isTopLevelFunction -> "top-level function ${psiMethod.name}"
+ psiMethod.isSecondaryConstructor -> "secondary constructor"
+ psiMethod.isPrimaryConstructor -> "constructor"
+ psiMethod.isDefaultMethod -> "default method ${psiMethod.name}"
+ else -> "method ${psiMethod.name}"
+ }
+ }
+}
diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelperProvider.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelperProvider.kt
new file mode 100644
index 000000000..d9be382ad
--- /dev/null
+++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelperProvider.kt
@@ -0,0 +1,8 @@
+package org.jetbrains.research.testspark.kotlin
+
+import com.intellij.psi.PsiFile
+import org.jetbrains.research.testspark.langwrappers.PsiHelperProvider
+
+class KotlinPsiHelperProvider : PsiHelperProvider {
+ override fun getPsiHelper(file: PsiFile) = KotlinPsiHelper(file)
+}
diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt
new file mode 100644
index 000000000..93f39d6ba
--- /dev/null
+++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt
@@ -0,0 +1,144 @@
+package org.jetbrains.research.testspark.kotlin
+
+import com.intellij.psi.PsiDocumentManager
+import com.intellij.psi.PsiFile
+import com.intellij.psi.util.PsiTreeUtil
+import com.intellij.psi.util.parentOfType
+import org.jetbrains.kotlin.idea.refactoring.isInterfaceClass
+import org.jetbrains.kotlin.psi.KtClass
+import org.jetbrains.kotlin.psi.KtClassOrObject
+import org.jetbrains.kotlin.psi.KtFunction
+import org.jetbrains.kotlin.psi.KtPrimaryConstructor
+import org.jetbrains.kotlin.psi.KtSecondaryConstructor
+import org.jetbrains.kotlin.psi.KtTypeReference
+import org.jetbrains.kotlin.psi.psiUtil.containingClassOrObject
+import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper
+import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper
+
+class KotlinPsiMethodWrapper(val psiFunction: KtFunction) : PsiMethodWrapper {
+
+ override val name: String get() = psiFunction.name!!
+
+ override val text: String? = psiFunction.text
+
+ override val containingClass: PsiClassWrapper? = psiFunction.run {
+ parentOfType()?.let { KotlinPsiClassWrapper(it) }
+ }
+
+ override val containingFile: PsiFile = psiFunction.containingFile
+
+ override val methodDescriptor: String
+ get() = psiFunction.run {
+ val parameterTypes = valueParameters.joinToString("") { generateFieldType(it.typeReference) }
+ val returnType = generateReturnDescriptor(psiFunction)
+ return "$name($parameterTypes)$returnType"
+ }
+
+ override val signature: String
+ get() = buildSignature(psiFunction)
+
+ val parameterList = psiFunction.valueParameterList
+
+ val isPrimaryConstructor: Boolean = psiFunction is KtPrimaryConstructor
+
+ val isSecondaryConstructor: Boolean = psiFunction is KtSecondaryConstructor
+
+ val isTopLevelFunction: Boolean = psiFunction.containingClassOrObject == null
+
+ val isDefaultMethod: Boolean = psiFunction.run {
+ val containingClass = PsiTreeUtil.getParentOfType(this, KtClassOrObject::class.java)
+ val containingInterface = containingClass?.isInterfaceClass()
+ // ensure that the function is a non-abstract method defined in an interface
+ name != "" && // function is not a constructor
+ bodyExpression != null && // function has an implementation
+ containingInterface == true // function is defined within an interface
+ }
+
+ override fun containsLine(lineNumber: Int): Boolean {
+ val psiFile = psiFunction.containingFile
+ val document = PsiDocumentManager.getInstance(psiFile.project).getDocument(psiFile) ?: return false
+ val textRange = psiFunction.textRange
+ // increase by one is necessary due to different start of numbering
+ val startLine = document.getLineNumber(textRange.startOffset) + 1
+ // increase by one is necessary due to different start of numbering
+ val endLine = document.getLineNumber(textRange.endOffset) + 1
+ return lineNumber in startLine..endLine
+ }
+
+ /**
+ * Returns a set of `PsiClassWrapper` instances for non-standard Kotlin classes referenced by the
+ * parameters of the current function.
+ *
+ * @return A mutable set of `PsiClassWrapper` instances representing non-standard Kotlin classes.
+ */
+ fun getInterestingPsiClassesWithQualifiedNames(): MutableSet {
+ val interestingPsiClasses = mutableSetOf()
+
+ psiFunction.valueParameters.forEach { parameter ->
+ val typeReference = parameter.typeReference
+ val psiClass = PsiTreeUtil.getParentOfType(typeReference, KtClass::class.java)
+ if (psiClass != null && psiClass.fqName != null && !psiClass.fqName.toString().startsWith("kotlin.")) {
+ interestingPsiClasses.add(KotlinPsiClassWrapper(psiClass))
+ }
+ }
+
+ return interestingPsiClasses
+ }
+
+ /**
+ * Generates the return descriptor for a method.
+ *
+ * @param psiFunction the function
+ * @return the return descriptor
+ */
+ private fun generateReturnDescriptor(psiFunction: KtFunction): String {
+ val returnType = psiFunction.typeReference?.text ?: "Unit"
+ return generateFieldType(returnType)
+ }
+
+ /**
+ * Generates the field descriptor for a type.
+ *
+ * @param typeReference the type reference to generate the descriptor for
+ * @return the field descriptor
+ */
+ private fun generateFieldType(typeReference: KtTypeReference?): String {
+ val type = typeReference?.text ?: "Unit"
+ return generateFieldType(type)
+ }
+
+ /**
+ * Generates the field descriptor for a type.
+ * https://docs.oracle.com/javase/specs/jvms/se7/html/jvms-4.html#jvms-4.3
+ *
+ * @param type the type to generate the descriptor for
+ * @return the field descriptor
+ */
+ private fun generateFieldType(type: String): String {
+ return when (type) {
+ "Int" -> "I"
+ "Long" -> "J"
+ "Float" -> "F"
+ "Double" -> "D"
+ "Boolean" -> "Z"
+ "Byte" -> "B"
+ "Char" -> "C"
+ "Short" -> "S"
+ "Unit" -> "V"
+ else -> "L${type.replace('.', '/')};"
+ }
+ }
+
+ companion object {
+ /**
+ * Builds a signature for a given Kotlin function by extracting the method body portion.
+ *
+ * @param function The Kotlin function to build the signature for.
+ * @return The signature of the function.
+ */
+ fun buildSignature(function: KtFunction) = function.run {
+ val bodyStart = bodyExpression?.startOffsetInParent ?: textLength
+ text.substring(0, bodyStart).replace('\n', ' ').trim()
+ }
+ }
+}
diff --git a/kotlin/src/main/resources/META-INF/testgenie-kotlin.xml b/kotlin/src/main/resources/META-INF/testgenie-kotlin.xml
new file mode 100644
index 000000000..6909a5b52
--- /dev/null
+++ b/kotlin/src/main/resources/META-INF/testgenie-kotlin.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/langwrappers/build.gradle.kts b/langwrappers/build.gradle.kts
new file mode 100644
index 000000000..16c6cebd1
--- /dev/null
+++ b/langwrappers/build.gradle.kts
@@ -0,0 +1,43 @@
+plugins {
+ kotlin("jvm")
+ id("org.jetbrains.intellij.platform")
+}
+
+repositories {
+ mavenCentral()
+ intellijPlatform {
+ defaultRepositories()
+ }
+}
+
+dependencies {
+
+ intellijPlatform {
+ create(rootProject.properties["platformType"].toString(), rootProject.properties["platformVersion"].toString())
+ // Plugin Dependencies. Uses `platformPlugins` property from the gradle.properties file.
+ bundledPlugins(listOf("com.intellij.java"))
+
+ instrumentationTools()
+ }
+ implementation(kotlin("stdlib"))
+
+ implementation(project(":core"))
+}
+
+intellijPlatform {
+ pluginConfiguration {
+ rootProject.properties["platformVersion"]?.let { version = it.toString() }
+ }
+// apply(plugin = "java")
+// // Apply more plugins if necessary
+// apply(plugin = "kotlin")
+}
+
+tasks.named("verifyPlugin") { enabled = false }
+tasks.named("runIde") { enabled = false }
+tasks.named("prepareJarSearchableOptions") { enabled = false }
+tasks.named("publishPlugin") { enabled = false }
+
+kotlin {
+ jvmToolchain(rootProject.properties["jvmToolchainVersion"].toString().toInt())
+}
diff --git a/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/LanguageClassTextExtractor.kt b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/LanguageClassTextExtractor.kt
new file mode 100644
index 000000000..0982b9ced
--- /dev/null
+++ b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/LanguageClassTextExtractor.kt
@@ -0,0 +1,7 @@
+package org.jetbrains.research.testspark.langwrappers
+
+import com.intellij.psi.PsiFile
+
+interface LanguageClassTextExtractor {
+ fun extract(file: PsiFile, classText: String, packagePattern: Regex, importPattern: Regex): String
+}
diff --git a/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt
new file mode 100644
index 000000000..3f7f1d0c8
--- /dev/null
+++ b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt
@@ -0,0 +1,233 @@
+package org.jetbrains.research.testspark.langwrappers
+
+import com.intellij.openapi.actionSystem.AnActionEvent
+import com.intellij.openapi.editor.Document
+import com.intellij.openapi.project.Project
+import com.intellij.openapi.vfs.VirtualFile
+import com.intellij.psi.PsiFile
+import org.jetbrains.research.testspark.core.data.ClassType
+import org.jetbrains.research.testspark.core.test.SupportedLanguage
+import org.jetbrains.research.testspark.core.test.data.CodeType
+
+typealias CodeTypeDisplayName = Pair
+
+/**
+ * Interface representing a wrapper for PSI methods,
+ * providing a common API to handle method-related data for different languages.
+ *
+ * @property name The name of a method
+ * @property methodDescriptor Human-readable method signature
+ * @property text The text of the function
+ * @property containingClass Class where the method is located
+ * @property containingFile File where the method is located
+ * */
+interface PsiMethodWrapper {
+ val name: String
+ val methodDescriptor: String
+ val signature: String
+ val text: String?
+ val containingClass: PsiClassWrapper?
+ val containingFile: PsiFile?
+
+ /**
+ * Checks if the given line number is within the range of the specified PsiMethod.
+ *
+ * @param lineNumber The line number to check.
+ * @return `true` if the line number is within the range of the method, `false` otherwise.
+ */
+ fun containsLine(lineNumber: Int): Boolean
+}
+
+/**
+ * Interface representing a wrapper for PSI classes,
+ * providing a common API to handle class-related data for different languages.
+ * @property name The name of a class
+ * @property qualifiedName The qualified name of the class.
+ * @property text The text of the class.
+ * @property methods All methods in the class
+ * @property allMethods All methods in the class and all its superclasses
+ * @property constructorSignatures The signatures of all constructors in the class
+ * @property superClass The superclass of the class
+ * @property virtualFile Virtual file where the class is located
+ * @property containingFile File where the method is located
+ * @property fullText The source code of the class (with package and imports).
+ * @property classType The type of the class
+ * @property rBrace The offset of the closing brace
+ * */
+interface PsiClassWrapper {
+ val name: String
+ val qualifiedName: String
+ val text: String?
+ val methods: List
+ val allMethods: List
+ val constructorSignatures: List
+ val superClass: PsiClassWrapper?
+ val virtualFile: VirtualFile
+ val containingFile: PsiFile
+ val fullText: String
+ val classType: ClassType
+ val rBrace: Int?
+
+ /**
+ * Searches for subclasses of the current class within the given project.
+ *
+ * @param project The project within which to search for subclasses.
+ * @return A collection of found subclasses.
+ */
+ fun searchSubclasses(project: Project): Collection
+
+ /**
+ * Retrieves a set of interesting PSI classes based on a given method.
+ *
+ * @param psiMethod The method to use for finding interesting PSI classes.
+ * @return A mutable set of interesting PSI classes.
+ */
+ fun getInterestingPsiClassesWithQualifiedNames(psiMethod: PsiMethodWrapper): MutableSet
+}
+
+/**
+ * Interface that declares all the methods needed for parsing and
+ * handling the PSI (Program Structure Interface) for different languages.
+ */
+interface PsiHelper {
+ val language: SupportedLanguage
+
+ /**
+ * Checks if a code construct is valid for unit test generation at the given caret offset.
+ *
+ * @param e The AnActionEvent representing the current action event.
+ * @return `true` if a code construct is valid for unit test generation at the caret offset, `false` otherwise.
+ */
+ fun availableForGeneration(e: AnActionEvent): Boolean
+
+ /**
+ * Returns the surrounding PsiClass object based on the caret position within the specified PsiFile.
+ * The surrounding class is determined by finding the PsiClass objects within the PsiFile and checking
+ * if the caret is within any of them.
+ *
+ * @param caretOffset The offset of the caret position within the PsiFile.
+ * @return The surrounding `PsiClass` object if found, `null` otherwise.
+ */
+ fun getSurroundingClass(caretOffset: Int): PsiClassWrapper?
+
+ /**
+ * Returns the surrounding method of the given PSI file based on the caret offset.
+ *
+ * @param caretOffset The caret offset within the PSI file.
+ * @return The surrounding method if found, otherwise null.
+ */
+ fun getSurroundingMethod(caretOffset: Int): PsiMethodWrapper?
+
+ /**
+ * Returns the line number of the selected line where the caret is positioned.
+ *
+ * The returned line number is **1-based**.
+ *
+ * @param caretOffset The caret offset within the PSI file.
+ * @return The line number of the selected line, otherwise null.
+ */
+ fun getSurroundingLineNumber(caretOffset: Int): Int?
+
+ /**
+ * Retrieves a set of interesting PsiClasses based on a given project,
+ * a list of classes to test, and a depth reducing factor.
+ *
+ * @param project The project in which to search for interesting classes.
+ * @param classesToTest The list of classes to test for interesting PsiClasses.
+ * @param polyDepthReducing The factor to reduce the polymorphism depth.
+ * @return The set of interesting PsiClasses found during the search.
+ */
+ fun getInterestingPsiClassesWithQualifiedNames(
+ project: Project,
+ classesToTest: List,
+ polyDepthReducing: Int,
+ maxInputParamsDepth: Int,
+ ): MutableSet
+
+ /**
+ * Returns a set of interesting PsiClasses based on the given PsiMethod.
+ *
+ * @param cut The class under test.
+ * @param psiMethod The PsiMethod for which to find interesting PsiClasses.
+ * @return A mutable set of interesting PsiClasses.
+ */
+ fun getInterestingPsiClassesWithQualifiedNames(
+ cut: PsiClassWrapper?,
+ psiMethod: PsiMethodWrapper,
+ ): MutableSet
+
+ /**
+ * Gets the current list of code types based on the given AnActionEvent.
+ *
+ * @param e The AnActionEvent representing the current action event.
+ * @return An array containing the current code types. If no caret or PSI file is found, an empty array is returned.
+ * The array contains the class display name, method display name (if present), and the line number (if present).
+ * The line number is prefixed with "Line".
+ */
+ fun getCurrentListOfCodeTypes(e: AnActionEvent): List
+
+ /**
+ * Helper for generating method descriptors for methods.
+ *
+ * @param psiMethod The method to extract the descriptor from.
+ * @return The method descriptor.
+ */
+ fun generateMethodDescriptor(psiMethod: PsiMethodWrapper): String
+
+ /**
+ * Fills the classesToTest variable with the data about the classes to test.
+ *
+ * @param project The project in which to collect classes to test.
+ * @param classesToTest The list of classes to test.
+ * @param caretOffset The caret offset in the file.
+ * @param maxPolymorphismDepth Check if cut has any user-defined superclass
+ */
+ fun collectClassesToTest(
+ project: Project,
+ classesToTest: MutableList,
+ caretOffset: Int,
+ maxPolymorphismDepth: Int,
+ )
+
+ /**
+ * Get the package name of the file.
+ */
+ fun getPackageName(): String
+
+ /**
+ * Get the module of the file.
+ */
+ fun getModuleFromPsiFile(): com.intellij.openapi.module.Module
+
+ /**
+ * Get the module of the file.
+ */
+ fun getDocumentFromPsiFile(): Document?
+
+ /**
+ * Gets the display line number.
+ * This is used when displaying the name of a method in the GenerateTestsActionMethod menu entry.
+ *
+ * @param line The line number.
+ * @return The display name of the line.
+ */
+ fun getLineHTMLDisplayName(line: Int): String
+
+ /**
+ * Gets the display name of a class.
+ * This is used when displaying the name of a class in the GenerateTestsActionClass menu entry.
+ *
+ * @param psiClass The PSI class of interest.
+ * @return The display name of the PSI class.
+ */
+ fun getClassHTMLDisplayName(psiClass: PsiClassWrapper): String
+
+ /**
+ * Gets the display name of a method, depending on if it is a (default) constructor or a normal method.
+ * This is used when displaying the name of a method in the GenerateTestsActionMethod menu entry.
+ *
+ * @param psiMethod The PSI method of interest.
+ * @return The display name of the PSI method.
+ */
+ fun getMethodHTMLDisplayName(psiMethod: PsiMethodWrapper): String
+}
diff --git a/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiHelperProvider.kt b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiHelperProvider.kt
new file mode 100644
index 000000000..5fa78eef2
--- /dev/null
+++ b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiHelperProvider.kt
@@ -0,0 +1,39 @@
+package org.jetbrains.research.testspark.langwrappers
+
+import com.intellij.lang.LanguageExtension
+import com.intellij.psi.PsiFile
+
+/**
+ * This is the provider interface for a PsiHelper. The PsiHelper allows for
+ * custom handling or manipulating PSI (Program Structure Interface) elements.
+ */
+interface PsiHelperProvider {
+
+ /**
+ * Get a PsiHelper for the given file.
+ *
+ * @param file the PsiFile to get the PsiHelper for.
+ * @return a PsiHelper object.
+ */
+ fun getPsiHelper(file: PsiFile): PsiHelper
+
+ companion object {
+ // An extension point that allows for custom PsiHelperProviders to be registered for different languages
+ private val EP = LanguageExtension("org.jetbrains.research.testgenie.psiHelperProvider")
+
+ /**
+ * Retrieves a PsiHelper for the given file based on its language.
+ *
+ * It attempts to get the PsiHelperProvider registered for the specified language.
+ * If none exists, the method will return null.
+ * Finally, it uses this PsiHelperProvider to get a PsiHelper for the file.
+ *
+ * @param file The PsiFile to get the PsiHelper for.
+ * @return The PsiHelper for the file or null if it couldn't be obtained.
+ */
+ fun getPsiHelper(file: PsiFile): PsiHelper? {
+ val language = file.language
+ return EP.forLanguage(language)?.getPsiHelper(file)
+ }
+ }
+}
diff --git a/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/strategies/JavaKotlinClassTextExtractor.kt b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/strategies/JavaKotlinClassTextExtractor.kt
new file mode 100644
index 000000000..643cdee34
--- /dev/null
+++ b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/strategies/JavaKotlinClassTextExtractor.kt
@@ -0,0 +1,39 @@
+package org.jetbrains.research.testspark.langwrappers.strategies
+
+import com.intellij.psi.PsiFile
+import org.jetbrains.research.testspark.langwrappers.LanguageClassTextExtractor
+
+/**
+Direct implementor for the Java and Kotlin PsiWrappers
+ */
+class JavaKotlinClassTextExtractor : LanguageClassTextExtractor {
+
+ override fun extract(
+ file: PsiFile,
+ classText: String,
+ packagePattern: Regex,
+ importPattern: Regex,
+ ): String {
+ var fullText = ""
+ val fileText = file.text
+
+ // get package
+ packagePattern.findAll(fileText, 0).map {
+ it.groupValues[0]
+ }.forEach {
+ fullText += "$it\n\n"
+ }
+
+ // get imports
+ importPattern.findAll(fileText, 0).map {
+ it.groupValues[0]
+ }.forEach {
+ fullText += "$it\n"
+ }
+
+ // Add class code
+ fullText += classText
+
+ return fullText
+ }
+}
diff --git a/lib/jacocoagent.jar b/lib/jacocoagent.jar
index 7799b3c02..e3c7d7d82 100644
Binary files a/lib/jacocoagent.jar and b/lib/jacocoagent.jar differ
diff --git a/lib/jacococli.jar b/lib/jacococli.jar
index 27a66d56e..4a1c11b5d 100644
Binary files a/lib/jacococli.jar and b/lib/jacococli.jar differ
diff --git a/lib/opentest4j-1.1.1.jar b/lib/opentest4j-1.1.1.jar
new file mode 100644
index 000000000..3f355292e
Binary files /dev/null and b/lib/opentest4j-1.1.1.jar differ
diff --git a/readme-images/pngs/k2-mode/disable-k2.png b/readme-images/pngs/k2-mode/disable-k2.png
new file mode 100644
index 000000000..96f1b0109
Binary files /dev/null and b/readme-images/pngs/k2-mode/disable-k2.png differ
diff --git a/runTestSparkHeadless.sh b/runTestSparkHeadless.sh
new file mode 100644
index 000000000..a5b64f474
--- /dev/null
+++ b/runTestSparkHeadless.sh
@@ -0,0 +1,31 @@
+#!/usr/bin/env bash
+
+# https://stackoverflow.com/a/246128
+DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null && pwd)"
+if uname -s | grep -iq cygwin; then
+ DIR=$(cygpath -w "$DIR")
+ PWD=$(cygpath -w "$PWD")
+fi
+echo $DIR
+
+echo "Provided arguments are $@"
+
+if [ $# -ne "12" ]; then
+ echo "$# arguments provided, expected 12 arguments in the following order:
+ 1) Path to the root directory of the project under test (ProjectPath)
+ 2) Path to the target file (.java file) (it MUST be relative to the ProjectPath)
+ 3) Qualified name of the class under test (i.e., .)
+ 4) Classpaths containing the compiled project (separated by ':')
+ 5) Version of JUnit testing framework (either 4 or 5)
+ 6) Model name (e.g., GPT-4)
+ 7) Grazie token
+ 8) Filepath to a txt-file containing prompt template
+ 9) Output directory
+ 10) Enable/disable coverage computation ('true' or 'false')
+ 11) Space username
+ 12) Space password"
+ exit 1
+fi
+
+echo -Proot="$1" -Pfile="$2" -Pcut="$3" -Pcp="$4" -Pjunitv="$5" -Pllm="$6" -Ptoken="$7" -Pprompt="$8" -Pout="$9" -PenableCoverage="${10}" -Dspace.username="${11}" -Dspace.pass="${12}"
+"$DIR/gradlew" -p "$DIR" headless -Proot="$1" -Pfile="$2" -Pcut="$3" -Pcp="$4" -Pjunitv="$5" -Pllm="$6" -Ptoken="$7" -Pprompt="$8" -Pout="$9" -PenableCoverage="${10}" -Dspace.username="${11}" -Dspace.pass="${12}"
diff --git a/settings.gradle.kts b/settings.gradle.kts
index a5261e06c..ca94f25ec 100644
--- a/settings.gradle.kts
+++ b/settings.gradle.kts
@@ -4,3 +4,6 @@ plugins {
rootProject.name = "TestSpark"
include("JUnitRunner")
include("core")
+include("langwrappers")
+include("kotlin")
+include("java")
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt
index b5e4ed008..6304b258d 100644
--- a/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt
+++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt
@@ -7,22 +7,29 @@ import com.intellij.openapi.actionSystem.AnAction
import com.intellij.openapi.actionSystem.AnActionEvent
import com.intellij.openapi.actionSystem.CommonDataKeys
import com.intellij.openapi.project.Project
-import com.intellij.psi.PsiFile
import com.intellij.ui.components.JBScrollPane
import com.intellij.util.ui.FormBuilder
-import org.jetbrains.research.testspark.actions.evosuite.EvoSuitePanelFactory
-import org.jetbrains.research.testspark.actions.llm.LLMSampleSelectorFactory
-import org.jetbrains.research.testspark.actions.llm.LLMSetupPanelFactory
-import org.jetbrains.research.testspark.actions.template.PanelFactory
-import org.jetbrains.research.testspark.bundles.TestSparkBundle
-import org.jetbrains.research.testspark.bundles.TestSparkLabelsBundle
+import org.jetbrains.research.testspark.actions.controllers.TestGenerationController
+import org.jetbrains.research.testspark.actions.controllers.VisibilityController
+import org.jetbrains.research.testspark.actions.evosuite.EvoSuitePanelBuilder
+import org.jetbrains.research.testspark.actions.llm.LLMSampleSelectorBuilder
+import org.jetbrains.research.testspark.actions.llm.LLMSetupPanelBuilder
+import org.jetbrains.research.testspark.actions.template.PanelBuilder
+import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle
+import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle
+import org.jetbrains.research.testspark.core.test.data.CodeType
+import org.jetbrains.research.testspark.display.TestSparkDisplayManager
import org.jetbrains.research.testspark.display.TestSparkIcons
-import org.jetbrains.research.testspark.helpers.getCurrentListOfCodeTypes
-import org.jetbrains.research.testspark.services.SettingsApplicationService
-import org.jetbrains.research.testspark.settings.SettingsApplicationState
-import org.jetbrains.research.testspark.tools.Manager
+import org.jetbrains.research.testspark.langwrappers.PsiHelper
+import org.jetbrains.research.testspark.langwrappers.PsiHelperProvider
+import org.jetbrains.research.testspark.services.EvoSuiteSettingsService
+import org.jetbrains.research.testspark.services.LLMSettingsService
+import org.jetbrains.research.testspark.settings.evosuite.EvoSuiteSettingsState
+import org.jetbrains.research.testspark.settings.llm.LLMSettingsState
+import org.jetbrains.research.testspark.tools.TestsExecutionResultManager
import org.jetbrains.research.testspark.tools.evosuite.EvoSuite
import org.jetbrains.research.testspark.tools.llm.Llm
+import org.jetbrains.research.testspark.tools.template.Tool
import java.awt.BorderLayout
import java.awt.CardLayout
import java.awt.Dimension
@@ -44,11 +51,12 @@ import javax.swing.JRadioButton
* It creates a dialog wrapper and displays it when the associated action is performed.
*/
class TestSparkAction : AnAction() {
- class VisibilityController {
- var isVisible = false
- }
-
+ // Controllers
private val visibilityController = VisibilityController()
+ private val testGenerationController = TestGenerationController()
+
+ private val testSparkDisplayManager = TestSparkDisplayManager()
+ private val testsExecutionResultManager = TestsExecutionResultManager()
/**
* Handles the action performed event.
@@ -60,16 +68,24 @@ class TestSparkAction : AnAction() {
* This parameter is required.
*/
override fun actionPerformed(e: AnActionEvent) {
- TestSparkActionWindow(e, visibilityController)
+ TestSparkActionWindow(e, visibilityController, testGenerationController, testSparkDisplayManager, testsExecutionResultManager)
}
/**
* Updates the state of the action based on the provided event.
*
- * @param e the AnActionEvent object representing the event
+ * @param e `AnActionEvent` object representing the event
*/
override fun update(e: AnActionEvent) {
- e.presentation.isEnabled = getCurrentListOfCodeTypes(e) != null
+ val file = e.dataContext.getData(CommonDataKeys.PSI_FILE)
+
+ if (file == null) {
+ e.presentation.isEnabledAndVisible = false
+ return
+ }
+
+ val psiHelper = PsiHelperProvider.getPsiHelper(file)
+ e.presentation.isEnabledAndVisible = (psiHelper != null) && psiHelper.availableForGeneration(e)
}
/**
@@ -77,29 +93,48 @@ class TestSparkAction : AnAction() {
*
* @property e The AnActionEvent object.
*/
- class TestSparkActionWindow(e: AnActionEvent, private val visibilityController: VisibilityController) :
+ class TestSparkActionWindow(
+ private val e: AnActionEvent,
+ private val visibilityController: VisibilityController,
+ private val testGenerationController: TestGenerationController,
+ private val testSparkDisplayManager: TestSparkDisplayManager,
+ private val testsExecutionResultManager: TestsExecutionResultManager,
+ ) :
JFrame("TestSpark") {
private val project: Project = e.project!!
- private val settingsState: SettingsApplicationState
- get() = project.getService(SettingsApplicationService::class.java).state
+
+ private val llmSettingsState: LLMSettingsState
+ get() = project.getService(LLMSettingsService::class.java).state
+ private val evoSuiteSettingsState: EvoSuiteSettingsState
+ get() = project.getService(EvoSuiteSettingsService::class.java).state
private val llmButton = JRadioButton("${Llm().name}")
private val evoSuiteButton = JRadioButton("${EvoSuite().name}")
private val testGeneratorButtonGroup = ButtonGroup()
- private val codeTypes = getCurrentListOfCodeTypes(e)!!
- private val psiFile: PsiFile = e.dataContext.getData(CommonDataKeys.PSI_FILE)!!
+
+ private val psiHelper: PsiHelper
+ get() {
+ val file = e.dataContext.getData(CommonDataKeys.PSI_FILE)!!
+ val psiHelper = PsiHelperProvider.getPsiHelper(file)
+ if (psiHelper == null) {
+ // TODO exception
+ }
+ return psiHelper!!
+ }
+
+ private val codeTypes = psiHelper.getCurrentListOfCodeTypes(e)
private val caretOffset: Int = e.dataContext.getData(CommonDataKeys.CARET)?.caretModel?.primaryCaret!!.offset
private val fileUrl = e.dataContext.getData(CommonDataKeys.VIRTUAL_FILE)!!.presentableUrl
- private val codeTypeButtons: MutableList = mutableListOf()
+
+ private val codeTypeButtons: MutableList> = mutableListOf()
private val codeTypeButtonGroup = ButtonGroup()
- private val nextButton = JButton(TestSparkLabelsBundle.defaultValue("next"))
+ private val nextButton = JButton(PluginLabelsBundle.get("next"))
private val cardLayout = CardLayout()
-
- private val llmSetupPanelFactory = LLMSetupPanelFactory(e, project)
- private val llmSampleSelectorFactory = LLMSampleSelectorFactory(project)
- private val evoSuitePanelFactory = EvoSuitePanelFactory(project)
+ private val llmSetupPanelFactory = LLMSetupPanelBuilder(e, project)
+ private val llmSampleSelectorFactory = LLMSampleSelectorBuilder(project, psiHelper.language)
+ private val evoSuitePanelFactory = EvoSuitePanelBuilder(project)
init {
if (!visibilityController.isVisible) {
@@ -135,19 +170,19 @@ class TestSparkAction : AnAction() {
NotificationGroupManager.getInstance()
.getNotificationGroup("Generation Error")
.createNotification(
- TestSparkBundle.message("generationWindowWarningTitle"),
- TestSparkBundle.message("generationWindowWarningMessage"),
+ PluginMessagesBundle.get("generationWindowWarningTitle"),
+ PluginMessagesBundle.get("generationWindowWarningMessage"),
NotificationType.WARNING,
)
.notify(e.project)
}
}
- private fun createCardPanel(toolPanelFactory: PanelFactory): JPanel {
+ private fun createCardPanel(toolPanelBuilder: PanelBuilder): JPanel {
val cardPanel = JPanel(BorderLayout())
- cardPanel.add(toolPanelFactory.getTitlePanel(), BorderLayout.NORTH)
- cardPanel.add(toolPanelFactory.getMiddlePanel(), BorderLayout.CENTER)
- cardPanel.add(toolPanelFactory.getBottomPanel(), BorderLayout.SOUTH)
+ cardPanel.add(toolPanelBuilder.getTitlePanel(), BorderLayout.NORTH)
+ cardPanel.add(toolPanelBuilder.getMiddlePanel(), BorderLayout.CENTER)
+ cardPanel.add(toolPanelBuilder.getBottomPanel(), BorderLayout.SOUTH)
return cardPanel
}
@@ -174,16 +209,19 @@ class TestSparkAction : AnAction() {
testGeneratorPanel.add(llmButton)
testGeneratorPanel.add(evoSuiteButton)
- for (codeType in codeTypes) {
- val button = JRadioButton(codeType as String)
- codeTypeButtons.add(button)
+ for ((codeType, codeTypeName) in codeTypes) {
+ val button = JRadioButton(codeTypeName)
+ codeTypeButtons.add(codeType to button)
codeTypeButtonGroup.add(button)
}
val codesToTestPanel = JPanel()
codesToTestPanel.add(JLabel("Select the code type:"))
- if (codeTypeButtons.size == 1) codeTypeButtons[0].isSelected = true
- for (button in codeTypeButtons) codesToTestPanel.add(button)
+ if (codeTypeButtons.size == 1) {
+ // A single button is selected by default
+ codeTypeButtons[0].second.isSelected = true
+ }
+ for ((_, button) in codeTypeButtons) codesToTestPanel.add(button)
val middlePanel = FormBuilder.createFormBuilder()
.setFormLeftIndent(10)
@@ -229,7 +267,7 @@ class TestSparkAction : AnAction() {
updateNextButton()
}
- for (button in codeTypeButtons) {
+ for ((_, button) in codeTypeButtons) {
button.addActionListener {
llmSetupPanelFactory.setPromptEditorType(button.text)
updateNextButton()
@@ -237,9 +275,9 @@ class TestSparkAction : AnAction() {
}
nextButton.addActionListener {
- if (llmButton.isSelected && !settingsState.llmSetupCheckBoxSelected && !settingsState.provideTestSamplesCheckBoxSelected) {
+ if (llmButton.isSelected && !llmSettingsState.llmSetupCheckBoxSelected && !llmSettingsState.provideTestSamplesCheckBoxSelected) {
startLLMGeneration()
- } else if (llmButton.isSelected && !settingsState.llmSetupCheckBoxSelected) {
+ } else if (llmButton.isSelected && !llmSettingsState.llmSetupCheckBoxSelected) {
cardLayout.next(panel)
cardLayout.next(panel)
cardLayout.next(panel)
@@ -248,7 +286,7 @@ class TestSparkAction : AnAction() {
cardLayout.next(panel)
cardLayout.next(panel)
pack()
- } else if (evoSuiteButton.isSelected && !settingsState.evosuiteSetupCheckBoxSelected) {
+ } else if (evoSuiteButton.isSelected && !evoSuiteSettingsState.evosuiteSetupCheckBoxSelected) {
startEvoSuiteGeneration()
} else {
cardLayout.next(panel)
@@ -269,7 +307,7 @@ class TestSparkAction : AnAction() {
llmSetupPanelFactory.getFinishedButton().addActionListener {
llmSetupPanelFactory.applyUpdates()
- if (settingsState.provideTestSamplesCheckBoxSelected) {
+ if (llmSettingsState.provideTestSamplesCheckBoxSelected) {
cardLayout.next(panel)
} else {
startLLMGeneration()
@@ -281,7 +319,7 @@ class TestSparkAction : AnAction() {
}
llmSampleSelectorFactory.getBackButton().addActionListener {
- if (settingsState.llmSetupCheckBoxSelected) {
+ if (llmSettingsState.llmSetupCheckBoxSelected) {
cardLayout.previous(panel)
} else {
cardLayout.previous(panel)
@@ -302,35 +340,55 @@ class TestSparkAction : AnAction() {
}
}
- private fun startEvoSuiteGeneration() {
- val testSamplesCode = llmSampleSelectorFactory.getTestSamplesCode()
-
- if (codeTypeButtons[0].isSelected) {
- Manager.generateTestsForClassByEvoSuite(project, psiFile, caretOffset, fileUrl, testSamplesCode)
- } else if (codeTypeButtons[1].isSelected) {
- Manager.generateTestsForMethodByEvoSuite(project, psiFile, caretOffset, fileUrl, testSamplesCode)
- } else if (codeTypeButtons[2].isSelected) {
- Manager.generateTestsForLineByEvoSuite(project, psiFile, caretOffset, fileUrl, testSamplesCode)
+ private fun startUnitTestGenerationTool(tool: Tool) {
+ if (!testGenerationController.isGeneratorRunning(project)) {
+ val testSamplesCode = llmSampleSelectorFactory.getTestSamplesCode()
+
+ for ((codeType, button) in codeTypeButtons) {
+ if (button.isSelected) {
+ when (codeType) {
+ CodeType.CLASS -> tool.generateTestsForClass(
+ project,
+ psiHelper,
+ caretOffset,
+ fileUrl,
+ testSamplesCode,
+ testGenerationController,
+ testSparkDisplayManager,
+ testsExecutionResultManager,
+ )
+ CodeType.METHOD -> tool.generateTestsForMethod(
+ project,
+ psiHelper,
+ caretOffset,
+ fileUrl,
+ testSamplesCode,
+ testGenerationController,
+ testSparkDisplayManager,
+ testsExecutionResultManager,
+ )
+ CodeType.LINE -> tool.generateTestsForLine(
+ project,
+ psiHelper,
+ caretOffset,
+ fileUrl,
+ testSamplesCode,
+ testGenerationController,
+ testSparkDisplayManager,
+ testsExecutionResultManager,
+ )
+ }
+ break
+ }
+ }
}
visibilityController.isVisible = false
dispose()
}
- private fun startLLMGeneration() {
- val testSamplesCode = llmSampleSelectorFactory.getTestSamplesCode()
-
- if (codeTypeButtons[0].isSelected) {
- Manager.generateTestsForClassByLlm(project, psiFile, caretOffset, fileUrl, testSamplesCode)
- } else if (codeTypeButtons[1].isSelected) {
- Manager.generateTestsForMethodByLlm(project, psiFile, caretOffset, fileUrl, testSamplesCode)
- } else if (codeTypeButtons[2].isSelected) {
- Manager.generateTestsForLineByLlm(project, psiFile, caretOffset, fileUrl, testSamplesCode)
- }
-
- visibilityController.isVisible = false
- dispose()
- }
+ private fun startEvoSuiteGeneration() = startUnitTestGenerationTool(tool = EvoSuite())
+ private fun startLLMGeneration() = startUnitTestGenerationTool(tool = Llm())
/**
* Updates the state of the "Next" button based on the selected options.
@@ -341,18 +399,15 @@ class TestSparkAction : AnAction() {
*/
private fun updateNextButton() {
val isTestGeneratorButtonGroupSelected = llmButton.isSelected || evoSuiteButton.isSelected
- var isCodeTypeButtonGroupSelected = false
- for (button in codeTypeButtons) {
- isCodeTypeButtonGroupSelected = isCodeTypeButtonGroupSelected || button.isSelected
- }
+ val isCodeTypeButtonGroupSelected = codeTypeButtons.any { it.second.isSelected }
nextButton.isEnabled = isTestGeneratorButtonGroupSelected && isCodeTypeButtonGroupSelected
- if ((llmButton.isSelected && !settingsState.llmSetupCheckBoxSelected && !settingsState.provideTestSamplesCheckBoxSelected) ||
- (evoSuiteButton.isSelected && !settingsState.evosuiteSetupCheckBoxSelected)
+ if ((llmButton.isSelected && !llmSettingsState.llmSetupCheckBoxSelected && !llmSettingsState.provideTestSamplesCheckBoxSelected) ||
+ (evoSuiteButton.isSelected && !evoSuiteSettingsState.evosuiteSetupCheckBoxSelected)
) {
- nextButton.text = TestSparkLabelsBundle.defaultValue("ok")
+ nextButton.text = PluginLabelsBundle.get("ok")
} else {
- nextButton.text = TestSparkLabelsBundle.defaultValue("next")
+ nextButton.text = PluginLabelsBundle.get("next")
}
}
}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/controllers/TestGenerationController.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/controllers/TestGenerationController.kt
new file mode 100644
index 000000000..618adba20
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/controllers/TestGenerationController.kt
@@ -0,0 +1,73 @@
+package org.jetbrains.research.testspark.actions.controllers
+
+import com.intellij.notification.NotificationGroupManager
+import com.intellij.notification.NotificationType
+import com.intellij.openapi.actionSystem.AnAction
+import com.intellij.openapi.actionSystem.AnActionEvent
+import com.intellij.openapi.project.Project
+import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle
+import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor
+import org.jetbrains.research.testspark.core.monitor.ErrorMonitor
+import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator
+
+/**
+ * Manager used for monitoring the unit test generation process.
+ * It also limits TestSpark to generate tests only once at a time.
+ */
+class TestGenerationController {
+ var indicator: CustomProgressIndicator? = null
+
+ // errorMonitor is passed in many places in the project
+ // and reflects if any bug happened in the test generation process
+ val errorMonitor: ErrorMonitor = DefaultErrorMonitor()
+
+ /**
+ * Method to show notification that test generation is already running.
+ */
+ private fun showGenerationRunningNotification(project: Project) {
+ val terminateButton: AnAction = object : AnAction("Terminate") {
+ override fun actionPerformed(e: AnActionEvent) {
+ indicator?.stop()
+ errorMonitor.notifyErrorOccurrence()
+ }
+ }
+
+ val notification = NotificationGroupManager.getInstance()
+ .getNotificationGroup("Execution Error")
+ .createNotification(
+ PluginMessagesBundle.get("alreadyRunningNotificationTitle"),
+ PluginMessagesBundle.get("alreadyRunningTextNotificationText"),
+ NotificationType.WARNING,
+ )
+
+ notification.addAction(terminateButton)
+
+ notification.notify(project)
+ }
+
+ fun finished() {
+ if (indicator != null &&
+ indicator!!.isRunning()
+ ) {
+ indicator?.stop()
+ }
+ }
+
+ /**
+ * Check if generator is running.
+ *
+ * @return true if it is already running
+ */
+ fun isGeneratorRunning(project: Project): Boolean {
+ // If indicator is null, we have never initiated an indicator before and there is no running test generation
+ if (indicator == null) {
+ return false
+ }
+
+ if (indicator!!.isRunning()) {
+ showGenerationRunningNotification(project)
+ return true
+ }
+ return false
+ }
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/controllers/VisibilityController.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/controllers/VisibilityController.kt
new file mode 100644
index 000000000..f8461c762
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/controllers/VisibilityController.kt
@@ -0,0 +1,5 @@
+package org.jetbrains.research.testspark.actions.controllers
+
+class VisibilityController {
+ var isVisible = false
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/evosuite/EvoSuitePanelBuilder.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/evosuite/EvoSuitePanelBuilder.kt
new file mode 100644
index 000000000..2fc795536
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/evosuite/EvoSuitePanelBuilder.kt
@@ -0,0 +1,86 @@
+package org.jetbrains.research.testspark.actions.evosuite
+
+import com.intellij.openapi.project.Project
+import com.intellij.openapi.ui.ComboBox
+import com.intellij.ui.components.JBLabel
+import com.intellij.util.ui.FormBuilder
+import org.jetbrains.research.testspark.actions.template.PanelBuilder
+import org.jetbrains.research.testspark.bundles.evosuite.EvoSuiteLabelsBundle
+import org.jetbrains.research.testspark.bundles.evosuite.EvoSuiteSettingsBundle
+import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle
+import org.jetbrains.research.testspark.data.evosuite.ContentDigestAlgorithm
+import org.jetbrains.research.testspark.services.EvoSuiteSettingsService
+import org.jetbrains.research.testspark.settings.evosuite.EvoSuiteSettingsState
+import java.awt.Font
+import javax.swing.JButton
+import javax.swing.JLabel
+import javax.swing.JPanel
+import javax.swing.JTextField
+
+class EvoSuitePanelBuilder(private val project: Project) : PanelBuilder {
+ private val evoSuiteSettingsState: EvoSuiteSettingsState
+ get() = project.getService(EvoSuiteSettingsService::class.java).state
+
+ // init components
+ private var javaPathTextField = JTextField(30)
+ private var algorithmSelector = ComboBox(ContentDigestAlgorithm.entries.toTypedArray())
+ private val backEvoSuiteButton = JButton(PluginLabelsBundle.get("back"))
+ private val okEvoSuiteButton = JButton(PluginLabelsBundle.get("ok"))
+
+ override fun getTitlePanel(): JPanel {
+ val textTitle = JLabel(PluginLabelsBundle.get("evosuiteSetup"))
+ textTitle.font = Font("Monochrome", Font.BOLD, 20)
+
+ val titlePanel = JPanel()
+ titlePanel.add(textTitle)
+
+ return titlePanel
+ }
+
+ override fun getMiddlePanel(): JPanel {
+ javaPathTextField.toolTipText = EvoSuiteSettingsBundle.get("javaPath")
+ javaPathTextField.text = evoSuiteSettingsState.javaPath
+
+ algorithmSelector.setMinimumAndPreferredWidth(300)
+ algorithmSelector.selectedItem = evoSuiteSettingsState.algorithm
+
+ return FormBuilder.createFormBuilder()
+ .setFormLeftIndent(10)
+ .addLabeledComponent(
+ JBLabel(EvoSuiteLabelsBundle.get("javaPath")),
+ javaPathTextField,
+ 10,
+ false,
+ )
+ .addLabeledComponent(
+ JBLabel(EvoSuiteLabelsBundle.get("defaultSearch")),
+ algorithmSelector,
+ 10,
+ false,
+ )
+ .panel
+ }
+
+ override fun getBottomPanel(): JPanel {
+ val bottomButtons = JPanel()
+
+ backEvoSuiteButton.isOpaque = false
+ backEvoSuiteButton.isContentAreaFilled = false
+ bottomButtons.add(backEvoSuiteButton)
+
+ okEvoSuiteButton.isOpaque = false
+ okEvoSuiteButton.isContentAreaFilled = false
+ bottomButtons.add(okEvoSuiteButton)
+
+ return bottomButtons
+ }
+
+ override fun getBackButton() = backEvoSuiteButton
+
+ override fun getFinishedButton() = okEvoSuiteButton
+
+ override fun applyUpdates() {
+ evoSuiteSettingsState.javaPath = javaPathTextField.text
+ evoSuiteSettingsState.algorithm = algorithmSelector.selectedItem!! as ContentDigestAlgorithm
+ }
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/evosuite/EvoSuitePanelFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/evosuite/EvoSuitePanelFactory.kt
deleted file mode 100644
index 4cd109017..000000000
--- a/src/main/kotlin/org/jetbrains/research/testspark/actions/evosuite/EvoSuitePanelFactory.kt
+++ /dev/null
@@ -1,112 +0,0 @@
-package org.jetbrains.research.testspark.actions.evosuite
-
-import com.intellij.openapi.project.Project
-import com.intellij.openapi.ui.ComboBox
-import com.intellij.ui.components.JBLabel
-import com.intellij.util.ui.FormBuilder
-import org.jetbrains.research.testspark.actions.template.PanelFactory
-import org.jetbrains.research.testspark.bundles.TestSparkLabelsBundle
-import org.jetbrains.research.testspark.bundles.TestSparkToolTipsBundle
-import org.jetbrains.research.testspark.data.ContentDigestAlgorithm
-import org.jetbrains.research.testspark.services.SettingsApplicationService
-import org.jetbrains.research.testspark.settings.SettingsApplicationState
-import java.awt.Font
-import javax.swing.JButton
-import javax.swing.JLabel
-import javax.swing.JPanel
-import javax.swing.JTextField
-
-class EvoSuitePanelFactory(private val project: Project) : PanelFactory {
- private val settingsState: SettingsApplicationState
- get() = project.getService(SettingsApplicationService::class.java).state
-
- private var javaPathTextField = JTextField(30)
- private var algorithmSelector = ComboBox(ContentDigestAlgorithm.values())
- private val backEvoSuiteButton = JButton(TestSparkLabelsBundle.defaultValue("back"))
- private val okEvoSuiteButton = JButton(TestSparkLabelsBundle.defaultValue("ok"))
-
- /**
- * Returns the title panel for the component.
- *
- * @return the title panel as a JPanel instance.
- */
- override fun getTitlePanel(): JPanel {
- val textTitle = JLabel(TestSparkLabelsBundle.defaultValue("evosuiteSetup"))
- textTitle.font = Font("Monochrome", Font.BOLD, 20)
-
- val titlePanel = JPanel()
- titlePanel.add(textTitle)
-
- return titlePanel
- }
-
- /**
- * Returns the middle panel.
- *
- * @return the middle panel as a JPanel.
- */
- override fun getMiddlePanel(): JPanel {
- javaPathTextField.toolTipText = TestSparkToolTipsBundle.defaultValue("javaPath")
- javaPathTextField.text = settingsState.javaPath
-
- algorithmSelector.setMinimumAndPreferredWidth(300)
- algorithmSelector.selectedItem = settingsState.algorithm
-
- return FormBuilder.createFormBuilder()
- .setFormLeftIndent(10)
- .addLabeledComponent(
- JBLabel(TestSparkLabelsBundle.defaultValue("javaPath")),
- javaPathTextField,
- 10,
- false,
- )
- .addLabeledComponent(
- JBLabel(TestSparkLabelsBundle.defaultValue("defaultSearch")),
- algorithmSelector,
- 10,
- false,
- )
- .panel
- }
-
- /**
- * Returns the bottom panel for the current view.
- *
- * @return The bottom panel for the current view.
- */
- override fun getBottomPanel(): JPanel {
- val bottomButtons = JPanel()
-
- backEvoSuiteButton.isOpaque = false
- backEvoSuiteButton.isContentAreaFilled = false
- bottomButtons.add(backEvoSuiteButton)
-
- okEvoSuiteButton.isOpaque = false
- okEvoSuiteButton.isContentAreaFilled = false
- bottomButtons.add(okEvoSuiteButton)
-
- return bottomButtons
- }
-
- /**
- * Retrieves the back button.
- *
- * @return The back button.
- */
- override fun getBackButton() = backEvoSuiteButton
-
- /**
- * Retrieves the reference to the "OK" button.
- *
- * @return The reference to the "OK" button.
- */
- override fun getFinishedButton() = okEvoSuiteButton
-
- /**
- * Updates the state of the settings.
- */
- override fun applyUpdates() {
- settingsState.javaPath = javaPathTextField.text
- settingsState.algorithm = algorithmSelector.selectedItem!! as ContentDigestAlgorithm
- }
-}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorBuilder.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorBuilder.kt
new file mode 100644
index 000000000..8cd42364e
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorBuilder.kt
@@ -0,0 +1,236 @@
+package org.jetbrains.research.testspark.actions.llm
+
+import com.intellij.openapi.fileTypes.FileType
+import com.intellij.openapi.fileTypes.FileTypeManager
+import com.intellij.openapi.project.Project
+import com.intellij.openapi.roots.ProjectFileIndex
+import com.intellij.openapi.roots.ProjectRootManager
+import com.intellij.psi.PsiClass
+import com.intellij.psi.PsiJavaFile
+import com.intellij.psi.PsiManager
+import com.intellij.psi.PsiMethod
+import com.intellij.util.containers.stream
+import com.intellij.util.ui.FormBuilder
+import org.jetbrains.research.testspark.actions.template.PanelBuilder
+import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle
+import org.jetbrains.research.testspark.core.test.SupportedLanguage
+import java.awt.Font
+import javax.swing.ButtonGroup
+import javax.swing.JButton
+import javax.swing.JLabel
+import javax.swing.JPanel
+import javax.swing.JRadioButton
+
+class LLMSampleSelectorBuilder(private val project: Project, private val language: SupportedLanguage) : PanelBuilder {
+ // init components
+ private val selectionTypeButtons: MutableList = mutableListOf(
+ JRadioButton(PluginLabelsBundle.get("provideTestSample")),
+ JRadioButton(PluginLabelsBundle.get("noTestSample")),
+ )
+ private val selectionTypeButtonGroup = ButtonGroup()
+ private val radioButtonsPanel = JPanel()
+
+ private val defaultTestName = "provide manually"
+ private val defaultTestCode = "// provide test method code here"
+ private val testNames = mutableListOf(defaultTestName)
+ private val initialTestCodes = mutableListOf(createTestSampleClass("", defaultTestCode))
+ private val testSamplePanelFactories: MutableList = mutableListOf()
+ private var testSamplesCode: String = ""
+
+ private val addButtonPanel = JPanel()
+ private val addButton = JButton(PluginLabelsBundle.get("addTestSample"))
+
+ private val nextButton = JButton(PluginLabelsBundle.get("ok"))
+ private val backLlmButton = JButton(PluginLabelsBundle.get("back"))
+
+ private var formBuilder = FormBuilder.createFormBuilder()
+ .setFormLeftIndent(10)
+ .addComponent(JPanel(), 0)
+ .addComponent(radioButtonsPanel, 10)
+ .addComponent(addButtonPanel, 10)
+
+ private var middlePanel = formBuilder.panel
+
+ init {
+ addListeners()
+
+ collectTestSamples(project, testNames, initialTestCodes)
+ }
+
+ override fun getTitlePanel(): JPanel {
+ val textTitle = JLabel(PluginLabelsBundle.get("llmSampleSelectorFactory"))
+ textTitle.font = Font("Monochrome", Font.BOLD, 20)
+
+ val titlePanel = JPanel()
+ titlePanel.add(textTitle)
+
+ return titlePanel
+ }
+
+ override fun getMiddlePanel(): JPanel {
+ for (button in selectionTypeButtons) {
+ selectionTypeButtonGroup.add(button)
+ radioButtonsPanel.add(button)
+ }
+
+ selectionTypeButtons[1].isSelected = true
+
+ addButtonPanel.add(addButton)
+
+ enabledComponents(false)
+
+ middlePanel.revalidate()
+
+ return middlePanel
+ }
+
+ override fun getBottomPanel(): JPanel {
+ val bottomPanel = JPanel()
+ backLlmButton.isOpaque = false
+ backLlmButton.isContentAreaFilled = false
+ bottomPanel.add(backLlmButton)
+ nextButton.isOpaque = false
+ nextButton.isContentAreaFilled = false
+ bottomPanel.add(nextButton)
+
+ return bottomPanel
+ }
+
+ override fun getBackButton() = backLlmButton
+
+ override fun getFinishedButton() = nextButton
+
+ override fun applyUpdates() {
+ if (selectionTypeButtons[0].isSelected) {
+ for (index in testSamplePanelFactories.indices) {
+ testSamplesCode += "Test sample number ${index + 1}\n```\n${testSamplePanelFactories[index].getCode()}\n```\n"
+ }
+ }
+ }
+
+ /**
+ * Retrieves the add button.
+ *
+ * @return The add button.
+ */
+ fun getAddButton(): JButton = addButton
+
+ /**
+ * Retrieves the test samples code.
+ *
+ * @return The test samples code.
+ */
+ fun getTestSamplesCode(): String = testSamplesCode
+
+ /**
+ * Adds action listeners to the selectionTypeButtons array to enable the nextButton if any button is selected.
+ */
+ private fun addListeners() {
+ selectionTypeButtons[0].addActionListener {
+ updateNextButton()
+ enabledComponents(true)
+ }
+
+ selectionTypeButtons[1].addActionListener {
+ updateNextButton()
+ enabledComponents(false)
+ }
+
+ addButton.addActionListener {
+ val testSamplePanelBuilder =
+ TestSamplePanelBuilder(project, middlePanel, testNames, initialTestCodes, language)
+ testSamplePanelFactories.add(testSamplePanelBuilder)
+ val testSamplePanel = testSamplePanelBuilder.getTestSamplePanel()
+ val codeScrollPanel = testSamplePanelBuilder.getCodeScrollPanel()
+ formBuilder = formBuilder
+ .addComponent(testSamplePanel, 10)
+ .addComponent(codeScrollPanel, 10)
+ middlePanel = formBuilder.panel
+ middlePanel.revalidate()
+
+ testSamplePanelBuilder.getRemoveButton().addActionListener {
+ testSamplePanelFactories.remove(testSamplePanelBuilder)
+ middlePanel.remove(testSamplePanel)
+ middlePanel.remove(codeScrollPanel)
+ middlePanel.revalidate()
+
+ updateNextButton()
+ }
+
+ updateNextButton()
+ }
+ }
+
+ /**
+ * Updates next button.
+ */
+ private fun updateNextButton() {
+ if (selectionTypeButtons[0].isSelected) {
+ nextButton.isEnabled = testSamplePanelFactories.isNotEmpty()
+ } else {
+ nextButton.isEnabled = true
+ }
+ }
+
+ /**
+ * Enables and disables the components in the panel in case of type button selection.
+ */
+ private fun enabledComponents(isEnabled: Boolean) {
+ addButton.isEnabled = isEnabled
+
+ for (testSamplePanelFactory in testSamplePanelFactories) {
+ testSamplePanelFactory.enabledComponents(isEnabled)
+ }
+ }
+
+ /**
+ * Retrieves a list of test samples from the given project.
+ *
+ * @return A list of strings, representing the names of the test samples.
+ */
+ private fun collectTestSamples(project: Project, testNames: MutableList, initialTestCodes: MutableList) {
+ val projectFileIndex: ProjectFileIndex = ProjectRootManager.getInstance(project).fileIndex
+ val javaFileType: FileType = FileTypeManager.getInstance().getFileTypeByExtension("java")
+
+ projectFileIndex.iterateContent { file ->
+ if (file.fileType === javaFileType) {
+ try {
+ val psiJavaFile = (PsiManager.getInstance(project).findFile(file) as PsiJavaFile)
+ val psiClass = psiJavaFile.classes[
+ psiJavaFile.classes.stream().map { it.name }.toArray()
+ .indexOf(psiJavaFile.name.removeSuffix(".java")),
+ ]
+ var imports = psiJavaFile.importList?.allImportStatements?.map { it.text }?.toList()
+ ?.joinToString("\n") ?: ""
+ if (psiClass.qualifiedName != null && psiClass.qualifiedName!!.contains(".")) {
+ imports += "\nimport ${psiClass.qualifiedName?.substringBeforeLast(".") + ".*"};"
+ }
+ psiClass.allMethods.forEach { method ->
+ val annotations = method.modifierList.annotations
+ annotations.forEach { annotation ->
+ if (annotation.qualifiedName == "org.junit.jupiter.api.Test" || annotation.qualifiedName == "org.junit.Test") {
+ val code: String = createTestSampleClass(imports, method.text)
+ testNames.add(createMethodName(psiClass, method))
+ initialTestCodes.add(code)
+ }
+ }
+ }
+ } catch (_: Exception) {
+ }
+ }
+ true
+ }
+ }
+
+ private fun createTestSampleClass(imports: String, methodCode: String): String {
+ var normalizedImports = imports
+ if (normalizedImports.isNotBlank()) normalizedImports += "\n\n"
+ return normalizedImports +
+ "public class TestSample {\n" +
+ " $methodCode\n" +
+ "}"
+ }
+
+ private fun createMethodName(psiClass: PsiClass, method: PsiMethod): String =
+ "${psiClass.qualifiedName}#${method.name}"
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt
deleted file mode 100644
index edb458c1c..000000000
--- a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt
+++ /dev/null
@@ -1,195 +0,0 @@
-package org.jetbrains.research.testspark.actions.llm
-
-import com.intellij.openapi.components.service
-import com.intellij.openapi.project.Project
-import com.intellij.util.ui.FormBuilder
-import org.jetbrains.research.testspark.actions.template.PanelFactory
-import org.jetbrains.research.testspark.bundles.TestSparkLabelsBundle
-import org.jetbrains.research.testspark.services.LLMTestSampleService
-import java.awt.Font
-import javax.swing.ButtonGroup
-import javax.swing.JButton
-import javax.swing.JLabel
-import javax.swing.JPanel
-import javax.swing.JRadioButton
-
-class LLMSampleSelectorFactory(private val project: Project) : PanelFactory {
- private val selectionTypeButtons: MutableList = mutableListOf(
- JRadioButton(TestSparkLabelsBundle.defaultValue("provideTestSample")),
- JRadioButton(TestSparkLabelsBundle.defaultValue("noTestSample")),
- )
- private val selectionTypeButtonGroup = ButtonGroup()
- private val radioButtonsPanel = JPanel()
-
- private val addButtonPanel = JPanel()
- private val addButton = JButton(TestSparkLabelsBundle.defaultValue("addTestSample"))
-
- private val backLlmButton = JButton(TestSparkLabelsBundle.defaultValue("back"))
- private val nextButton = JButton(TestSparkLabelsBundle.defaultValue("ok"))
-
- private val defaultTestName = "provide manually"
- private val defaultTestCode = "// provide test method code here"
-
- private val testNames = mutableListOf(defaultTestName)
- private val initialTestCodes =
- mutableListOf(project.service().createTestSampleClass("", defaultTestCode))
-
- private val testSamplePanelFactories: MutableList = mutableListOf()
-
- private var testSamplesCode: String = ""
-
- private var formBuilder = FormBuilder.createFormBuilder()
- .setFormLeftIndent(10)
- .addComponent(JPanel(), 0)
- .addComponent(radioButtonsPanel, 10)
- .addComponent(addButtonPanel, 10)
-
- private var middlePanel = formBuilder.panel
-
- init {
- addListeners()
-
- project.service().collectTestSamples(project, testNames, initialTestCodes)
- }
-
- /**
- * Adds action listeners to the selectionTypeButtons array to enable the nextButton if any button is selected.
- */
- private fun addListeners() {
- selectionTypeButtons[0].addActionListener {
- updateNextButton()
- enabledComponents(true)
- }
-
- selectionTypeButtons[1].addActionListener {
- updateNextButton()
- enabledComponents(false)
- }
-
- addButton.addActionListener {
- val testSamplePanelFactory = TestSamplePanelFactory(project, middlePanel, testNames, initialTestCodes)
- testSamplePanelFactories.add(testSamplePanelFactory)
- val testSamplePanel = testSamplePanelFactory.getTestSamplePanel()
- val codeScrollPanel = testSamplePanelFactory.getCodeScrollPanel()
- formBuilder = formBuilder
- .addComponent(testSamplePanel, 10)
- .addComponent(codeScrollPanel, 10)
- middlePanel = formBuilder.panel
- middlePanel.revalidate()
-
- testSamplePanelFactory.getRemoveButton().addActionListener {
- testSamplePanelFactories.remove(testSamplePanelFactory)
- middlePanel.remove(testSamplePanel)
- middlePanel.remove(codeScrollPanel)
- middlePanel.revalidate()
-
- updateNextButton()
- }
-
- updateNextButton()
- }
- }
-
- private fun updateNextButton() {
- if (selectionTypeButtons[0].isSelected) {
- nextButton.isEnabled = testSamplePanelFactories.isNotEmpty()
- } else {
- nextButton.isEnabled = true
- }
- }
-
- /**
- * Returns a JPanel object representing the title panel.
- * The panel contains a JLabel with the text "llmSampleSelectorFactory",
- * rendered in a bold 20pt Monochrome font.
- *
- * @return a JPanel object representing the title panel
- */
- override fun getTitlePanel(): JPanel {
- val textTitle = JLabel(TestSparkLabelsBundle.defaultValue("llmSampleSelectorFactory"))
- textTitle.font = Font("Monochrome", Font.BOLD, 20)
-
- val titlePanel = JPanel()
- titlePanel.add(textTitle)
-
- return titlePanel
- }
-
- /**
- * Returns the middle panel containing radio buttons, a test samples selector, and a language text field scroll pane.
- *
- * @return the middle panel as a JPanel
- */
- override fun getMiddlePanel(): JPanel {
- for (button in selectionTypeButtons) {
- selectionTypeButtonGroup.add(button)
- radioButtonsPanel.add(button)
- }
-
- selectionTypeButtons[1].isSelected = true
-
- addButtonPanel.add(addButton)
-
- enabledComponents(false)
-
- middlePanel.revalidate()
-
- return middlePanel
- }
-
- /**
- * Retrieves the bottom panel containing the back and next buttons.
- *
- * @return The JPanel containing the back and next buttons.
- */
- override fun getBottomPanel(): JPanel {
- val bottomPanel = JPanel()
- backLlmButton.isOpaque = false
- backLlmButton.isContentAreaFilled = false
- bottomPanel.add(backLlmButton)
- nextButton.isOpaque = false
- nextButton.isContentAreaFilled = false
- bottomPanel.add(nextButton)
-
- return bottomPanel
- }
-
- /**
- * Retrieves the back button.
- *
- * @return The back button.
- */
- override fun getBackButton() = backLlmButton
-
- /**
- * Retrieves the add button.
- *
- * @return The add button.
- */
- fun getAddButton() = addButton
-
- /**
- * Retrieves the reference to the "OK" button.
- *
- * @return The reference to the "OK" button.
- */
- override fun getFinishedButton() = nextButton
-
- override fun applyUpdates() {
- if (selectionTypeButtons[0].isSelected) {
- for (index in testSamplePanelFactories.indices) {
- testSamplesCode += "Test sample number ${index + 1}\n```\n${testSamplePanelFactories[index].getCode()}\n```\n"
- }
- }
- }
-
- private fun enabledComponents(isEnabled: Boolean) {
- addButton.isEnabled = isEnabled
-
- for (testSamplePanelFactory in testSamplePanelFactories) {
- testSamplePanelFactory.enabledComponents(isEnabled)
- }
- }
-
- fun getTestSamplesCode(): String = testSamplesCode
-}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSetupPanelBuilder.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSetupPanelBuilder.kt
new file mode 100644
index 000000000..873c634c9
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSetupPanelBuilder.kt
@@ -0,0 +1,225 @@
+package org.jetbrains.research.testspark.actions.llm
+
+import com.intellij.openapi.actionSystem.AnActionEvent
+import com.intellij.openapi.project.Project
+import com.intellij.openapi.ui.ComboBox
+import com.intellij.ui.components.JBLabel
+import com.intellij.util.ui.FormBuilder
+import org.jetbrains.research.testspark.actions.template.PanelBuilder
+import org.jetbrains.research.testspark.bundles.llm.LLMLabelsBundle
+import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle
+import org.jetbrains.research.testspark.core.data.JUnitVersion
+import org.jetbrains.research.testspark.data.llm.JsonEncoding
+import org.jetbrains.research.testspark.data.llm.PromptEditorType
+import org.jetbrains.research.testspark.display.TestSparkIcons
+import org.jetbrains.research.testspark.display.custom.JUnitCombobox
+import org.jetbrains.research.testspark.helpers.LLMHelper
+import org.jetbrains.research.testspark.helpers.PromptParserHelper
+import org.jetbrains.research.testspark.services.LLMSettingsService
+import org.jetbrains.research.testspark.settings.llm.LLMSettingsState
+import org.jetbrains.research.testspark.tools.llm.generation.LLMPlatform
+import java.awt.FlowLayout
+import java.awt.Font
+import javax.swing.DefaultComboBoxModel
+import javax.swing.JButton
+import javax.swing.JLabel
+import javax.swing.JPanel
+import javax.swing.JTextField
+
+class LLMSetupPanelBuilder(e: AnActionEvent, private val project: Project) : PanelBuilder {
+ private val llmSettingsState: LLMSettingsState
+ get() = project.getService(LLMSettingsService::class.java).state
+
+ // init components
+ private val defaultModulesArray = arrayOf("")
+ private var modelSelector = ComboBox(defaultModulesArray)
+ private var llmUserTokenField = JTextField(30)
+ private var platformSelector = ComboBox(arrayOf(llmSettingsState.openAIName, llmSettingsState.huggingFaceName))
+ private val backLlmButton = JButton(PluginLabelsBundle.get("back"))
+ private val okLlmButton = JButton(PluginLabelsBundle.get("next"))
+ private val junitSelector = JUnitCombobox(e)
+
+ private val llmPlatforms: List = LLMHelper.getLLLMPlatforms()
+
+ private var promptEditorType: PromptEditorType = PromptEditorType.CLASS
+ private val promptTemplateNames = ComboBox(arrayOf(""))
+ private var prompts: String = llmSettingsState.classPrompts
+ private var promptNames: String = llmSettingsState.classPromptNames
+ private var currentDefaultPromptIndex: Int = llmSettingsState.classCurrentDefaultPromptIndex
+ private var showCodeJLabel: JLabel = JLabel(TestSparkIcons.showCode)
+
+ init {
+ LLMHelper.addLLMPanelListeners(
+ platformSelector,
+ modelSelector,
+ llmUserTokenField,
+ llmPlatforms,
+ llmSettingsState,
+ )
+
+ addListeners()
+ }
+
+ override fun getTitlePanel(): JPanel {
+ val textTitle = JLabel(PluginLabelsBundle.get("llmSetup"))
+ textTitle.font = Font("Monochrome", Font.BOLD, 20)
+
+ val titlePanel = JPanel()
+ titlePanel.add(textTitle)
+
+ return titlePanel
+ }
+
+ override fun getMiddlePanel(): JPanel {
+ LLMHelper.stylizeMainComponents(platformSelector, modelSelector, llmUserTokenField, llmPlatforms, llmSettingsState)
+
+ updatePromptSelectionPanel()
+
+ return FormBuilder.createFormBuilder()
+ .setFormLeftIndent(10)
+ .addLabeledComponent(
+ JBLabel(LLMLabelsBundle.get("llmPlatform")),
+ platformSelector,
+ 10,
+ false,
+ )
+ .addLabeledComponent(
+ JBLabel(LLMLabelsBundle.get("llmToken")),
+ llmUserTokenField,
+ 10,
+ false,
+ )
+ .addLabeledComponent(
+ JBLabel(LLMLabelsBundle.get("model")),
+ modelSelector,
+ 10,
+ false,
+ )
+ .addLabeledComponent(
+ JBLabel(LLMLabelsBundle.get("junitVersion")),
+ junitSelector,
+ 10,
+ false,
+ )
+ .addLabeledComponent(
+ JBLabel(LLMLabelsBundle.get("selectPrompt")),
+ getPromptSelectionPanel(),
+ 10,
+ false,
+ )
+ .panel
+ }
+
+ override fun getBottomPanel(): JPanel {
+ val bottomPanel = JPanel()
+
+ backLlmButton.isOpaque = false
+ backLlmButton.isContentAreaFilled = false
+ bottomPanel.add(backLlmButton)
+
+ okLlmButton.isOpaque = false
+ okLlmButton.isContentAreaFilled = false
+ if (!llmSettingsState.provideTestSamplesCheckBoxSelected) {
+ okLlmButton.text = PluginLabelsBundle.get("ok")
+ }
+ bottomPanel.add(okLlmButton)
+
+ return bottomPanel
+ }
+
+ override fun getBackButton() = backLlmButton
+
+ override fun getFinishedButton() = okLlmButton
+
+ override fun applyUpdates() {
+ llmSettingsState.currentLLMPlatformName = platformSelector.selectedItem!!.toString()
+ for (index in llmPlatforms.indices) {
+ if (llmPlatforms[index].name == llmSettingsState.openAIName) {
+ llmSettingsState.openAIToken = llmPlatforms[index].token
+ llmSettingsState.openAIModel = llmPlatforms[index].model
+ }
+ if (llmPlatforms[index].name == llmSettingsState.grazieName) {
+ llmSettingsState.grazieToken = llmPlatforms[index].token
+ llmSettingsState.grazieModel = llmPlatforms[index].model
+ }
+ if (llmPlatforms[index].name == llmSettingsState.huggingFaceName) {
+ llmSettingsState.huggingFaceToken = llmPlatforms[index].token
+ llmSettingsState.huggingFaceModel = llmPlatforms[index].model
+ }
+ }
+ llmSettingsState.junitVersion = junitSelector.selectedItem!! as JUnitVersion
+
+ when (promptEditorType) {
+ PromptEditorType.CLASS -> llmSettingsState.classCurrentDefaultPromptIndex = JsonEncoding.decode(promptNames).indexOf(promptTemplateNames.selectedItem!!.toString())
+ PromptEditorType.METHOD -> llmSettingsState.methodCurrentDefaultPromptIndex = JsonEncoding.decode(promptNames).indexOf(promptTemplateNames.selectedItem!!.toString())
+ PromptEditorType.LINE -> llmSettingsState.lineCurrentDefaultPromptIndex = JsonEncoding.decode(promptNames).indexOf(promptTemplateNames.selectedItem!!.toString())
+ }
+ }
+
+ /**
+ * Set promptEditorType variable.
+ */
+ fun setPromptEditorType(codeType: String) {
+ if (codeType.contains("class") || codeType.contains("interface")) promptEditorType = PromptEditorType.CLASS
+ if (codeType.contains("method") || codeType.contains("constructor")) promptEditorType = PromptEditorType.METHOD
+ if (codeType.contains("line")) promptEditorType = PromptEditorType.LINE
+
+ updatePromptSelectionPanel()
+ }
+
+ /**
+ * Update prompts, promptNames, currentDefaultPromptIndex vars.
+ */
+ private fun updatePromptSelectionPanel() {
+ when (promptEditorType) {
+ PromptEditorType.CLASS -> {
+ prompts = llmSettingsState.classPrompts
+ promptNames = llmSettingsState.classPromptNames
+ currentDefaultPromptIndex = llmSettingsState.classCurrentDefaultPromptIndex
+ }
+ PromptEditorType.METHOD -> {
+ prompts = llmSettingsState.methodPrompts
+ promptNames = llmSettingsState.methodPromptNames
+ currentDefaultPromptIndex = llmSettingsState.methodCurrentDefaultPromptIndex
+ }
+ PromptEditorType.LINE -> {
+ prompts = llmSettingsState.linePrompts
+ promptNames = llmSettingsState.linePromptNames
+ currentDefaultPromptIndex = llmSettingsState.lineCurrentDefaultPromptIndex
+ }
+ }
+
+ val names = JsonEncoding.decode(promptNames)
+ var normalizedNames = arrayOf()
+ for (i in names.indices) {
+ val prompt = JsonEncoding.decode(prompts)[i]
+ if (PromptParserHelper.isPromptValid(prompt)) {
+ normalizedNames += names[i]
+ }
+ }
+
+ promptTemplateNames.model = DefaultComboBoxModel(normalizedNames)
+ promptTemplateNames.selectedItem = names[currentDefaultPromptIndex]
+ }
+
+ /**
+ * @return prompt selection panel
+ */
+ private fun getPromptSelectionPanel(): JPanel {
+ val panel = JPanel(FlowLayout(FlowLayout.LEFT))
+
+ panel.add(promptTemplateNames)
+ panel.add(showCodeJLabel)
+
+ return panel
+ }
+
+ /**
+ * Add listener to a promptTemplateNames.
+ */
+ private fun addListeners() {
+ promptTemplateNames.addActionListener {
+ showCodeJLabel.toolTipText = JsonEncoding.decode(prompts)[JsonEncoding.decode(promptNames).indexOf(promptTemplateNames.selectedItem!!.toString())]
+ }
+ }
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSetupPanelFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSetupPanelFactory.kt
deleted file mode 100644
index 6f3779eb8..000000000
--- a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSetupPanelFactory.kt
+++ /dev/null
@@ -1,246 +0,0 @@
-package org.jetbrains.research.testspark.actions.llm
-
-import com.intellij.openapi.actionSystem.AnActionEvent
-import com.intellij.openapi.components.service
-import com.intellij.openapi.project.Project
-import com.intellij.openapi.ui.ComboBox
-import com.intellij.ui.components.JBLabel
-import com.intellij.util.ui.FormBuilder
-import org.jetbrains.research.testspark.actions.template.PanelFactory
-import org.jetbrains.research.testspark.bundles.TestSparkLabelsBundle
-import org.jetbrains.research.testspark.core.data.JUnitVersion
-import org.jetbrains.research.testspark.data.JsonEncoding
-import org.jetbrains.research.testspark.display.JUnitCombobox
-import org.jetbrains.research.testspark.display.TestSparkIcons
-import org.jetbrains.research.testspark.helpers.addLLMPanelListeners
-import org.jetbrains.research.testspark.helpers.getLLLMPlatforms
-import org.jetbrains.research.testspark.helpers.stylizeMainComponents
-import org.jetbrains.research.testspark.services.PromptParserService
-import org.jetbrains.research.testspark.services.SettingsApplicationService
-import org.jetbrains.research.testspark.settings.SettingsApplicationState
-import org.jetbrains.research.testspark.settings.llm.PromptEditorType
-import org.jetbrains.research.testspark.tools.llm.generation.LLMPlatform
-import java.awt.FlowLayout
-import java.awt.Font
-import javax.swing.DefaultComboBoxModel
-import javax.swing.JButton
-import javax.swing.JLabel
-import javax.swing.JPanel
-import javax.swing.JTextField
-
-class LLMSetupPanelFactory(e: AnActionEvent, private val project: Project) : PanelFactory {
- private val settingsState: SettingsApplicationState
- get() = project.getService(SettingsApplicationService::class.java).state
-
- private val defaultModulesArray = arrayOf("")
- private var modelSelector = ComboBox(defaultModulesArray)
- private var llmUserTokenField = JTextField(30)
- private var platformSelector = ComboBox(arrayOf(settingsState.openAIName))
- private val backLlmButton = JButton(TestSparkLabelsBundle.defaultValue("back"))
- private val okLlmButton = JButton(TestSparkLabelsBundle.defaultValue("next"))
- private val junitSelector = JUnitCombobox(e)
-
- private val llmPlatforms: List = getLLLMPlatforms()
-
- private var promptEditorType: PromptEditorType = PromptEditorType.CLASS
- private val promptTemplateNames = ComboBox(arrayOf(""))
- private var prompts: String = settingsState.classPrompts
- private var promptNames: String = settingsState.classPromptNames
- private var currentDefaultPromptIndex: Int = settingsState.classCurrentDefaultPromptIndex
- private var showCodeJLabel: JLabel = JLabel(TestSparkIcons.showCode)
-
- init {
- addLLMPanelListeners(
- platformSelector,
- modelSelector,
- llmUserTokenField,
- llmPlatforms,
- settingsState,
- )
-
- addListeners()
- }
-
- /**
- * Returns the title panel for the setup.
- *
- * @return the title panel containing the setup title label.
- */
- override fun getTitlePanel(): JPanel {
- val textTitle = JLabel(TestSparkLabelsBundle.defaultValue("llmSetup"))
- textTitle.font = Font("Monochrome", Font.BOLD, 20)
-
- val titlePanel = JPanel()
- titlePanel.add(textTitle)
-
- return titlePanel
- }
-
- /**
- * Retrieves the middle panel of the UI.
- *
- * This method returns a JPanel object that represents the middle panel of the user interface.
- * The middle panel contains several components including a platform selector, a model selector,
- * and a user token field. These components are stylized using the `stylizeMainComponents` method.
- * The UI labels for the platform, token, and model components are retrieved using the
- * `TestSpark*/
- override fun getMiddlePanel(): JPanel {
- stylizeMainComponents(platformSelector, modelSelector, llmUserTokenField, llmPlatforms, settingsState)
-
- updatePromptSelectionPanel()
-
- return FormBuilder.createFormBuilder()
- .setFormLeftIndent(10)
- .addLabeledComponent(
- JBLabel(TestSparkLabelsBundle.defaultValue("llmPlatform")),
- platformSelector,
- 10,
- false,
- )
- .addLabeledComponent(
- JBLabel(TestSparkLabelsBundle.defaultValue("llmToken")),
- llmUserTokenField,
- 10,
- false,
- )
- .addLabeledComponent(
- JBLabel(TestSparkLabelsBundle.defaultValue("model")),
- modelSelector,
- 10,
- false,
- )
- .addLabeledComponent(
- JBLabel(TestSparkLabelsBundle.defaultValue("junitVersion")),
- junitSelector,
- 10,
- false,
- )
- .addLabeledComponent(
- JBLabel(TestSparkLabelsBundle.defaultValue("selectPrompt")),
- getPromptSelectionPanel(),
- 10,
- false,
- )
- .panel
- }
-
- private fun updatePromptSelectionPanel() {
- when (promptEditorType) {
- PromptEditorType.CLASS -> {
- prompts = settingsState.classPrompts
- promptNames = settingsState.classPromptNames
- currentDefaultPromptIndex = settingsState.classCurrentDefaultPromptIndex
- }
- PromptEditorType.METHOD -> {
- prompts = settingsState.methodPrompts
- promptNames = settingsState.methodPromptNames
- currentDefaultPromptIndex = settingsState.methodCurrentDefaultPromptIndex
- }
- PromptEditorType.LINE -> {
- prompts = settingsState.linePrompts
- promptNames = settingsState.linePromptNames
- currentDefaultPromptIndex = settingsState.lineCurrentDefaultPromptIndex
- }
- }
-
- val names = JsonEncoding.decode(promptNames)
- var normalizedNames = arrayOf()
- for (i in names.indices) {
- val prompt = JsonEncoding.decode(prompts)[i]
- if (service().isPromptValid(prompt)) {
- normalizedNames += names[i]
- }
- }
-
- promptTemplateNames.model = DefaultComboBoxModel(normalizedNames)
- promptTemplateNames.selectedItem = names[currentDefaultPromptIndex]
- }
-
- private fun getPromptSelectionPanel(): JPanel {
- val panel = JPanel(FlowLayout(FlowLayout.LEFT))
-
- panel.add(promptTemplateNames)
- panel.add(showCodeJLabel)
-
- return panel
- }
-
- fun setPromptEditorType(codeType: String) {
- if (codeType.contains("class") || codeType.contains("interface")) promptEditorType = PromptEditorType.CLASS
- if (codeType.contains("method") || codeType.contains("constructor")) promptEditorType = PromptEditorType.METHOD
- if (codeType.contains("line")) promptEditorType = PromptEditorType.LINE
-
- updatePromptSelectionPanel()
- }
-
- /**
- * Returns the bottom panel for the UI.
- *
- * @return The JPanel representing the bottom panel of the UI.
- */
- override fun getBottomPanel(): JPanel {
- val bottomPanel = JPanel()
-
- backLlmButton.isOpaque = false
- backLlmButton.isContentAreaFilled = false
- bottomPanel.add(backLlmButton)
-
- okLlmButton.isOpaque = false
- okLlmButton.isContentAreaFilled = false
- if (!settingsState.provideTestSamplesCheckBoxSelected) {
- okLlmButton.text = TestSparkLabelsBundle.defaultValue("ok")
- }
- bottomPanel.add(okLlmButton)
-
- return bottomPanel
- }
-
- /**
- * Retrieves the back button.
- *
- * @return The back button.
- */
- override fun getBackButton() = backLlmButton
-
- /**
- * Retrieves the reference to the "OK" button.
- *
- * @return The reference to the "OK" button.
- */
- override fun getFinishedButton() = okLlmButton
-
- private fun addListeners() {
- promptTemplateNames.addActionListener {
- showCodeJLabel.toolTipText = JsonEncoding.decode(prompts)[JsonEncoding.decode(promptNames).indexOf(promptTemplateNames.selectedItem!!.toString())]
- }
- }
-
- /**
- * Updates the settings state based on the selected values from the UI components.
- *
- * This method sets the `llmPlatform`, `llmUserToken`, and `model` properties of the `settingsState` object
- * based on the currently selected values from the UI components.
- *
- * Note: This method assumes all the required UI components (`platformSelector`, `llmUserTokenField`, and `modelSelector`) are properly initialized and have values selected.
- */
- override fun applyUpdates() {
- settingsState.currentLLMPlatformName = platformSelector.selectedItem!!.toString()
- for (index in llmPlatforms.indices) {
- if (llmPlatforms[index].name == settingsState.openAIName) {
- settingsState.openAIToken = llmPlatforms[index].token
- settingsState.openAIModel = llmPlatforms[index].model
- }
- if (llmPlatforms[index].name == settingsState.grazieName) {
- settingsState.grazieToken = llmPlatforms[index].token
- settingsState.grazieModel = llmPlatforms[index].model
- }
- }
- settingsState.junitVersion = junitSelector.selectedItem!! as JUnitVersion
-
- when (promptEditorType) {
- PromptEditorType.CLASS -> settingsState.classCurrentDefaultPromptIndex = JsonEncoding.decode(promptNames).indexOf(promptTemplateNames.selectedItem!!.toString())
- PromptEditorType.METHOD -> settingsState.methodCurrentDefaultPromptIndex = JsonEncoding.decode(promptNames).indexOf(promptTemplateNames.selectedItem!!.toString())
- PromptEditorType.LINE -> settingsState.lineCurrentDefaultPromptIndex = JsonEncoding.decode(promptNames).indexOf(promptTemplateNames.selectedItem!!.toString())
- }
- }
-}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/TestSamplePanelFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/TestSamplePanelBuilder.kt
similarity index 80%
rename from src/main/kotlin/org/jetbrains/research/testspark/actions/llm/TestSamplePanelFactory.kt
rename to src/main/kotlin/org/jetbrains/research/testspark/actions/llm/TestSamplePanelBuilder.kt
index 15b516062..d0ac2c325 100644
--- a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/TestSamplePanelFactory.kt
+++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/TestSamplePanelBuilder.kt
@@ -9,44 +9,42 @@ import com.intellij.openapi.project.Project
import com.intellij.openapi.ui.ComboBox
import com.intellij.ui.LanguageTextField
import com.intellij.ui.components.JBScrollPane
-import org.jetbrains.research.testspark.bundles.TestSparkLabelsBundle
+import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle
+import org.jetbrains.research.testspark.core.test.SupportedLanguage
import org.jetbrains.research.testspark.display.TestCaseDocumentCreator
import org.jetbrains.research.testspark.display.TestSparkIcons
-import org.jetbrains.research.testspark.display.createButton
-import org.jetbrains.research.testspark.display.getModifiedLines
+import org.jetbrains.research.testspark.display.utils.IconButtonCreator
+import org.jetbrains.research.testspark.display.utils.ModifiedLinesGetter
import javax.swing.BoxLayout
import javax.swing.DefaultComboBoxModel
import javax.swing.JButton
import javax.swing.JPanel
import javax.swing.ScrollPaneConstants
-class TestSamplePanelFactory(
+class TestSamplePanelBuilder(
project: Project,
private val middlePanel: JPanel,
private val testNames: MutableList,
private val initialTestCodes: MutableList,
+ private val language: SupportedLanguage,
) {
+ // init components
private val currentTestCodes = initialTestCodes.toMutableList()
-
private val languageTextField = LanguageTextField(
- Language.findLanguageByID("JAVA"),
+ Language.findLanguageByID(language.languageId),
project,
initialTestCodes[0],
TestCaseDocumentCreator("TestSample"),
false,
)
-
private var languageTextFieldScrollPane = JBScrollPane(
languageTextField,
ScrollPaneConstants.VERTICAL_SCROLLBAR_ALWAYS,
ScrollPaneConstants.HORIZONTAL_SCROLLBAR_ALWAYS,
)
-
private var testSamplesSelector = ComboBox(arrayOf(""))
-
- private val resetButton = createButton(TestSparkIcons.reset, TestSparkLabelsBundle.defaultValue("resetTip"))
-
- private val removeButton = createButton(TestSparkIcons.remove, TestSparkLabelsBundle.defaultValue("removeTip"))
+ private val resetButton = IconButtonCreator.getButton(TestSparkIcons.reset, PluginLabelsBundle.get("resetTip"))
+ private val removeButton = IconButtonCreator.getButton(TestSparkIcons.remove, PluginLabelsBundle.get("removeTip"))
init {
addListeners()
@@ -54,6 +52,9 @@ class TestSamplePanelFactory(
testSamplesSelector.model = DefaultComboBoxModel(testNames.toTypedArray())
}
+ /**
+ * Add listeners.
+ */
private fun addListeners() {
languageTextField.document.addDocumentListener(object : DocumentListener {
override fun documentChanged(event: DocumentEvent) {
@@ -63,7 +64,7 @@ class TestSamplePanelFactory(
if (testNames[index] == testSamplesSelector.selectedItem) {
currentTestCodes[index] = languageTextField.text
- val modifiedLineIndexes = getModifiedLines(
+ val modifiedLineIndexes = ModifiedLinesGetter.getLines(
initialTestCodes[index].split("\n"),
currentTestCodes[index].split("\n"),
)
@@ -103,8 +104,14 @@ class TestSamplePanelFactory(
}
}
+ /**
+ * @return removeButton
+ */
fun getRemoveButton(): JButton = removeButton
+ /**
+ *
+ */
fun getTestSamplePanel(): JPanel {
val testSamplePanel = JPanel()
@@ -116,10 +123,16 @@ class TestSamplePanelFactory(
return testSamplePanel
}
+ /**
+ * @return languageTextFieldScrollPane
+ */
fun getCodeScrollPanel(): JBScrollPane {
return languageTextFieldScrollPane
}
+ /**
+ * Enables and disables the components in the panel in case of type button selection.
+ */
fun enabledComponents(isEnabled: Boolean) {
resetButton.isEnabled = false
for (index in testNames.indices) {
@@ -133,5 +146,8 @@ class TestSamplePanelFactory(
languageTextFieldScrollPane.isEnabled = isEnabled
}
+ /**
+ * @return code of the sample
+ */
fun getCode(): String = languageTextField.text
}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/template/PanelBuilder.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/template/PanelBuilder.kt
new file mode 100644
index 000000000..59f8717dd
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/template/PanelBuilder.kt
@@ -0,0 +1,46 @@
+package org.jetbrains.research.testspark.actions.template
+
+import javax.swing.JButton
+import javax.swing.JPanel
+
+interface PanelBuilder {
+ /**
+ * Returns a JPanel object containing the title panel.
+ *
+ * @return a JPanel object representing the title panel
+ */
+ fun getTitlePanel(): JPanel
+
+ /**
+ * Returns the middle panel containing main components.
+ *
+ * @return the middle panel as a JPanel.
+ */
+ fun getMiddlePanel(): JPanel
+
+ /**
+ * Retrieves the bottom panel containing the back and next buttons.
+ *
+ * @return The JPanel containing the back and next buttons.
+ */
+ fun getBottomPanel(): JPanel
+
+ /**
+ * Retrieves the back button.
+ *
+ * @return The back button.
+ */
+ fun getBackButton(): JButton
+
+ /**
+ * Retrieves the reference to the "OK" button.
+ *
+ * @return The reference to the "OK" button.
+ */
+ fun getFinishedButton(): JButton
+
+ /**
+ * Updates the settings state based on the selected values from the UI components.
+ */
+ fun applyUpdates()
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/template/PanelFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/template/PanelFactory.kt
deleted file mode 100644
index 3e99d2ccb..000000000
--- a/src/main/kotlin/org/jetbrains/research/testspark/actions/template/PanelFactory.kt
+++ /dev/null
@@ -1,18 +0,0 @@
-package org.jetbrains.research.testspark.actions.template
-
-import javax.swing.JButton
-import javax.swing.JPanel
-
-interface PanelFactory {
- fun getTitlePanel(): JPanel
-
- fun getMiddlePanel(): JPanel
-
- fun getBottomPanel(): JPanel
-
- fun getBackButton(): JButton
-
- fun getFinishedButton(): JButton
-
- fun applyUpdates()
-}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/appstarter/ProjectApplicationUtils.kt b/src/main/kotlin/org/jetbrains/research/testspark/appstarter/ProjectApplicationUtils.kt
new file mode 100644
index 000000000..4fd7a12fa
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/appstarter/ProjectApplicationUtils.kt
@@ -0,0 +1,224 @@
+package org.jetbrains.research.testspark.appstarter
+
+import com.intellij.conversion.ConversionListener
+import com.intellij.conversion.ConversionService
+import com.intellij.ide.CommandLineInspectionProgressReporter
+import com.intellij.ide.CommandLineInspectionProjectConfigurator
+import com.intellij.ide.CommandLineInspectionProjectConfigurator.ConfiguratorContext
+import com.intellij.ide.impl.PatchProjectUtil
+import com.intellij.ide.impl.ProjectUtil.openOrImport
+import com.intellij.openapi.Disposable
+import com.intellij.openapi.application.ApplicationManager
+import com.intellij.openapi.application.ModalityState
+import com.intellij.openapi.application.ex.ApplicationManagerEx
+import com.intellij.openapi.progress.util.ProgressIndicatorBase
+import com.intellij.openapi.project.Project
+import com.intellij.openapi.util.Disposer
+import com.intellij.openapi.util.io.FileUtil
+import com.intellij.openapi.vfs.LocalFileSystem
+import com.intellij.openapi.vfs.VirtualFile
+import com.intellij.openapi.vfs.VirtualFileManager
+import kotlinx.coroutines.runBlocking
+import org.jetbrains.idea.maven.MavenCommandLineInspectionProjectConfigurator
+import org.jetbrains.idea.maven.project.MavenProjectsManager
+import org.jetbrains.plugins.gradle.GradleCommandLineProjectConfigurator
+import org.slf4j.LoggerFactory
+import java.nio.file.Path
+import java.util.function.Predicate
+import kotlin.coroutines.suspendCoroutine
+
+class ProjectConfiguratorException : Exception {
+ constructor(message: String) : super(message)
+
+ @Suppress("unused")
+ constructor(message: String, cause: Throwable) : super(message, cause)
+}
+
+class ConversionListenerImpl : ConversionListener {
+ private val logger = LoggerFactory.getLogger(javaClass)
+
+ override fun conversionNeeded() {
+ logger.info("Conversion is needed for project.")
+ }
+
+ override fun successfullyConverted(backupDir: Path) {
+ logger.info("Project successfully converted.")
+ }
+
+ override fun error(message: String) {
+ throw ProjectConfiguratorException(message)
+ }
+
+ override fun cannotWriteToFiles(readonlyFiles: List) {
+ throw ProjectConfiguratorException("Can not write to files ${readonlyFiles.joinToString { it.fileName.toString() }}")
+ }
+}
+
+class ConfiguratorContextImpl(
+ private val projectRoot: Path,
+ private val indicator: ProgressIndicatorBase = ProgressIndicatorBase(),
+ private val filesFilter: Predicate = Predicate { true },
+ private val virtualFilesFilter: Predicate = Predicate { true },
+) : ConfiguratorContext {
+ private val logger = LoggerFactory.getLogger(javaClass)
+ override fun getProgressIndicator() = indicator
+ override fun getLogger() = object : CommandLineInspectionProgressReporter {
+ override fun reportError(message: String) {
+ logger.warn("ERROR: $message")
+ }
+
+ override fun reportMessage(minVerboseLevel: Int, message: String) {
+ logger.info("PROGRESS: $message")
+ }
+ }
+
+ override fun getProjectPath() = projectRoot
+ override fun getFilesFilter(): Predicate = filesFilter
+ override fun getVirtualFilesFilter(): Predicate = virtualFilesFilter
+}
+
+object ProjectApplicationUtils {
+
+ private val logger = LoggerFactory.getLogger(javaClass)
+
+ /**
+ * Rewritten from {@link com.intellij.codeInspection.InspectionApplicationBase}.
+ * Implementation which reuse InspectionApplicationBase:
+ *
+ * val app = object : InspectionApplicationBase() {
+ * fun open(): Project? {
+ * return this.openProject(projectPath, parentDisposable)
+ * }
+ * }
+ *
+ * return app.open() ?: throw ProjectApplicationException("Can not open project")
+ */
+ suspend fun openProject(projectPath: Path, parentDisposable: Disposable, projectToClose: Project? = null): Project {
+ ApplicationManager.getApplication().assertIsNonDispatchThread()
+ ApplicationManagerEx.getApplicationEx().isSaveAllowed = false
+
+ LocalFileSystem.getInstance().refreshAndFindFileByPath(
+ FileUtil.toSystemIndependentName(projectPath.toString()),
+ ) ?: throw ProjectConfiguratorException("Project directory not found.")
+
+ convertProject(projectPath)
+
+ configureProjectEnvironment(projectPath)
+
+ val project = openOrImport(projectPath, projectToClose, forceOpenInNewFrame = true)
+ ?: throw ProjectConfiguratorException("Can not open or import project from $projectPath.")
+ Disposer.register(parentDisposable) { closeProject(project) }
+
+ waitAllStartupActivitiesPassed(project)
+
+ ApplicationManager.getApplication().invokeAndWait {
+ VirtualFileManager.getInstance().refreshWithoutFileWatcher(false)
+ }
+
+ ApplicationManager.getApplication().invokeAndWait {
+ PatchProjectUtil.patchProject(project)
+ }
+
+ waitForInvokeLaterActivities()
+
+ return project
+ }
+
+ private suspend fun convertProject(projectPath: Path) {
+ val conversionService = ConversionService.getInstance()
+ ?: throw ProjectConfiguratorException("Can not convert project $projectPath")
+ val conversionResult = conversionService.convertSilently(projectPath, ConversionListenerImpl())
+ if (conversionResult.openingIsCanceled()) {
+ throw ProjectConfiguratorException("Project opening is canceled $projectPath")
+ }
+ }
+
+ private fun configureProjectEnvironment(projectPath: Path) {
+ for (configurator in CommandLineInspectionProjectConfigurator.EP_NAME.extensionList) {
+ val context = ConfiguratorContextImpl(projectPath)
+ if (configurator.isApplicable(context)) {
+ logger.info("Applying configurator ${configurator.name} to configure project environment $projectPath.")
+ configurator.configureEnvironment(context)
+ }
+ }
+ }
+
+ fun resolveProject(
+ project: Project,
+ configurator: CommandLineInspectionProjectConfigurator,
+ context: ConfiguratorContext,
+ ) {
+ logger.info("Resolving project ${project.name}...")
+ logger.info("Applying configurator ${configurator.name} to resolve project ${project.name}...")
+ configurator.preConfigureProject(project, context)
+ configurator.configureProject(project, context)
+ waitForInvokeLaterActivities()
+ logger.info("Project ${project.name} was successfully resolved with configurator ${configurator.name}!")
+ }
+
+ private fun closeProject(project: Project) {
+ logger.info("Closing project $project...")
+ ApplicationManager.getApplication().assertIsNonDispatchThread()
+ // ToDo: move headless mode to another branch
+// ApplicationManager.getApplication().invokeAndWait {
+// ProjectManagerEx.getInstanceEx().forceCloseProject(project)
+// }
+ }
+
+ private suspend fun waitAllStartupActivitiesPassed(project: Project): Unit = suspendCoroutine {
+ logger.info("Waiting all startup activities passed $project...")
+ // ToDo: move headless mode to another branch
+// StartupManager.getInstance(project).runAfterOpened { it.resume(Unit) }
+ waitForInvokeLaterActivities()
+ }
+
+ /**
+ * Magic loop is used to wait all invoke later activities passed.
+ * Without this loop we can run into problems, because some of the project opening activities may not be finished
+ * when we need them.
+ */
+ private fun waitForInvokeLaterActivities() {
+ logger.info("Waiting all invoked later activities...")
+ repeat(10) {
+ ApplicationManager.getApplication().invokeAndWait({}, ModalityState.any())
+ }
+ }
+}
+
+class JvmProjectResolver {
+ private val logger = LoggerFactory.getLogger(javaClass)
+
+ fun resolveProject(project: Project) {
+ logger.info("Started to resolve project ${project.name}.")
+ val configurator = getProjectConfigurator(project)
+ val projectPath = project.basePath?.let { Path.of(it) }
+ ?: throw ProjectConfiguratorException("Undefined base path for project ${project.name}")
+ val context = ConfiguratorContextImpl(projectPath)
+
+ ProjectApplicationUtils.resolveProject(project, configurator, context)
+ }
+
+ private fun getProjectConfigurator(project: Project): CommandLineInspectionProjectConfigurator {
+ return if (MavenProjectsManager.getInstance(project).isMavenizedProject) {
+ logger.info("Project ${project.name} considered to be maven")
+ MavenCommandLineInspectionProjectConfigurator()
+ } else {
+ logger.info("Project ${project.name} considered to be gradle")
+ GradleCommandLineProjectConfigurator()
+ }
+ }
+}
+
+class JvmProjectConfigurator {
+ private val projectResolver = JvmProjectResolver()
+
+ fun openProject(projectPath: Path, fullResolve: Boolean, parentDisposable: Disposable): Project = runBlocking {
+ val project = ProjectApplicationUtils.openProject(projectPath, parentDisposable)
+
+ if (fullResolve) {
+ projectResolver.resolveProject(project)
+ }
+
+ project
+ }
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt b/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt
new file mode 100644
index 000000000..af2bdef41
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt
@@ -0,0 +1,304 @@
+package org.jetbrains.research.testspark.appstarter
+
+import com.intellij.openapi.application.ApplicationManager
+import com.intellij.openapi.application.ApplicationStarter
+import com.intellij.openapi.components.service
+import com.intellij.openapi.project.DumbService
+import com.intellij.openapi.project.Project
+import com.intellij.openapi.project.ProjectManager
+import com.intellij.openapi.roots.ProjectFileIndex
+import com.intellij.openapi.roots.ProjectRootManager
+import com.intellij.openapi.util.Disposer
+import com.intellij.openapi.vfs.LocalFileSystem
+import com.intellij.psi.PsiClass
+import com.intellij.psi.PsiJavaFile
+import com.intellij.psi.PsiManager
+import kotlinx.serialization.ExperimentalSerializationApi
+import org.jetbrains.research.testspark.bundles.llm.LLMDefaultsBundle
+import org.jetbrains.research.testspark.core.data.JUnitVersion
+import org.jetbrains.research.testspark.core.data.TestGenerationData
+import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor
+import org.jetbrains.research.testspark.core.test.TestCompiler
+import org.jetbrains.research.testspark.core.test.data.CodeType
+import org.jetbrains.research.testspark.data.FragmentToTestData
+import org.jetbrains.research.testspark.data.ProjectContext
+import org.jetbrains.research.testspark.data.llm.JsonEncoding
+import org.jetbrains.research.testspark.langwrappers.PsiHelperProvider
+import org.jetbrains.research.testspark.progress.HeadlessProgressIndicator
+import org.jetbrains.research.testspark.services.LLMSettingsService
+import org.jetbrains.research.testspark.services.PluginSettingsService
+import org.jetbrains.research.testspark.tools.TestProcessor
+import org.jetbrains.research.testspark.tools.TestsExecutionResultManager
+import org.jetbrains.research.testspark.tools.ToolUtils
+import org.jetbrains.research.testspark.tools.factories.TestCompilerFactory
+import org.jetbrains.research.testspark.tools.llm.Llm
+import java.io.File
+import java.nio.file.Path
+import java.nio.file.Paths
+import kotlin.system.exitProcess
+
+/**
+ * This class is responsible for generating and running tests based on the provided arguments in headless mode.
+ */
+class TestSparkStarter : ApplicationStarter {
+ @Deprecated("Specify it as `id` for extension definition in a plugin descriptor")
+ override val commandName: String = "testspark"
+
+ /** Sets main (start) thread for IDE in headless as not edt. */
+ override val requiredModality: Int = ApplicationStarter.NOT_IN_EDT
+
+ @Suppress("TooGenericExceptionCaught")
+ @OptIn(ExperimentalSerializationApi::class)
+ override fun main(args: List) {
+ // Project path
+ val projectPath = args[1]
+ // Path to the target file (.java file)
+ val cutSourceFilePath = Paths.get(projectPath, args[2]).toAbsolutePath()
+ // CUT name (.)
+ val classUnderTestName = args[3]
+ // Paths to compilation output of the project under test (seperated by ':')
+ val projectClassPath = args[4]
+ val classPath = "$projectPath${ToolUtils.pathSep}$projectClassPath"
+ // JUnit Version
+ val jUnitVersion = args[5]
+ // Selected model
+ val model = args[6]
+ // Token
+ val token = args[7]
+ // Filepath to a file containing the prompt template
+ val promptTemplateFile = args[8]
+ // Output directory
+ val output = args[9]
+ // Run coverage
+ val runCoverage = args[10].toBoolean()
+
+ val testsExecutionResultManager = TestsExecutionResultManager()
+
+ println("Test generation requested for $projectPath")
+
+ // remove the `.idea` folder in the $projectPath if exists
+ val ideaFolderPath = "$projectPath${File.separator}.idea"
+ val ideaFolder = File(ideaFolderPath)
+ if (ideaFolder.exists()) {
+ ideaFolder.deleteRecursively()
+ }
+
+ // open and resolve the project
+ val project = try {
+ JvmProjectConfigurator().openProject(
+ Paths.get(projectPath),
+ fullResolve = true,
+ parentDisposable = Disposer.newDisposable(),
+ )
+ } catch (e: Throwable) {
+ e.printStackTrace(System.err)
+ exitProcess(1)
+ }
+
+ ApplicationManager.getApplication().invokeAndWait {
+ println("Detected project: $project")
+ // Continue when the project is indexed
+ println("Indexing project...")
+ project.let {
+ DumbService.getInstance(it).runWhenSmart {
+ try {
+ // open target file
+ val cutSourceVirtualFile =
+ LocalFileSystem.getInstance().findFileByPath(cutSourceFilePath.toString()) ?: run {
+ println("Couldn't open file $cutSourceFilePath")
+ exitProcess(1)
+ }
+
+ // get target PsiClass
+ val psiFile = PsiManager.getInstance(project).findFile(cutSourceVirtualFile) as PsiJavaFile
+ val targetPsiClass = detectPsiClass(psiFile.classes, classUnderTestName) ?: run {
+ println("Couldn't find $classUnderTestName in $cutSourceFilePath")
+ exitProcess(1)
+ }
+
+ println("PsiClass ${targetPsiClass.qualifiedName} is detected! Start the test generation process.")
+
+ // Get project SDK
+ val projectSDKPath = getProjectSdkPath(project)
+ // update settings
+ val settingsState = project.getService(LLMSettingsService::class.java).state
+ settingsState.currentLLMPlatformName = LLMDefaultsBundle.get("grazieName")
+ settingsState.grazieToken = token
+ settingsState.grazieModel = model
+ settingsState.classPrompts =
+ JsonEncoding.encode(mutableListOf(File(promptTemplateFile).readText()))
+ settingsState.junitVersion = when (jUnitVersion.filter { it.isDigit() }) {
+ "4" -> JUnitVersion.JUnit4
+ "5" -> JUnitVersion.JUnit5
+ else -> {
+ throw IllegalArgumentException("JUnit version $jUnitVersion is not supported. Supported JUnit versions are '4' and '5'")
+ }
+ }
+ project.service().state.buildPath = classPath
+
+ // Prepare Project Context
+ // First, get CUT Module
+ val cutModule = ProjectFileIndex.getInstance(project)
+ .getModuleForFile(targetPsiClass.containingFile.virtualFile)
+ // Then, instantiate the project context
+ val projectContext = ProjectContext(
+ classPath,
+ output,
+ targetPsiClass.qualifiedName,
+ cutModule,
+ )
+ // Prepare the test generation data
+ val testGenerationData = TestGenerationData(
+ resultPath = output,
+ testResultName = "HeadlessGeneratedTests",
+ )
+ println("[TestSpark Starter] Indexing is done")
+
+ // get package name
+ val packageList = targetPsiClass.qualifiedName.toString().split(".").dropLast(1).toMutableList()
+ val packageName = packageList.joinToString(".")
+
+ // Get PsiHelper
+ val psiHelper = PsiHelperProvider.getPsiHelper(psiFile)
+ if (psiHelper == null) {
+ // TODO exception: the support for the current language does not exist
+ }
+ // Create a process Manager
+ val llmProcessManager = Llm()
+ .getLLMProcessManager(
+ project,
+ psiHelper!!,
+ targetPsiClass.textRange.startOffset,
+ testSamplesCode = "", // we don't provide samples to LLM
+ projectSDKPath = projectSDKPath,
+ )
+
+ println("[TestSpark Starter] Starting the test generation process")
+ // Start test generation
+ val indicator = HeadlessProgressIndicator()
+ val errorMonitor = DefaultErrorMonitor()
+ val testCompiler = TestCompilerFactory.create(
+ project,
+ settingsState.junitVersion,
+ psiHelper.language,
+ projectSDKPath.toString(),
+ )
+ val uiContext = llmProcessManager.runTestGenerator(
+ indicator,
+ FragmentToTestData(CodeType.CLASS),
+ packageName,
+ projectContext,
+ testGenerationData,
+ errorMonitor,
+ testsExecutionResultManager,
+ )
+
+ // Check test Generation Output
+ if (uiContext != null) {
+ println("[TestSpark Starter] Test generation completed successfully")
+ // Run test file
+ if (runCoverage) {
+ runTestsWithCoverageCollection(
+ project,
+ output,
+ packageList,
+ classPath,
+ projectContext,
+ projectSDKPath,
+ testCompiler,
+ )
+ }
+ } else {
+ println("[TestSpark Starter] Test generation failed")
+ }
+
+ ProjectManager.getInstance().closeAndDispose(project)
+
+ println("[TestSpark Starter] Exiting the headless mode")
+ exitProcess(0)
+ } catch (e: Throwable) {
+ println("[TestSpark Starter] Exiting the headless mode with an exception")
+
+ ProjectManager.getInstance().closeAndDispose(project)
+ e.printStackTrace(System.err)
+ exitProcess(0)
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Retrieves the project SDK based on the provided parameters.
+ *
+ * @param project the project inder test
+ * @return the project SDK for running and compiling tests generated in headless mode
+ */
+ private fun getProjectSdkPath(project: Project): Path {
+ return when (val projectSdk = ProjectRootManager.getInstance(project).projectSdk) {
+ null -> {
+ println("Did not resolve the project SDK, using default SDK")
+ Paths.get(System.getProperty("java.home"))
+ }
+
+ else -> Paths.get(projectSdk.homeDirectory!!.path)
+ }
+ }
+
+ private fun runTestsWithCoverageCollection(
+ project: Project,
+ out: String,
+ packageList: MutableList,
+ classPath: String,
+ projectContext: ProjectContext,
+ projectSDKPath: Path,
+ testCompiler: TestCompiler,
+ ) {
+ val targetDirectory = "$out${File.separator}${packageList.joinToString(File.separator)}"
+ println("Run tests in $targetDirectory")
+ File(targetDirectory).walk().forEach {
+ if (it.name.endsWith(".class")) {
+ println("Running test ${it.name}")
+ var testcaseName = it.nameWithoutExtension.removePrefix("Generated")
+ testcaseName = testcaseName[0].lowercaseChar() + testcaseName.substring(1)
+ // The current test is compiled and is ready to run jacoco
+
+ val testExecutionError = TestProcessor(project, projectSDKPath).createXmlFromJacoco(
+ it.nameWithoutExtension,
+ "$targetDirectory${File.separator}jacoco-${it.nameWithoutExtension}",
+ testcaseName,
+ classPath,
+ packageList.joinToString("."),
+ out,
+ projectContext,
+ testCompiler,
+ )
+ // Saving exception (if exists) thrown during the test execution
+ saveException(testcaseName, targetDirectory, testExecutionError)
+ }
+ }
+ }
+
+ private fun saveException(
+ testcaseName: String,
+ targetDirectory: String,
+ testExecutionError: String,
+ ) {
+ if (testExecutionError.isBlank() || !testExecutionError.contains("Exception", ignoreCase = false)) {
+ return
+ }
+ val targetPath = Paths.get(targetDirectory, "$testcaseName-exception.log")
+
+ // Save the exception
+ targetPath.toFile().writeText(testExecutionError.replace("\tat ", "\nat "))
+ }
+
+ private fun detectPsiClass(classes: Array, classUnderTestName: String): PsiClass? {
+ for (psiClass in classes) {
+ if (psiClass.qualifiedName == classUnderTestName) {
+ return psiClass
+ }
+ }
+ return null
+ }
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/bundles/TestSparkBundle.kt b/src/main/kotlin/org/jetbrains/research/testspark/bundles/TestSparkBundle.kt
deleted file mode 100644
index 14da7445e..000000000
--- a/src/main/kotlin/org/jetbrains/research/testspark/bundles/TestSparkBundle.kt
+++ /dev/null
@@ -1,20 +0,0 @@
-package org.jetbrains.research.testspark.bundles
-
-import com.intellij.DynamicBundle
-import org.jetbrains.annotations.Nls
-import org.jetbrains.annotations.PropertyKey
-
-const val BUNDLE = "messages.TestSpark"
-
-/**
- * Loads the EvoSuite messages from `messages/TestSpark.properties` file in the `resources` directory.
- */
-object TestSparkBundle : DynamicBundle(BUNDLE) {
-
- /**
- * Gets the requested message.
- */
- @Nls
- fun message(@PropertyKey(resourceBundle = BUNDLE) key: String, vararg params: Any): String =
- getMessage(key, *params)
-}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/bundles/TestSparkDefaultsBundle.kt b/src/main/kotlin/org/jetbrains/research/testspark/bundles/TestSparkDefaultsBundle.kt
deleted file mode 100644
index bbece29e3..000000000
--- a/src/main/kotlin/org/jetbrains/research/testspark/bundles/TestSparkDefaultsBundle.kt
+++ /dev/null
@@ -1,19 +0,0 @@
-package org.jetbrains.research.testspark.bundles
-
-import com.intellij.DynamicBundle
-import org.jetbrains.annotations.Nls
-import org.jetbrains.annotations.PropertyKey
-
-const val DEFAULTS_BUNDLE = "defaults.TestSpark"
-
-/**
- * Loads the default values from `defaults/TestSpark.properties` file in the `resources` directory.
- */
-object TestSparkDefaultsBundle : DynamicBundle(DEFAULTS_BUNDLE) {
-
- /**
- * Gets the requested default value.
- */
- @Nls
- fun defaultValue(@PropertyKey(resourceBundle = DEFAULTS_BUNDLE) key: String): String = getMessage(key)
-}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/bundles/TestSparkLabelsBundle.kt b/src/main/kotlin/org/jetbrains/research/testspark/bundles/TestSparkLabelsBundle.kt
deleted file mode 100644
index 14778815a..000000000
--- a/src/main/kotlin/org/jetbrains/research/testspark/bundles/TestSparkLabelsBundle.kt
+++ /dev/null
@@ -1,19 +0,0 @@
-package org.jetbrains.research.testspark.bundles
-
-import com.intellij.DynamicBundle
-import org.jetbrains.annotations.Nls
-import org.jetbrains.annotations.PropertyKey
-
-const val LABELS_BUNDLE = "defaults.Labels"
-
-/**
- * Loads the label texts from `defaults/Labels.properties` file in the `recourses` directory.
- */
-object TestSparkLabelsBundle : DynamicBundle(LABELS_BUNDLE) {
-
- /**
- * Gets the requested default value.
- */
- @Nls
- fun defaultValue(@PropertyKey(resourceBundle = LABELS_BUNDLE) key: String): String = getMessage(key)
-}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/bundles/TestSparkToolTipsBundle.kt b/src/main/kotlin/org/jetbrains/research/testspark/bundles/TestSparkToolTipsBundle.kt
deleted file mode 100644
index 22ba81156..000000000
--- a/src/main/kotlin/org/jetbrains/research/testspark/bundles/TestSparkToolTipsBundle.kt
+++ /dev/null
@@ -1,19 +0,0 @@
-package org.jetbrains.research.testspark.bundles
-
-import com.intellij.DynamicBundle
-import org.jetbrains.annotations.Nls
-import org.jetbrains.annotations.PropertyKey
-
-const val TOOLTIPS_BUNDLE = "defaults.Tooltips"
-
-/**
- * Loads the tooltip texts from `defaults/Tooltips.properties` file in the `resources` directory.
- */
-object TestSparkToolTipsBundle : DynamicBundle(TOOLTIPS_BUNDLE) {
-
- /**
- * Gets the requested default value.
- */
- @Nls
- fun defaultValue(@PropertyKey(resourceBundle = TOOLTIPS_BUNDLE) key: String): String = getMessage(key)
-}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/bundles/evosuite/EvoSuiteBundlePaths.kt b/src/main/kotlin/org/jetbrains/research/testspark/bundles/evosuite/EvoSuiteBundlePaths.kt
new file mode 100644
index 000000000..a9365de69
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/bundles/evosuite/EvoSuiteBundlePaths.kt
@@ -0,0 +1,8 @@
+package org.jetbrains.research.testspark.bundles.evosuite
+
+object EvoSuiteBundlePaths {
+ const val defaults: String = "properties.evosuite.EvoSuiteDefaults"
+ const val messages: String = "properties.evosuite.EvoSuiteMessages"
+ const val labels: String = "properties.evosuite.EvoSuiteLabels"
+ const val settings: String = "properties.evosuite.EvoSuiteSettings"
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/bundles/evosuite/EvoSuiteDefaultsBundle.kt b/src/main/kotlin/org/jetbrains/research/testspark/bundles/evosuite/EvoSuiteDefaultsBundle.kt
new file mode 100644
index 000000000..f32b79165
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/bundles/evosuite/EvoSuiteDefaultsBundle.kt
@@ -0,0 +1,17 @@
+package org.jetbrains.research.testspark.bundles.evosuite
+
+import com.intellij.DynamicBundle
+import org.jetbrains.annotations.Nls
+import org.jetbrains.annotations.PropertyKey
+
+/**
+ * Loads the `resources` directory.
+ */
+object EvoSuiteDefaultsBundle : DynamicBundle(EvoSuiteBundlePaths.defaults) {
+
+ /**
+ * Gets the requested default value.
+ */
+ @Nls
+ fun get(@PropertyKey(resourceBundle = EvoSuiteBundlePaths.defaults) key: String): String = getMessage(key)
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/bundles/evosuite/EvoSuiteLabelsBundle.kt b/src/main/kotlin/org/jetbrains/research/testspark/bundles/evosuite/EvoSuiteLabelsBundle.kt
new file mode 100644
index 000000000..3758ab5a2
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/bundles/evosuite/EvoSuiteLabelsBundle.kt
@@ -0,0 +1,17 @@
+package org.jetbrains.research.testspark.bundles.evosuite
+
+import com.intellij.DynamicBundle
+import org.jetbrains.annotations.Nls
+import org.jetbrains.annotations.PropertyKey
+
+/**
+ * Loads the `recourses` directory.
+ */
+object EvoSuiteLabelsBundle : DynamicBundle(EvoSuiteBundlePaths.labels) {
+
+ /**
+ * Gets the requested default value.
+ */
+ @Nls
+ fun get(@PropertyKey(resourceBundle = EvoSuiteBundlePaths.labels) key: String): String = getMessage(key)
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/bundles/evosuite/EvoSuiteMessagesBundle.kt b/src/main/kotlin/org/jetbrains/research/testspark/bundles/evosuite/EvoSuiteMessagesBundle.kt
new file mode 100644
index 000000000..da606470b
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/bundles/evosuite/EvoSuiteMessagesBundle.kt
@@ -0,0 +1,18 @@
+package org.jetbrains.research.testspark.bundles.evosuite
+
+import com.intellij.DynamicBundle
+import org.jetbrains.annotations.Nls
+import org.jetbrains.annotations.PropertyKey
+
+/**
+ * Loads the `resources` directory.
+ */
+object EvoSuiteMessagesBundle : DynamicBundle(EvoSuiteBundlePaths.messages) {
+
+ /**
+ * Gets the requested message.
+ */
+ @Nls
+ fun get(@PropertyKey(resourceBundle = EvoSuiteBundlePaths.messages) key: String, vararg params: Any): String =
+ getMessage(key, *params)
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/bundles/evosuite/EvoSuiteSettingsBundle.kt b/src/main/kotlin/org/jetbrains/research/testspark/bundles/evosuite/EvoSuiteSettingsBundle.kt
new file mode 100644
index 000000000..5cb0acb1e
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/bundles/evosuite/EvoSuiteSettingsBundle.kt
@@ -0,0 +1,17 @@
+package org.jetbrains.research.testspark.bundles.evosuite
+
+import com.intellij.DynamicBundle
+import org.jetbrains.annotations.Nls
+import org.jetbrains.annotations.PropertyKey
+
+/**
+ * Loads the `resources` directory.
+ */
+object EvoSuiteSettingsBundle : DynamicBundle(EvoSuiteBundlePaths.settings) {
+
+ /**
+ * Gets the requested default value.
+ */
+ @Nls
+ fun get(@PropertyKey(resourceBundle = EvoSuiteBundlePaths.settings) key: String): String = getMessage(key)
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/bundles/llm/LLMBundlePaths.kt b/src/main/kotlin/org/jetbrains/research/testspark/bundles/llm/LLMBundlePaths.kt
new file mode 100644
index 000000000..98037f681
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/bundles/llm/LLMBundlePaths.kt
@@ -0,0 +1,8 @@
+package org.jetbrains.research.testspark.bundles.llm
+
+object LLMBundlePaths {
+ const val defaults: String = "properties.llm.LLMDefaults"
+ const val messages: String = "properties.llm.LLMMessages"
+ const val labels: String = "properties.llm.LLMLabels"
+ const val settings: String = "properties.llm.LLMSettings"
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/bundles/llm/LLMDefaultsBundle.kt b/src/main/kotlin/org/jetbrains/research/testspark/bundles/llm/LLMDefaultsBundle.kt
new file mode 100644
index 000000000..f016e1bc8
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/bundles/llm/LLMDefaultsBundle.kt
@@ -0,0 +1,21 @@
+package org.jetbrains.research.testspark.bundles.llm
+
+import com.intellij.DynamicBundle
+import org.jetbrains.annotations.Nls
+import org.jetbrains.annotations.PropertyKey
+
+/**
+ * Loads the `resources` directory.
+ */
+object LLMDefaultsBundle : DynamicBundle(LLMBundlePaths.defaults) {
+
+ /**
+ * Gets the requested default value.
+ */
+ @Nls
+ fun get(@PropertyKey(resourceBundle = LLMBundlePaths.defaults) key: String): String = getMessage(key)
+ // In Intellij Platform version 2, the DynamicBundle returns the whole path and the value at the end in plugin verification.
+ // Each is separated by "|" (e.g., "|b|properties.llm.LLMDefaults|k|maxLLMRequest|3")
+ // if we do not split them here, the process will throw java.lang.NumberFormatException
+ .split("|").last()
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/bundles/llm/LLMLabelsBundle.kt b/src/main/kotlin/org/jetbrains/research/testspark/bundles/llm/LLMLabelsBundle.kt
new file mode 100644
index 000000000..10c979290
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/bundles/llm/LLMLabelsBundle.kt
@@ -0,0 +1,17 @@
+package org.jetbrains.research.testspark.bundles.llm
+
+import com.intellij.DynamicBundle
+import org.jetbrains.annotations.Nls
+import org.jetbrains.annotations.PropertyKey
+
+/**
+ * Loads the `recourses` directory.
+ */
+object LLMLabelsBundle : DynamicBundle(LLMBundlePaths.labels) {
+
+ /**
+ * Gets the requested default value.
+ */
+ @Nls
+ fun get(@PropertyKey(resourceBundle = LLMBundlePaths.labels) key: String): String = getMessage(key)
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/bundles/llm/LLMMessagesBundle.kt b/src/main/kotlin/org/jetbrains/research/testspark/bundles/llm/LLMMessagesBundle.kt
new file mode 100644
index 000000000..aa3ad988a
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/bundles/llm/LLMMessagesBundle.kt
@@ -0,0 +1,18 @@
+package org.jetbrains.research.testspark.bundles.llm
+
+import com.intellij.DynamicBundle
+import org.jetbrains.annotations.Nls
+import org.jetbrains.annotations.PropertyKey
+
+/**
+ * Loads the `resources` directory.
+ */
+object LLMMessagesBundle : DynamicBundle(LLMBundlePaths.messages) {
+
+ /**
+ * Gets the requested message.
+ */
+ @Nls
+ fun get(@PropertyKey(resourceBundle = LLMBundlePaths.messages) key: String, vararg params: Any): String =
+ getMessage(key, *params)
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/bundles/llm/LLMSettingsBundle.kt b/src/main/kotlin/org/jetbrains/research/testspark/bundles/llm/LLMSettingsBundle.kt
new file mode 100644
index 000000000..3ae64642c
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/bundles/llm/LLMSettingsBundle.kt
@@ -0,0 +1,17 @@
+package org.jetbrains.research.testspark.bundles.llm
+
+import com.intellij.DynamicBundle
+import org.jetbrains.annotations.Nls
+import org.jetbrains.annotations.PropertyKey
+
+/**
+ * Loads the `resources` directory.
+ */
+object LLMSettingsBundle : DynamicBundle(LLMBundlePaths.settings) {
+
+ /**
+ * Gets the requested default value.
+ */
+ @Nls
+ fun get(@PropertyKey(resourceBundle = LLMBundlePaths.settings) key: String): String = getMessage(key)
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/bundles/plugin/PluginBundlePaths.kt b/src/main/kotlin/org/jetbrains/research/testspark/bundles/plugin/PluginBundlePaths.kt
new file mode 100644
index 000000000..e5c3ecfcd
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/bundles/plugin/PluginBundlePaths.kt
@@ -0,0 +1,8 @@
+package org.jetbrains.research.testspark.bundles.plugin
+
+object PluginBundlePaths {
+ const val defaults: String = "properties.plugin.PluginDefaults"
+ const val messages: String = "properties.plugin.PluginMessages"
+ const val labels: String = "properties.plugin.PluginLabels"
+ const val settings: String = "properties.plugin.PluginSettings"
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/bundles/plugin/PluginDefaultsBundle.kt b/src/main/kotlin/org/jetbrains/research/testspark/bundles/plugin/PluginDefaultsBundle.kt
new file mode 100644
index 000000000..ed8ac286d
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/bundles/plugin/PluginDefaultsBundle.kt
@@ -0,0 +1,17 @@
+package org.jetbrains.research.testspark.bundles.plugin
+
+import com.intellij.DynamicBundle
+import org.jetbrains.annotations.Nls
+import org.jetbrains.annotations.PropertyKey
+
+/**
+ * Loads the `resources` directory.
+ */
+object PluginDefaultsBundle : DynamicBundle(PluginBundlePaths.defaults) {
+
+ /**
+ * Gets the requested default value.
+ */
+ @Nls
+ fun get(@PropertyKey(resourceBundle = PluginBundlePaths.defaults) key: String): String = getMessage(key)
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/bundles/plugin/PluginLabelsBundle.kt b/src/main/kotlin/org/jetbrains/research/testspark/bundles/plugin/PluginLabelsBundle.kt
new file mode 100644
index 000000000..56baf1a40
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/bundles/plugin/PluginLabelsBundle.kt
@@ -0,0 +1,17 @@
+package org.jetbrains.research.testspark.bundles.plugin
+
+import com.intellij.DynamicBundle
+import org.jetbrains.annotations.Nls
+import org.jetbrains.annotations.PropertyKey
+
+/**
+ * Loads the `recourses` directory.
+ */
+object PluginLabelsBundle : DynamicBundle(PluginBundlePaths.labels) {
+
+ /**
+ * Gets the requested default value.
+ */
+ @Nls
+ fun get(@PropertyKey(resourceBundle = PluginBundlePaths.labels) key: String): String = getMessage(key)
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/bundles/plugin/PluginMessagesBundle.kt b/src/main/kotlin/org/jetbrains/research/testspark/bundles/plugin/PluginMessagesBundle.kt
new file mode 100644
index 000000000..1e31d4314
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/bundles/plugin/PluginMessagesBundle.kt
@@ -0,0 +1,18 @@
+package org.jetbrains.research.testspark.bundles.plugin
+
+import com.intellij.DynamicBundle
+import org.jetbrains.annotations.Nls
+import org.jetbrains.annotations.PropertyKey
+
+/**
+ * Loads the `resources` directory.
+ */
+object PluginMessagesBundle : DynamicBundle(PluginBundlePaths.messages) {
+
+ /**
+ * Gets the requested message.
+ */
+ @Nls
+ fun get(@PropertyKey(resourceBundle = PluginBundlePaths.messages) key: String, vararg params: Any): String =
+ getMessage(key, *params)
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/bundles/plugin/PluginSettingsBundle.kt b/src/main/kotlin/org/jetbrains/research/testspark/bundles/plugin/PluginSettingsBundle.kt
new file mode 100644
index 000000000..a5745f0d4
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/bundles/plugin/PluginSettingsBundle.kt
@@ -0,0 +1,17 @@
+package org.jetbrains.research.testspark.bundles.plugin
+
+import com.intellij.DynamicBundle
+import org.jetbrains.annotations.Nls
+import org.jetbrains.annotations.PropertyKey
+
+/**
+ * Loads the `resources` directory.
+ */
+object PluginSettingsBundle : DynamicBundle(PluginBundlePaths.settings) {
+
+ /**
+ * Gets the requested default value.
+ */
+ @Nls
+ fun get(@PropertyKey(resourceBundle = PluginBundlePaths.settings) key: String): String = getMessage(key)
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/data/FragmentToTestData.kt b/src/main/kotlin/org/jetbrains/research/testspark/data/FragmentToTestData.kt
index 0cf79dddb..3c289bb11 100644
--- a/src/main/kotlin/org/jetbrains/research/testspark/data/FragmentToTestData.kt
+++ b/src/main/kotlin/org/jetbrains/research/testspark/data/FragmentToTestData.kt
@@ -1,5 +1,7 @@
package org.jetbrains.research.testspark.data
+import org.jetbrains.research.testspark.core.test.data.CodeType
+
/**
* Data about test objects that require test generators.
*/
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/data/JsonEncoding.kt b/src/main/kotlin/org/jetbrains/research/testspark/data/JsonEncoding.kt
deleted file mode 100644
index 45aaa6e9a..000000000
--- a/src/main/kotlin/org/jetbrains/research/testspark/data/JsonEncoding.kt
+++ /dev/null
@@ -1,46 +0,0 @@
-package org.jetbrains.research.testspark.data
-
-import kotlinx.serialization.builtins.ListSerializer
-import kotlinx.serialization.builtins.serializer
-import kotlinx.serialization.json.Json
-
-/**
- * JsonEncoding is required to work with an array of data in the form of strings.
- * You need to be able to concatenate this data into a single string to save it in SettingsApplicationState,
- * and convert it back to a list of strings correctly.
- */
-class JsonEncoding {
- companion object {
-
- /**
- * Decode a string into a list of strings
- */
- fun decode(jsonString: String): MutableList =
- Json.decodeFromString(ListSerializer(String.serializer()), jsonString) as MutableList
-
- /**
- * Encode a list of strings into a string
- */
- fun encode(values: MutableList): String {
- var jsonString = Json.encodeToString(
- ListSerializer(String.serializer()),
- values,
- )
- // These characters are incorrectly stored in json, so the following substitutions are required
- val replacements = mapOf(
- "\\n" to "\n",
- "\\t" to "\t",
- "\\r" to "\r",
- "\\\\" to "\\",
- "\\\"" to "\"",
- "\\'" to "\'",
- "\\b" to "\b",
- "\\f" to "\u000c",
- )
- replacements.forEach { (key, value) ->
- jsonString = jsonString.replace(key, value)
- }
- return jsonString
- }
- }
-}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/data/ProjectContext.kt b/src/main/kotlin/org/jetbrains/research/testspark/data/ProjectContext.kt
index 1b8c1d7ef..ea34cfc32 100644
--- a/src/main/kotlin/org/jetbrains/research/testspark/data/ProjectContext.kt
+++ b/src/main/kotlin/org/jetbrains/research/testspark/data/ProjectContext.kt
@@ -1,7 +1,6 @@
package org.jetbrains.research.testspark.data
import com.intellij.openapi.module.Module
-import com.intellij.psi.PsiClass
data class ProjectContext(
// The class path of the project.
@@ -10,9 +9,6 @@ data class ProjectContext(
// The URL of the file being tested.
var fileUrlAsString: String? = null,
- // The PsiClass of the class under test
- var cutPsiClass: PsiClass? = null,
-
// The full qualified name of the class under test
var classFQN: String? = null,
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/data/UIContext.kt b/src/main/kotlin/org/jetbrains/research/testspark/data/UIContext.kt
index 255f4a692..1697fd66c 100644
--- a/src/main/kotlin/org/jetbrains/research/testspark/data/UIContext.kt
+++ b/src/main/kotlin/org/jetbrains/research/testspark/data/UIContext.kt
@@ -2,9 +2,12 @@ package org.jetbrains.research.testspark.data
import org.jetbrains.research.testspark.core.data.TestGenerationData
import org.jetbrains.research.testspark.core.generation.llm.network.RequestManager
+import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor
+import org.jetbrains.research.testspark.core.monitor.ErrorMonitor
data class UIContext(
val projectContext: ProjectContext,
val testGenerationOutput: TestGenerationData,
var requestManager: RequestManager? = null,
+ val errorMonitor: ErrorMonitor = DefaultErrorMonitor(),
)
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/data/ContentDigestAlgorithm.kt b/src/main/kotlin/org/jetbrains/research/testspark/data/evosuite/ContentDigestAlgorithm.kt
similarity index 88%
rename from src/main/kotlin/org/jetbrains/research/testspark/data/ContentDigestAlgorithm.kt
rename to src/main/kotlin/org/jetbrains/research/testspark/data/evosuite/ContentDigestAlgorithm.kt
index 309597e8e..5b7af3896 100644
--- a/src/main/kotlin/org/jetbrains/research/testspark/data/ContentDigestAlgorithm.kt
+++ b/src/main/kotlin/org/jetbrains/research/testspark/data/evosuite/ContentDigestAlgorithm.kt
@@ -1,4 +1,4 @@
-package org.jetbrains.research.testspark.data
+package org.jetbrains.research.testspark.data.evosuite
enum class ContentDigestAlgorithm {
// random
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/data/llm/JsonEncoding.kt b/src/main/kotlin/org/jetbrains/research/testspark/data/llm/JsonEncoding.kt
new file mode 100644
index 000000000..50019a6bf
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/data/llm/JsonEncoding.kt
@@ -0,0 +1,29 @@
+package org.jetbrains.research.testspark.data.llm
+
+import kotlinx.serialization.builtins.ListSerializer
+import kotlinx.serialization.builtins.serializer
+import kotlinx.serialization.json.Json
+
+/**
+ * JsonEncoding is required to work with an array of data in the form of strings.
+ * You need to be able to concatenate this data into a single string to save it in SettingsApplicationState,
+ * and convert it back to a list of strings correctly.
+ */
+class JsonEncoding {
+ companion object {
+
+ /**
+ * Decode a string into a list of strings
+ */
+ fun decode(jsonString: String): MutableList =
+ Json.decodeFromString(ListSerializer(String.serializer()), jsonString) as MutableList
+
+ /**
+ * Encode a list of strings into a string
+ */
+ fun encode(values: MutableList): String = Json.encodeToString(
+ ListSerializer(String.serializer()),
+ values,
+ )
+ }
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/PromptEditorType.kt b/src/main/kotlin/org/jetbrains/research/testspark/data/llm/PromptEditorType.kt
similarity index 71%
rename from src/main/kotlin/org/jetbrains/research/testspark/settings/llm/PromptEditorType.kt
rename to src/main/kotlin/org/jetbrains/research/testspark/data/llm/PromptEditorType.kt
index e5b239c3f..b516ef2a1 100644
--- a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/PromptEditorType.kt
+++ b/src/main/kotlin/org/jetbrains/research/testspark/data/llm/PromptEditorType.kt
@@ -1,4 +1,4 @@
-package org.jetbrains.research.testspark.settings.llm
+package org.jetbrains.research.testspark.data.llm
enum class PromptEditorType(val text: String, val index: Int) {
CLASS("Class", 0),
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/IconButtonCreator.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/IconButtonCreator.kt
deleted file mode 100644
index 617ed8d5a..000000000
--- a/src/main/kotlin/org/jetbrains/research/testspark/display/IconButtonCreator.kt
+++ /dev/null
@@ -1,22 +0,0 @@
-package org.jetbrains.research.testspark.display
-
-import java.awt.Dimension
-import javax.swing.Icon
-import javax.swing.JButton
-
-/**
- * Creates a button with the specified icon.
- *
- * @param icon the icon to be displayed on the button
- * @return the created button
- */
-fun createButton(icon: Icon, tip: String): JButton {
- val button = JButton(icon)
- button.isOpaque = false
- button.isContentAreaFilled = false
- button.isBorderPainted = false
- button.toolTipText = tip
- val size = button.preferredSize.height
- button.preferredSize = Dimension(size, size)
- return button
-}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/ModifiedLinesGetter.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/ModifiedLinesGetter.kt
deleted file mode 100644
index e077923c4..000000000
--- a/src/main/kotlin/org/jetbrains/research/testspark/display/ModifiedLinesGetter.kt
+++ /dev/null
@@ -1,48 +0,0 @@
-package org.jetbrains.research.testspark.display
-
-/**
- * Returns the indexes of lines that are modified between two lists of strings.
- *
- * @param source The source list of strings.
- * @param target The target list of strings.
- * @return The indexes of modified lines.
- */
-fun getModifiedLines(source: List, target: List): List {
- val dp = Array(source.size + 1) { IntArray(target.size + 1) }
-
- for (i in 1..source.size) {
- for (j in 1..target.size) {
- if (source[i - 1] == target[j - 1]) {
- dp[i][j] = dp[i - 1][j - 1] + 1
- } else {
- dp[i][j] = maxOf(dp[i - 1][j], dp[i][j - 1])
- }
- }
- }
-
- var i = source.size
- var j = target.size
-
- val modifiedLineIndexes = mutableListOf()
-
- while (i > 0 && j > 0) {
- if (source[i - 1] == target[j - 1]) {
- i--
- j--
- } else if (dp[i][j] == dp[i - 1][j]) {
- i--
- } else if (dp[i][j] == dp[i][j - 1]) {
- modifiedLineIndexes.add(j - 1)
- j--
- }
- }
-
- while (j > 0) {
- modifiedLineIndexes.add(j - 1)
- j--
- }
-
- modifiedLineIndexes.reverse()
-
- return modifiedLineIndexes
-}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/TestSparkDisplayManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/TestSparkDisplayManager.kt
new file mode 100644
index 000000000..710e98f79
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/display/TestSparkDisplayManager.kt
@@ -0,0 +1,102 @@
+package org.jetbrains.research.testspark.display
+
+import com.intellij.openapi.editor.Editor
+import com.intellij.openapi.project.Project
+import com.intellij.openapi.wm.ToolWindow
+import com.intellij.openapi.wm.ToolWindowManager
+import com.intellij.serviceContainer.AlreadyDisposedException
+import com.intellij.ui.content.ContentManager
+import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle
+import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle
+import org.jetbrains.research.testspark.core.data.Report
+import org.jetbrains.research.testspark.core.test.SupportedLanguage
+import org.jetbrains.research.testspark.data.UIContext
+import org.jetbrains.research.testspark.display.coverage.CoverageVisualisationTabBuilder
+import org.jetbrains.research.testspark.display.generatedTests.GeneratedTestsTabBuilder
+import org.jetbrains.research.testspark.tools.TestsExecutionResultManager
+import java.awt.Component
+import javax.swing.JOptionPane
+
+/**
+ * The TestSparkDisplayBuilder class is responsible for displaying the generated test cases and related information in the TestSpark tool window.
+ * It provides methods to fill the panel with the generated test cases, remove all previously shown test cases, and clear the panel.
+ * The TestSparkDisplayBuilder class uses the CoverageVisualisationTabBuilder and GeneratedTestsTabBuilder classes to build and show the tabs containing the coverage visualisation and generated test data.
+ *
+ * @property toolWindow The ToolWindow object representing the TestSpark tool window.
+ * @property contentManager The ContentManager object responsible for managing the contents of the tool window.
+ * @property editor The Editor object used to display and edit the code.
+ * @property coverageVisualisationTabBuilder The CoverageVisualisationTabBuilder object used to build and show the coverage visualisation tab.
+ * @property generatedTestsTabBuilder The GeneratedTestsTabBuilder object used to build and show the generated tests tab.
+ */
+class TestSparkDisplayManager {
+ private var toolWindow: ToolWindow? = null
+
+ private var contentManager: ContentManager? = null
+
+ private var editor: Editor? = null
+
+ private var coverageVisualisationTabBuilder: CoverageVisualisationTabBuilder? = null
+ private var generatedTestsTabBuilder: GeneratedTestsTabBuilder? = null
+
+ /**
+ * Fill the panel with the generated test cases.
+ */
+ fun display(report: Report, editor: Editor, uiContext: UIContext, language: SupportedLanguage, project: Project, testsExecutionResultManager: TestsExecutionResultManager) {
+ this.toolWindow = ToolWindowManager.getInstance(project).getToolWindow("TestSpark")
+ this.contentManager = toolWindow!!.contentManager
+
+ this.editor = editor
+
+ coverageVisualisationTabBuilder = CoverageVisualisationTabBuilder(project, editor)
+ generatedTestsTabBuilder = GeneratedTestsTabBuilder(project, report, editor, uiContext, coverageVisualisationTabBuilder!!, testsExecutionResultManager)
+
+ generatedTestsTabBuilder!!.show(contentManager!!, language)
+ coverageVisualisationTabBuilder!!.show(report, generatedTestsTabBuilder!!.generatedTestsTabData())
+
+ toolWindow!!.show()
+
+ // removing all tests
+ generatedTestsTabBuilder!!.getRemoveAllButton().addActionListener {
+ // in case of empty list -- just call clear method
+ if (generatedTestsTabBuilder!!.generatedTestsTabData().testCaseNameToPanel.isEmpty()) {
+ clear()
+ return@addActionListener
+ }
+
+ val parentComponent: Component? = null
+ val choice = JOptionPane.showConfirmDialog(
+ parentComponent,
+ PluginMessagesBundle.get("removeAllMessage"),
+ PluginMessagesBundle.get("confirmationTitle"),
+ JOptionPane.OK_CANCEL_OPTION,
+ JOptionPane.WARNING_MESSAGE,
+ )
+
+ if (choice == JOptionPane.OK_OPTION) {
+ clear()
+ }
+ }
+
+ generatedTestsTabBuilder!!.getApplyButton().addActionListener {
+ if (generatedTestsTabBuilder!!.applyTests()) clear()
+ }
+ }
+
+ fun clear() {
+ editor?.markupModel?.removeAllHighlighters()
+
+ coverageVisualisationTabBuilder?.clear()
+ generatedTestsTabBuilder?.clear()
+
+ if (contentManager != null) {
+ for (content in contentManager!!.contents) {
+ if (content.tabName != PluginLabelsBundle.get("descriptionWindow")) {
+ contentManager?.removeContent(content, true)
+ }
+ }
+ }
+ try {
+ toolWindow?.hide()
+ } catch (_: AlreadyDisposedException) {} // Make sure the process continues if the tool window is already closed
+ }
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/TopButtonsPanelFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/TopButtonsPanelFactory.kt
deleted file mode 100644
index 4cfea2d75..000000000
--- a/src/main/kotlin/org/jetbrains/research/testspark/display/TopButtonsPanelFactory.kt
+++ /dev/null
@@ -1,196 +0,0 @@
-package org.jetbrains.research.testspark.display
-
-import com.intellij.openapi.components.service
-import com.intellij.openapi.progress.ProgressIndicator
-import com.intellij.openapi.progress.ProgressManager
-import com.intellij.openapi.progress.Task
-import com.intellij.openapi.project.Project
-import org.jetbrains.research.testspark.bundles.TestSparkBundle
-import org.jetbrains.research.testspark.bundles.TestSparkLabelsBundle
-import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator
-import org.jetbrains.research.testspark.services.TestCaseDisplayService
-import java.awt.Dimension
-import java.util.LinkedList
-import java.util.Queue
-import javax.swing.Box
-import javax.swing.BoxLayout
-import javax.swing.JButton
-import javax.swing.JCheckBox
-import javax.swing.JLabel
-import javax.swing.JOptionPane
-import javax.swing.JPanel
-
-class TopButtonsPanelFactory(private val project: Project) {
- private var runAllButton: JButton = createRunAllTestButton()
- private var selectAllButton: JButton =
- createButton(TestSparkIcons.selectAll, TestSparkLabelsBundle.defaultValue("selectAllTip"))
- private var unselectAllButton: JButton =
- createButton(TestSparkIcons.unselectAll, TestSparkLabelsBundle.defaultValue("unselectAllTip"))
- private var removeAllButton: JButton =
- createButton(TestSparkIcons.removeAll, TestSparkLabelsBundle.defaultValue("removeAllTip"))
-
- private var testsSelectedText: String = "${TestSparkLabelsBundle.defaultValue("testsSelected")}: %d/%d"
- private var testsSelectedLabel: JLabel = JLabel(testsSelectedText)
-
- private val testsPassedText: String = "${TestSparkLabelsBundle.defaultValue("testsPassed")}: %d/%d"
- private var testsPassedLabel: JLabel = JLabel(testsPassedText)
-
- private val testCasePanelFactories = arrayListOf()
-
- fun getPanel(): JPanel {
- val panel = JPanel()
- panel.layout = BoxLayout(panel, BoxLayout.X_AXIS)
- panel.preferredSize = Dimension(0, 30)
- panel.add(Box.createRigidArea(Dimension(10, 0)))
- panel.add(testsPassedLabel)
- panel.add(Box.createRigidArea(Dimension(10, 0)))
- panel.add(testsSelectedLabel)
- panel.add(Box.createHorizontalGlue())
- panel.add(runAllButton)
- panel.add(selectAllButton)
- panel.add(unselectAllButton)
- panel.add(removeAllButton)
-
- selectAllButton.addActionListener { toggleAllCheckboxes(true) }
- unselectAllButton.addActionListener { toggleAllCheckboxes(false) }
- removeAllButton.addActionListener { removeAllTestCases() }
- runAllButton.addActionListener { runAllTestCases() }
-
- return panel
- }
-
- /**
- * Updates the labels.
- */
- fun updateTopLabels() {
- var numberOfPassedTests = 0
- for (testCasePanelFactory in testCasePanelFactories) {
- if (testCasePanelFactory.isRemoved()) continue
- val error = testCasePanelFactory.getError()
- if ((error is String) && error.isEmpty()) {
- numberOfPassedTests++
- }
- }
- testsSelectedLabel.text = String.format(
- testsSelectedText,
- project.service().getTestsSelected(),
- project.service().getTestCasePanels().size,
- )
- testsPassedLabel.text =
- String.format(
- testsPassedText,
- numberOfPassedTests,
- project.service().getTestCasePanels().size,
- )
- runAllButton.isEnabled = false
- for (testCasePanelFactory in testCasePanelFactories) {
- runAllButton.isEnabled = runAllButton.isEnabled || testCasePanelFactory.isRunEnabled()
- }
- }
-
- /**
- * Sets the array of TestCasePanelFactory objects.
- *
- * @param testCasePanelFactories The ArrayList containing the TestCasePanelFactory objects to be set.
- */
- fun setTestCasePanelFactoriesArray(testCasePanelFactories: ArrayList) {
- this.testCasePanelFactories.addAll(testCasePanelFactories)
- }
-
- /**
- * Toggles check boxes so that they are either all selected or all not selected,
- * depending on the provided parameter.
- *
- * @param selected whether the checkboxes have to be selected or not
- */
- private fun toggleAllCheckboxes(selected: Boolean) {
- project.service().getTestCasePanels().forEach { (_, jPanel) ->
- val checkBox = jPanel.getComponent(0) as JCheckBox
- checkBox.isSelected = selected
- }
- project.service()
- .setTestsSelected(if (selected) project.service().getTestCasePanels().size else 0)
- }
-
- /**
- * Removes all test cases from the cache and tool window UI.
- */
- private fun removeAllTestCases() {
- // Ask the user for the confirmation
- val choice = JOptionPane.showConfirmDialog(
- null,
- TestSparkBundle.message("removeAllMessage"),
- TestSparkBundle.message("confirmationTitle"),
- JOptionPane.YES_NO_OPTION,
- JOptionPane.QUESTION_MESSAGE,
- )
-
- // Cancel the operation if the user did not press "Yes"
- if (choice == JOptionPane.NO_OPTION) return
-
- project.service().clear()
- }
-
- /**
- * Executes all test cases.
- *
- * This method presents a caution message to the user and asks for confirmation before executing the test cases.
- * If the user confirms, it iterates through each test case panel factory and runs the corresponding test.
- */
- private fun runAllTestCases() {
- val choice = JOptionPane.showConfirmDialog(
- null,
- TestSparkBundle.message("runCautionMessage"),
- TestSparkBundle.message("confirmationTitle"),
- JOptionPane.OK_CANCEL_OPTION,
- JOptionPane.WARNING_MESSAGE,
- )
-
- if (choice == JOptionPane.CANCEL_OPTION) return
-
- runAllButton.isEnabled = false
-
- // add each test generation task to queue
- val tasks: Queue<(CustomProgressIndicator) -> Unit> = LinkedList()
-
- for (testCasePanelFactory in testCasePanelFactories) {
- testCasePanelFactory.addTask(tasks)
- }
- // run tasks one after each other
- executeTasks(tasks)
- }
-
- private fun executeTasks(tasks: Queue<(CustomProgressIndicator) -> Unit>) {
- val nextTask = tasks.poll()
-
- nextTask?.let { task ->
- ProgressManager.getInstance().run(object : Task.Backgroundable(project, "Test execution") {
- override fun run(indicator: ProgressIndicator) {
- task(IJProgressIndicator(indicator))
- }
-
- override fun onFinished() {
- super.onFinished()
- executeTasks(tasks)
- }
- })
- }
- }
-
- /**
- * Creates a JButton for running all tests.
- *
- * @return a JButton for running all tests
- */
- private fun createRunAllTestButton(): JButton {
- val runTestButton = JButton(TestSparkLabelsBundle.defaultValue("runAll"), TestSparkIcons.runTest)
- runTestButton.isOpaque = false
- runTestButton.isContentAreaFilled = false
- runTestButton.isBorderPainted = true
- return runTestButton
- }
-
- fun clear() {
- testCasePanelFactories.clear()
- }
-}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/coverage/CoverageRenderer.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/coverage/CoverageRenderer.kt
similarity index 53%
rename from src/main/kotlin/org/jetbrains/research/testspark/coverage/CoverageRenderer.kt
rename to src/main/kotlin/org/jetbrains/research/testspark/display/coverage/CoverageRenderer.kt
index c9173629a..376b542a9 100644
--- a/src/main/kotlin/org/jetbrains/research/testspark/coverage/CoverageRenderer.kt
+++ b/src/main/kotlin/org/jetbrains/research/testspark/display/coverage/CoverageRenderer.kt
@@ -1,8 +1,7 @@
-package org.jetbrains.research.testspark.coverage
+package org.jetbrains.research.testspark.display.coverage
import com.intellij.codeInsight.hint.HintManager
import com.intellij.codeInsight.hint.HintManagerImpl
-import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.components.service
import com.intellij.openapi.editor.Editor
import com.intellij.openapi.editor.LogicalPosition
@@ -10,19 +9,26 @@ import com.intellij.openapi.editor.ex.EditorGutterComponentEx
import com.intellij.openapi.editor.markup.ActiveGutterRenderer
import com.intellij.openapi.editor.markup.LineMarkerRendererEx
import com.intellij.openapi.project.Project
+import com.intellij.openapi.wm.ToolWindowManager
+import com.intellij.ui.EditorTextField
import com.intellij.ui.HintHint
+import com.intellij.ui.JBColor
import com.intellij.ui.LightweightHint
import com.intellij.ui.components.ActionLink
import com.intellij.ui.components.JBLabel
import com.intellij.ui.components.JBScrollPane
import com.intellij.util.ui.FormBuilder
-import org.jetbrains.research.testspark.services.SettingsApplicationService
-import org.jetbrains.research.testspark.services.TestCaseDisplayService
+import org.jetbrains.research.testspark.bundles.plugin.PluginSettingsBundle
+import org.jetbrains.research.testspark.display.generatedTests.GeneratedTestsTabData
+import org.jetbrains.research.testspark.services.EvoSuiteSettingsService
+import org.jetbrains.research.testspark.services.PluginSettingsService
+import org.jetbrains.research.testspark.settings.evosuite.EvoSuiteSettingsState
import java.awt.Color
import java.awt.Dimension
import java.awt.Graphics
import java.awt.Rectangle
import java.awt.event.MouseEvent
+import javax.swing.JPanel
/**
* This class extends the line marker and gutter editor to allow more functionality.
@@ -43,8 +49,13 @@ class CoverageRenderer(
private val notCoveredMutation: List,
private val mapMutantsToTests: HashMap>,
private val project: Project,
+ private val generatedTestsTabData: GeneratedTestsTabData,
) :
ActiveGutterRenderer, LineMarkerRendererEx {
+ private val evoSuiteSettingsState: EvoSuiteSettingsState
+ get() = project.getService(EvoSuiteSettingsService::class.java).state
+
+ private var defaultEditorColor: Color? = null
/**
* Perform the action - show toolTip on mouse click.
@@ -61,13 +72,12 @@ class CoverageRenderer(
for (testName in tests) {
prePanel.addComponent(
ActionLink(testName) {
- highlightInToolwindow(testName)
+ highlightTestCase(testName)
},
)
}
- val state = ApplicationManager.getApplication().getService(SettingsApplicationService::class.java).state
- if (coveredMutation.isNotEmpty() && state.criterionWeakMutation) {
+ if (coveredMutation.isNotEmpty() && evoSuiteSettingsState.criterionWeakMutation) {
prePanel.addComponent(JBLabel(" Killed mutants:"), 10)
for (mutantName in coveredMutation) {
prePanel.addComponent(
@@ -78,7 +88,7 @@ class CoverageRenderer(
}
}
- if (notCoveredMutation.isNotEmpty() && state.criterionWeakMutation) {
+ if (notCoveredMutation.isNotEmpty() && evoSuiteSettingsState.criterionWeakMutation) {
prePanel.addComponent(JBLabel(" Survived mutants:"), 10)
for (mutantName in notCoveredMutation) {
prePanel.addComponent(
@@ -124,25 +134,95 @@ class CoverageRenderer(
}
/**
- * Use Display service's mini-editor highlighter function
+ * Use Display service's mutant highlighter function
+ * @param mutantName name of the mutant whose coverage to visualise
+ * @param map map of mutant operations -> List of names of tests which cover the mutants
+ */
+ private fun highlightMutantsInToolwindow(mutantName: String, map: HashMap>) {
+ highlightCoveredMutants(map.getOrPut(mutantName) { ArrayList() })
+ }
+
+ /**
+ * Highlight the mini-editor in the tool window whose name corresponds with the name of the test provided
*
- * @param name name of the test to highlight
+ * @param name name of the test whose editor should be highlighted
*/
- private fun highlightInToolwindow(name: String) {
- val testCaseDisplayService = project.service()
+ private fun highlightTestCase(name: String) {
+ val myPanel = generatedTestsTabData.testCaseNameToPanel[name] ?: return
+ openToolWindowTab()
+ scrollToPanel(myPanel)
+
+ val editorTextField = generatedTestsTabData.testCaseNameToEditorTextField[name] ?: return
+ val settingsProjectState = project.service().state
+ val highlightColor =
+ JBColor(
+ PluginSettingsBundle.get("colorName"),
+ Color(
+ settingsProjectState.colorRed,
+ settingsProjectState.colorGreen,
+ settingsProjectState.colorBlue,
+ 30,
+ ),
+ )
+ if (editorTextField.background.equals(highlightColor)) return
+ defaultEditorColor = editorTextField.background
+ editorTextField.background = highlightColor
+ returnOriginalEditorBackground(editorTextField)
+ }
- testCaseDisplayService.highlightTestCase(name)
+ /**
+ * Method to open the toolwindow tab with generated tests if not already open.
+ */
+ private fun openToolWindowTab() {
+ val toolWindowManager = ToolWindowManager.getInstance(project).getToolWindow("TestSpark")
+ generatedTestsTabData.contentManager = toolWindowManager!!.contentManager
+ if (generatedTestsTabData.content != null) {
+ toolWindowManager.show()
+ toolWindowManager.contentManager.setSelectedContent(generatedTestsTabData.content!!)
+ }
}
/**
- * Use Display service's mutant highlighter function
- * @param mutantName name of the mutant whose coverage to visualise
- * @param map map of mutant operations -> List of names of tests which cover the mutants
+ * Scrolls to the highlighted panel.
+ *
+ * @param myPanel the panel to scroll to
*/
- private fun highlightMutantsInToolwindow(mutantName: String, map: HashMap>) {
- val testCaseDisplayService = project.service()
+ private fun scrollToPanel(myPanel: JPanel) {
+ var sum = 0
+ for (component in generatedTestsTabData.allTestCasePanel.components) {
+ if (component == myPanel) {
+ break
+ }
+ sum += component.height
+ }
+ val scroll = generatedTestsTabData.scrollPane.verticalScrollBar
+
+ // Get the value of the "myPanel" height to enable scrolling to that position.
+ // The current scroll percentage relative to the "myPanel" height is calculated as (sum / generatedTestsTabData.allTestCasePanel.height).
+ // The total scroll height is the sum of the minimum and maximum scroll values (scroll.minimum + scroll.maximum).
+ scroll.value = (scroll.minimum + scroll.maximum) * sum / generatedTestsTabData.allTestCasePanel.height
+ }
- testCaseDisplayService.highlightCoveredMutants(map.getOrPut(mutantName) { ArrayList() })
+ /**
+ * Reset the provided editors color to the default (initial) one after 10 seconds
+ * @param editor the editor whose color to change
+ */
+ private fun returnOriginalEditorBackground(editor: EditorTextField) {
+ Thread {
+ val timeWithHighlightedBackground: Long = 10000
+ Thread.sleep(timeWithHighlightedBackground)
+ editor.background = defaultEditorColor
+ }.start()
+ }
+
+ /**
+ * Highlight a range of editors
+ * @param names list of test names to pass to highlight function
+ */
+ private fun highlightCoveredMutants(names: List) {
+ names.forEach {
+ highlightTestCase(it)
+ }
}
/**
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/services/CoverageVisualisationService.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/coverage/CoverageVisualisationTabBuilder.kt
similarity index 70%
rename from src/main/kotlin/org/jetbrains/research/testspark/services/CoverageVisualisationService.kt
rename to src/main/kotlin/org/jetbrains/research/testspark/display/coverage/CoverageVisualisationTabBuilder.kt
index 3fb019002..7b2c2af69 100644
--- a/src/main/kotlin/org/jetbrains/research/testspark/services/CoverageVisualisationService.kt
+++ b/src/main/kotlin/org/jetbrains/research/testspark/display/coverage/CoverageVisualisationTabBuilder.kt
@@ -1,6 +1,5 @@
-package org.jetbrains.research.testspark.services
+package org.jetbrains.research.testspark.display.coverage
-import com.intellij.openapi.components.Service
import com.intellij.openapi.components.service
import com.intellij.openapi.editor.Editor
import com.intellij.openapi.editor.markup.HighlighterLayer
@@ -8,26 +7,33 @@ import com.intellij.openapi.editor.markup.TextAttributes
import com.intellij.openapi.project.Project
import com.intellij.openapi.wm.ToolWindowManager
import com.intellij.ui.JBColor
+import com.intellij.ui.ScrollPaneFactory
import com.intellij.ui.content.Content
import com.intellij.ui.content.ContentFactory
import com.intellij.ui.content.ContentManager
+import com.intellij.ui.table.JBTable
import org.evosuite.result.MutationInfo
-import org.jetbrains.research.testspark.bundles.TestSparkLabelsBundle
-import org.jetbrains.research.testspark.bundles.TestSparkToolTipsBundle
+import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle
+import org.jetbrains.research.testspark.bundles.plugin.PluginSettingsBundle
import org.jetbrains.research.testspark.core.data.Report
-import org.jetbrains.research.testspark.coverage.CoverageRenderer
import org.jetbrains.research.testspark.data.IJReport
import org.jetbrains.research.testspark.data.IJTestCase
+import org.jetbrains.research.testspark.display.generatedTests.GeneratedTestsTabData
+import org.jetbrains.research.testspark.services.PluginSettingsService
import java.awt.Color
+import java.awt.Dimension
+import javax.swing.JScrollPane
+import javax.swing.table.AbstractTableModel
import kotlin.math.roundToInt
/**
- * Service used to visualise the coverage and inject data in the toolWindow tab.
- *
- * @param project the project
+ * This class is responsible for building and managing the "Coverage" tab in TestSpark.
+ * It handles the GUI components, their interactions, and the application of test cases.
*/
-@Service(Service.Level.PROJECT)
-class CoverageVisualisationService(private val project: Project) {
+class CoverageVisualisationTabBuilder(
+ private val project: Project,
+ private val editor: Editor,
+) {
// Variable to keep reference to the coverage visualisation content
private var content: Content? = null
@@ -36,6 +42,8 @@ class CoverageVisualisationService(private val project: Project) {
private var currentHighlightedData: HighlightedData? = null
+ private var mainScrollPane: JScrollPane? = null
+
/**
* Represents highlighted data in the editor.
*
@@ -55,32 +63,24 @@ class CoverageVisualisationService(private val project: Project) {
* Clears all highlighters from the list of editors.
*/
fun clear() {
- currentHighlightedData ?: return
- currentHighlightedData!!.editor.markupModel ?: return
- currentHighlightedData!!.editor.markupModel.removeAllHighlighters()
+ currentHighlightedData?.editor?.markupModel?.removeAllHighlighters()
}
- /**
- * Retrieves the current highlighted data.
- *
- * @return The current highlighted data, or null if there is no highlighted data.
- */
- fun getCurrentHighlightedData(): HighlightedData? = currentHighlightedData
-
/**
* Instantiates tab for coverage table and calls function to update coverage.
*
* @param testReport the generated tests summary
*/
- fun showCoverage(testReport: Report) {
+ fun show(testReport: Report, generatedTestsTabData: GeneratedTestsTabData) {
// Show toolWindow statistics
fillToolWindowContents(testReport)
createToolWindowTab()
updateCoverage(
testReport.allCoveredLines,
- testReport.testCaseList.values.stream().map { it.id }.toList().toHashSet(),
+ testReport.testCaseList.values.map { it.id }.toHashSet(),
testReport,
+ generatedTestsTabData,
)
}
@@ -92,24 +92,25 @@ class CoverageVisualisationService(private val project: Project) {
* @param testReport report used for gutter information
* @param selectedTests hash set of selected test names
*/
- fun updateCoverage(
+ private fun updateCoverage(
linesToCover: Set,
selectedTests: HashSet,
testReport: Report,
+ generatedTestsTabData: GeneratedTestsTabData,
) {
currentHighlightedData =
- HighlightedData(linesToCover, selectedTests, testReport, project.service().editor!!)
+ HighlightedData(linesToCover, selectedTests, testReport, editor)
clear()
- val settingsProjectState = project.service().state
+ val settingsProjectState = project.service().state
if (settingsProjectState.showCoverageCheckboxSelected) {
val color = JBColor(
- TestSparkToolTipsBundle.defaultValue("colorName"),
+ PluginSettingsBundle.get("colorName"),
Color(settingsProjectState.colorRed, settingsProjectState.colorGreen, settingsProjectState.colorBlue),
)
val colorForLines = JBColor(
- TestSparkToolTipsBundle.defaultValue("colorName"),
+ PluginSettingsBundle.get("colorName"),
Color(
settingsProjectState.colorRed,
settingsProjectState.colorGreen,
@@ -146,7 +147,7 @@ class CoverageVisualisationService(private val project: Project) {
for (i in linesToCover) {
val line = i - 1
- val hl = project.service().editor!!.markupModel.addLineHighlighter(
+ val hl = editor.markupModel.addLineHighlighter(
line,
HighlighterLayer.ADDITIONAL_SYNTAX,
textAttribute,
@@ -166,6 +167,7 @@ class CoverageVisualisationService(private val project: Project) {
mutationNotCoveredLine,
mapMutantsToTests,
project,
+ generatedTestsTabData,
)
}
}
@@ -186,7 +188,8 @@ class CoverageVisualisationService(private val project: Project) {
}
private fun getCoveredMutants(testReport: Report, selectedTests: HashSet): Map> {
- return testReport.testCaseList.filter { x -> x.value.id in selectedTests }.map { x -> (x.value as IJTestCase).coveredMutants }
+ return testReport.testCaseList.filter { x -> x.value.id in selectedTests }
+ .map { x -> (x.value as IJTestCase).coveredMutants }
.flatten().groupBy { x -> x.lineNo }
}
@@ -221,19 +224,20 @@ class CoverageVisualisationService(private val project: Project) {
}
// Change the values in the table
- val coverageToolWindowDisplayService = project.service()
- coverageToolWindowDisplayService.data[0] = testReport.UUT
- coverageToolWindowDisplayService.data[1] = "$relativeLines% ($coveredLines/$allLines)"
- coverageToolWindowDisplayService.data[2] = "$relativeBranch% ($coveredBranches/$allBranches)"
- coverageToolWindowDisplayService.data[3] = "$relativeMutations% ($coveredMutations/$allMutations)"
+ mainScrollPane = getPanel(
+ arrayListOf(
+ testReport.UUT,
+ "$relativeLines% ($coveredLines/$allLines)",
+ "$relativeBranch% ($coveredBranches/$allBranches)",
+ "$relativeMutations% ($coveredMutations/$allMutations)",
+ ),
+ )
}
/**
* Creates a new toolWindow tab for the coverage visualisation.
*/
private fun createToolWindowTab() {
- val visualisationService = project.service()
-
// Remove coverage visualisation from content manager if necessary
val toolWindowManager = ToolWindowManager.getInstance(project).getToolWindow("TestSpark")
contentManager = toolWindowManager!!.contentManager
@@ -244,13 +248,62 @@ class CoverageVisualisationService(private val project: Project) {
// If there is no coverage visualisation tab, make it
val contentFactory: ContentFactory = ContentFactory.getInstance()
content = contentFactory.createContent(
- visualisationService.mainPanel,
- TestSparkLabelsBundle.defaultValue("coverageVisualisation"),
+ mainScrollPane,
+ PluginLabelsBundle.get("coverageVisualisation"),
true,
)
contentManager!!.addContent(content!!)
}
+ private fun getPanel(data: ArrayList): JScrollPane {
+ // Implementation of abstract table model
+ val tableModel = object : AbstractTableModel() {
+ /**
+ * Returns the number of rows.
+ *
+ * @return row count
+ */
+ override fun getRowCount(): Int {
+ return 1
+ }
+
+ /**
+ * Returns the number of columns.
+ *
+ * @return column count
+ */
+ override fun getColumnCount(): Int {
+ return 4
+ }
+
+ /**
+ * Returns the value at index.
+ *
+ * @param rowIndex index of row
+ * @param columnIndex index of column
+ * @return value at row
+ */
+ override fun getValueAt(rowIndex: Int, columnIndex: Int): Any {
+ return data[rowIndex * 4 + columnIndex]
+ }
+ }
+
+ val table = JBTable(tableModel)
+
+ val mainPanel = ScrollPaneFactory.createScrollPane(table)
+
+ val tableColumnModel = table.columnModel
+ tableColumnModel.getColumn(0).headerValue = PluginLabelsBundle.get("unitsUndertest")
+ tableColumnModel.getColumn(1).headerValue = PluginLabelsBundle.get("lineCoverage")
+ tableColumnModel.getColumn(2).headerValue = PluginLabelsBundle.get("branchCoverage")
+ tableColumnModel.getColumn(3).headerValue = PluginLabelsBundle.get("weakMutationCoverage")
+
+ table.columnModel = tableColumnModel
+ table.minimumSize = Dimension(700, 100)
+
+ return mainPanel
+ }
+
/**
* Closes the toolWindow tab for the coverage visualisation
*/
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/IJProgressIndicator.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/custom/IJProgressIndicator.kt
similarity index 88%
rename from src/main/kotlin/org/jetbrains/research/testspark/display/IJProgressIndicator.kt
rename to src/main/kotlin/org/jetbrains/research/testspark/display/custom/IJProgressIndicator.kt
index 3130eb611..813dadeff 100644
--- a/src/main/kotlin/org/jetbrains/research/testspark/display/IJProgressIndicator.kt
+++ b/src/main/kotlin/org/jetbrains/research/testspark/display/custom/IJProgressIndicator.kt
@@ -1,4 +1,4 @@
-package org.jetbrains.research.testspark.display
+package org.jetbrains.research.testspark.display.custom
import com.intellij.openapi.progress.ProgressIndicator
import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator
@@ -15,6 +15,7 @@ class IJProgressIndicator(private val indicator: ProgressIndicator) : CustomProg
}
override fun isIndeterminate(): Boolean = indicator.isIndeterminate
+
override fun setFraction(value: Double) {
indicator.fraction = value
}
@@ -23,6 +24,8 @@ class IJProgressIndicator(private val indicator: ProgressIndicator) : CustomProg
override fun isCanceled(): Boolean = indicator.isCanceled
+ override fun isRunning(): Boolean = indicator.isRunning
+
override fun start() {
indicator.start()
}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/JUnitCombobox.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/custom/JUnitCombobox.kt
similarity index 83%
rename from src/main/kotlin/org/jetbrains/research/testspark/display/JUnitCombobox.kt
rename to src/main/kotlin/org/jetbrains/research/testspark/display/custom/JUnitCombobox.kt
index cc218f7c8..48ce7f24c 100644
--- a/src/main/kotlin/org/jetbrains/research/testspark/display/JUnitCombobox.kt
+++ b/src/main/kotlin/org/jetbrains/research/testspark/display/custom/JUnitCombobox.kt
@@ -1,4 +1,4 @@
-package org.jetbrains.research.testspark.display
+package org.jetbrains.research.testspark.display.custom
import com.intellij.openapi.actionSystem.AnActionEvent
import com.intellij.openapi.actionSystem.CommonDataKeys
@@ -7,24 +7,24 @@ import com.intellij.openapi.roots.ModuleRootManager
import com.intellij.openapi.roots.ProjectRootManager
import com.intellij.openapi.ui.ComboBox
import org.jetbrains.research.testspark.core.data.JUnitVersion
-import org.jetbrains.research.testspark.services.SettingsApplicationService
-import org.jetbrains.research.testspark.settings.SettingsApplicationState
+import org.jetbrains.research.testspark.services.LLMSettingsService
+import org.jetbrains.research.testspark.settings.llm.LLMSettingsState
import java.awt.Component
import javax.swing.DefaultListCellRenderer
import javax.swing.JList
class JUnitCombobox(val e: AnActionEvent) : ComboBox(JUnitVersion.entries.toTypedArray()) {
- private val settingsState: SettingsApplicationState
- get() = e.project!!.getService(SettingsApplicationService::class.java).state
+ private val llmSettingsState: LLMSettingsState
+ get() = e.project!!.getService(LLMSettingsService::class.java).state
init {
val detected = findJUnitDependency()
- if (settingsState.junitVersionPriorityCheckBoxSelected && detected.size == 1) {
+ if (llmSettingsState.junitVersionPriorityCheckBoxSelected && detected.size == 1) {
this.selectedItem = detected[0]
} else {
for (junitVersion in JUnitVersion.entries) {
- if (junitVersion == settingsState.junitVersion) {
+ if (junitVersion == llmSettingsState.junitVersion) {
this.selectedItem = junitVersion
}
}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/generatedTests/GenerateTestsTabHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/generatedTests/GenerateTestsTabHelper.kt
new file mode 100644
index 000000000..859e16c1a
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/display/generatedTests/GenerateTestsTabHelper.kt
@@ -0,0 +1,40 @@
+package org.jetbrains.research.testspark.display.generatedTests
+
+object GenerateTestsTabHelper {
+ /**
+ * A helper method to remove a test case from the cache and from the UI.
+ *
+ * @param testCaseName the name of the test
+ */
+ fun removeTestCase(testCaseName: String, generatedTestsTabData: GeneratedTestsTabData) {
+ // Update the number of selected test cases if necessary
+ if (generatedTestsTabData.testCaseNameToSelectedCheckbox[testCaseName]!!.isSelected) {
+ generatedTestsTabData.testsSelected--
+ }
+
+ // Remove the test panel from the UI
+ generatedTestsTabData.allTestCasePanel.remove(generatedTestsTabData.testCaseNameToPanel[testCaseName])
+
+ // Remove the test panel
+ generatedTestsTabData.testCaseNameToPanel.remove(testCaseName)
+
+ // Remove the selected checkbox
+ generatedTestsTabData.testCaseNameToSelectedCheckbox.remove(testCaseName)
+
+ // Remove the editorTextField
+ generatedTestsTabData.testCaseNameToEditorTextField.remove(testCaseName)
+ }
+
+ /**
+ * Updates the user interface of the tool window.
+ *
+ * This method updates the UI of the tool window tab by calling the updateUI
+ * method of the allTestCasePanel object and the updateTopLabels method
+ * of the topButtonsPanel object. It also checks if there are no more tests remaining
+ * and closes the tool window if that is the case.
+ */
+ fun update(generatedTestsTabData: GeneratedTestsTabData) {
+ generatedTestsTabData.allTestCasePanel.updateUI()
+ generatedTestsTabData.topButtonsPanelBuilder.update(generatedTestsTabData)
+ }
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/generatedTests/GeneratedTestsTabBuilder.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/generatedTests/GeneratedTestsTabBuilder.kt
new file mode 100644
index 000000000..bdde60d9b
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/display/generatedTests/GeneratedTestsTabBuilder.kt
@@ -0,0 +1,255 @@
+package org.jetbrains.research.testspark.display.generatedTests
+
+import com.intellij.openapi.editor.Editor
+import com.intellij.openapi.project.Project
+import com.intellij.openapi.wm.ToolWindowManager
+import com.intellij.serviceContainer.AlreadyDisposedException
+import com.intellij.ui.content.ContentFactory
+import com.intellij.ui.content.ContentManager
+import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle
+import org.jetbrains.research.testspark.core.data.Report
+import org.jetbrains.research.testspark.core.test.SupportedLanguage
+import org.jetbrains.research.testspark.data.UIContext
+import org.jetbrains.research.testspark.display.coverage.CoverageVisualisationTabBuilder
+import org.jetbrains.research.testspark.display.utils.ReportUpdater
+import org.jetbrains.research.testspark.display.utils.java.JavaDisplayUtils
+import org.jetbrains.research.testspark.display.utils.kotlin.KotlinDisplayUtils
+import org.jetbrains.research.testspark.display.utils.template.DisplayUtils
+import org.jetbrains.research.testspark.tools.TestsExecutionResultManager
+import java.awt.BorderLayout
+import java.awt.Dimension
+import javax.swing.Box
+import javax.swing.BoxLayout
+import javax.swing.JCheckBox
+import javax.swing.JPanel
+import javax.swing.JSeparator
+import javax.swing.SwingConstants
+
+/**
+ * This class is responsible for building and managing the "Generated Tests" tab in TestSpark.
+ * It handles the GUI components, their interactions, and the application of test cases.
+ */
+class GeneratedTestsTabBuilder(
+ private val project: Project,
+ private val report: Report,
+ private val editor: Editor,
+ private val uiContext: UIContext,
+ private val coverageVisualisationTabBuilder: CoverageVisualisationTabBuilder,
+ private val testsExecutionResultManager: TestsExecutionResultManager,
+) {
+ private val generatedTestsTabData: GeneratedTestsTabData = GeneratedTestsTabData()
+
+ private var mainPanel: JPanel = JPanel()
+
+ private var displayUtils: DisplayUtils? = null
+
+ fun generatedTestsTabData() = generatedTestsTabData
+
+ fun getRemoveAllButton() = generatedTestsTabData.topButtonsPanelBuilder.getRemoveAllButton()
+
+ fun getApplyButton() = generatedTestsTabData.applyButton
+
+ /**
+ * Displays the generated tests tab in the tool window.
+ * This method initializes necessary components based on the selected language and shows the tab.
+ */
+ fun show(contentManager: ContentManager, language: SupportedLanguage) {
+ generatedTestsTabData.allTestCasePanel.removeAll()
+ generatedTestsTabData.allTestCasePanel.layout =
+ BoxLayout(generatedTestsTabData.allTestCasePanel, BoxLayout.Y_AXIS)
+ generatedTestsTabData.testCaseNameToPanel.clear()
+
+ generatedTestsTabData.contentManager = contentManager
+
+ setDisplayUtils(language)
+
+ fillMainPanel()
+
+ fillAllTestCasePanel(language)
+
+ addSeparator()
+
+ createToolWindowTab()
+ }
+
+ /**
+ * Sets the display utility object based on the supported language.
+ *
+ * @param language The programming language.
+ */
+ private fun setDisplayUtils(language: SupportedLanguage) {
+ displayUtils = when (language) {
+ SupportedLanguage.Java -> {
+ JavaDisplayUtils()
+ }
+
+ SupportedLanguage.Kotlin -> {
+ KotlinDisplayUtils()
+ }
+ }
+ }
+
+ /**
+ * Initializes and fills the main panel with subcomponents.
+ */
+ private fun fillMainPanel() {
+ mainPanel.layout = BorderLayout()
+
+ mainPanel.add(
+ generatedTestsTabData.topButtonsPanelBuilder.getPanel(project, generatedTestsTabData),
+ BorderLayout.NORTH,
+ )
+ mainPanel.add(generatedTestsTabData.scrollPane, BorderLayout.CENTER)
+ mainPanel.add(generatedTestsTabData.applyButton, BorderLayout.SOUTH)
+
+ generatedTestsTabData.applyButton.isOpaque = false
+ generatedTestsTabData.applyButton.isContentAreaFilled = false
+ }
+
+ /**
+ * Initializes and fills the main panel with subcomponents.
+ */
+ private fun fillAllTestCasePanel(language: SupportedLanguage) {
+ // TestCasePanelFactories array
+ val testCasePanelFactories = arrayListOf()
+
+ report.testCaseList.values.forEach {
+ val testCase = it
+ val testCasePanel = JPanel()
+ testCasePanel.layout = BorderLayout()
+
+ // Add a checkbox to select the test
+ val checkbox = JCheckBox()
+ checkbox.isSelected = true
+ checkbox.addItemListener {
+ // Update the number of selected tests
+ generatedTestsTabData.testsSelected -= (1 - 2 * checkbox.isSelected.compareTo(false))
+
+ if (checkbox.isSelected) {
+ ReportUpdater.selectTestCase(
+ report,
+ generatedTestsTabData.unselectedTestCases,
+ testCase.id,
+ coverageVisualisationTabBuilder,
+ generatedTestsTabData,
+ )
+ } else {
+ ReportUpdater.unselectTestCase(
+ report,
+ generatedTestsTabData.unselectedTestCases,
+ testCase.id,
+ coverageVisualisationTabBuilder,
+ generatedTestsTabData,
+ )
+ }
+
+ GenerateTestsTabHelper.update(generatedTestsTabData)
+ }
+ testCasePanel.add(checkbox, BorderLayout.WEST)
+
+ val testCasePanelBuilder =
+ TestCasePanelBuilder(
+ project, language, testCase, editor, checkbox, uiContext, report,
+ coverageVisualisationTabBuilder, generatedTestsTabData, testsExecutionResultManager,
+ )
+ testCasePanel.add(testCasePanelBuilder.getUpperPanel(), BorderLayout.NORTH)
+ testCasePanel.add(testCasePanelBuilder.getMiddlePanel(), BorderLayout.CENTER)
+ testCasePanel.add(testCasePanelBuilder.getBottomPanel(), BorderLayout.SOUTH)
+
+ testCasePanelFactories.add(testCasePanelBuilder)
+
+ testCasePanel.add(Box.createRigidArea(Dimension(12, 0)), BorderLayout.EAST)
+
+ // Add panel to parent panel
+ testCasePanel.maximumSize = Dimension(Short.MAX_VALUE.toInt(), Short.MAX_VALUE.toInt())
+ generatedTestsTabData.allTestCasePanel.add(testCasePanel)
+ addSeparator()
+
+ generatedTestsTabData.testCaseNameToPanel[testCase.testName] = testCasePanel
+ generatedTestsTabData.testCaseNameToSelectedCheckbox[testCase.testName] = checkbox
+ generatedTestsTabData.testCaseNameToEditorTextField[testCase.testName] =
+ testCasePanelBuilder.getEditorTextField()
+ }
+ generatedTestsTabData.testsSelected = generatedTestsTabData.testCaseNameToPanel.size
+ generatedTestsTabData.testCasePanelFactories.addAll(testCasePanelFactories)
+ generatedTestsTabData.topButtonsPanelBuilder.update(generatedTestsTabData)
+ }
+
+ /**
+ * Adds a visual separator component to the panel to distinguish sections.
+ */
+ private fun addSeparator() {
+ generatedTestsTabData.allTestCasePanel.add(Box.createRigidArea(Dimension(0, 10)))
+ generatedTestsTabData.allTestCasePanel.add(JSeparator(SwingConstants.HORIZONTAL))
+ generatedTestsTabData.allTestCasePanel.add(Box.createRigidArea(Dimension(0, 10)))
+ }
+
+ /**
+ * Applies the selected test cases by passing them to the display utility for execution.
+ */
+ fun applyTests(): Boolean {
+ // Filter the selected test cases
+ val selectedTestCasePanels =
+ generatedTestsTabData.testCaseNameToPanel.filter { (it.value.getComponent(0) as JCheckBox).isSelected }
+ val selectedTestCases = selectedTestCasePanels.map { it.key }
+
+ // Get the test case components (source code of the tests)
+ val testCaseComponents = selectedTestCases
+ .map { generatedTestsTabData.testCaseNameToEditorTextField[it]!! }
+ .map { it.document.text }
+
+ val applyingResult = displayUtils!!.applyTests(project, uiContext, testCaseComponents)
+
+ // Remove the selected test cases from the cache and the tool window UI
+ if (applyingResult) clear()
+
+ return applyingResult
+ }
+
+ /**
+ * Creates a new tab in the tool window for displaying the generated tests.
+ */
+ private fun createToolWindowTab() {
+ // Remove generated tests tab from content manager if necessary
+ val toolWindowManager = ToolWindowManager.getInstance(project).getToolWindow("TestSpark")
+ generatedTestsTabData.contentManager = toolWindowManager!!.contentManager
+ if (generatedTestsTabData.content != null) {
+ generatedTestsTabData.contentManager!!.removeContent(generatedTestsTabData.content!!, true)
+ }
+
+ // If there is no generated tests tab, make it
+ val contentFactory: ContentFactory = ContentFactory.getInstance()
+ generatedTestsTabData.content = contentFactory.createContent(
+ mainPanel,
+ PluginLabelsBundle.get("generatedTests"),
+ true,
+ )
+ generatedTestsTabData.contentManager!!.addContent(generatedTestsTabData.content!!)
+ generatedTestsTabData.contentManager!!.setSelectedContent(generatedTestsTabData.content!!)
+
+ toolWindowManager.show()
+ }
+
+ /**
+ * Closes the tool window by removing the content and hiding the window.
+ */
+ private fun closeToolWindow() {
+ try {
+ generatedTestsTabData.contentManager?.removeContent(generatedTestsTabData.content!!, true)
+ ToolWindowManager.getInstance(project).getToolWindow("TestSpark")?.hide()
+ coverageVisualisationTabBuilder.closeToolWindowTab()
+ } catch (_: AlreadyDisposedException) {} // Make sure the process continues if the tool window is already closed
+ }
+
+ /**
+ * Clears all the generated test cases from the UI and the internal cache.
+ */
+ fun clear() {
+ generatedTestsTabData.testCaseNameToPanel.toMap()
+ .forEach { GenerateTestsTabHelper.removeTestCase(it.key, generatedTestsTabData) }
+ generatedTestsTabData.testCasePanelFactories.clear()
+ generatedTestsTabData.topButtonsPanelBuilder.clear(generatedTestsTabData)
+
+ closeToolWindow()
+ }
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/generatedTests/GeneratedTestsTabData.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/generatedTests/GeneratedTestsTabData.kt
new file mode 100644
index 000000000..97da647b9
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/display/generatedTests/GeneratedTestsTabData.kt
@@ -0,0 +1,30 @@
+package org.jetbrains.research.testspark.display.generatedTests
+
+import com.intellij.ui.EditorTextField
+import com.intellij.ui.components.JBScrollPane
+import com.intellij.ui.content.Content
+import com.intellij.ui.content.ContentManager
+import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle
+import org.jetbrains.research.testspark.core.data.TestCase
+import javax.swing.JButton
+import javax.swing.JCheckBox
+import javax.swing.JPanel
+
+class GeneratedTestsTabData {
+ val testCaseNameToPanel: HashMap = HashMap()
+ val testCaseNameToSelectedCheckbox: HashMap = HashMap()
+ val testCaseNameToEditorTextField: HashMap = HashMap()
+ var testsSelected: Int = 0
+ val unselectedTestCases: HashMap = HashMap()
+ val testCasePanelFactories: ArrayList = arrayListOf()
+ var allTestCasePanel: JPanel = JPanel()
+ val applyButton: JButton = JButton(PluginLabelsBundle.get("applyButton"))
+ var scrollPane: JBScrollPane = JBScrollPane(
+ allTestCasePanel,
+ JBScrollPane.VERTICAL_SCROLLBAR_ALWAYS,
+ JBScrollPane.HORIZONTAL_SCROLLBAR_NEVER,
+ )
+ var topButtonsPanelBuilder = TopButtonsPanelBuilder()
+ var contentManager: ContentManager? = null
+ var content: Content? = null
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/TestCasePanelFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/generatedTests/TestCasePanelBuilder.kt
similarity index 71%
rename from src/main/kotlin/org/jetbrains/research/testspark/display/TestCasePanelFactory.kt
rename to src/main/kotlin/org/jetbrains/research/testspark/display/generatedTests/TestCasePanelBuilder.kt
index 04c00b3be..09ad48f01 100644
--- a/src/main/kotlin/org/jetbrains/research/testspark/display/TestCasePanelFactory.kt
+++ b/src/main/kotlin/org/jetbrains/research/testspark/display/generatedTests/TestCasePanelBuilder.kt
@@ -1,10 +1,9 @@
-package org.jetbrains.research.testspark.display
+package org.jetbrains.research.testspark.display.generatedTests
import com.intellij.lang.Language
import com.intellij.notification.NotificationGroupManager
import com.intellij.notification.NotificationType
import com.intellij.openapi.command.WriteCommandAction
-import com.intellij.openapi.components.service
import com.intellij.openapi.diff.DiffColors
import com.intellij.openapi.editor.Editor
import com.intellij.openapi.editor.event.DocumentEvent
@@ -15,29 +14,40 @@ import com.intellij.openapi.progress.ProgressManager
import com.intellij.openapi.progress.Task
import com.intellij.openapi.project.Project
import com.intellij.openapi.ui.ComboBox
+import com.intellij.ui.EditorTextField
import com.intellij.ui.JBColor
import com.intellij.ui.LanguageTextField
import com.intellij.ui.components.JBScrollPane
import com.intellij.util.ui.JBUI
-import org.jetbrains.research.testspark.bundles.TestSparkBundle
-import org.jetbrains.research.testspark.bundles.TestSparkLabelsBundle
+import org.jetbrains.research.testspark.bundles.llm.LLMMessagesBundle
+import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle
+import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle
+import org.jetbrains.research.testspark.core.data.Report
import org.jetbrains.research.testspark.core.data.TestCase
import org.jetbrains.research.testspark.core.generation.llm.getClassWithTestCaseName
import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator
+import org.jetbrains.research.testspark.core.test.SupportedLanguage
import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM
-import org.jetbrains.research.testspark.data.JsonEncoding
import org.jetbrains.research.testspark.data.UIContext
-import org.jetbrains.research.testspark.services.ErrorService
-import org.jetbrains.research.testspark.services.JavaClassBuilderService
-import org.jetbrains.research.testspark.services.ReportLockingService
-import org.jetbrains.research.testspark.services.SettingsApplicationService
-import org.jetbrains.research.testspark.services.TestCaseDisplayService
-import org.jetbrains.research.testspark.services.TestsExecutionResultService
-import org.jetbrains.research.testspark.settings.SettingsApplicationState
-import org.jetbrains.research.testspark.tools.generatedTests.TestProcessor
-import org.jetbrains.research.testspark.tools.isProcessStopped
+import org.jetbrains.research.testspark.data.llm.JsonEncoding
+import org.jetbrains.research.testspark.display.TestCaseDocumentCreator
+import org.jetbrains.research.testspark.display.TestSparkIcons
+import org.jetbrains.research.testspark.display.coverage.CoverageVisualisationTabBuilder
+import org.jetbrains.research.testspark.display.custom.IJProgressIndicator
+import org.jetbrains.research.testspark.display.utils.ErrorMessageManager
+import org.jetbrains.research.testspark.display.utils.IconButtonCreator
+import org.jetbrains.research.testspark.display.utils.ModifiedLinesGetter
+import org.jetbrains.research.testspark.display.utils.ReportUpdater
+import org.jetbrains.research.testspark.helpers.LLMHelper
+import org.jetbrains.research.testspark.services.LLMSettingsService
+import org.jetbrains.research.testspark.settings.llm.LLMSettingsState
+import org.jetbrains.research.testspark.testmanager.TestAnalyzerFactory
+import org.jetbrains.research.testspark.tools.TestProcessor
+import org.jetbrains.research.testspark.tools.TestsExecutionResultManager
+import org.jetbrains.research.testspark.tools.ToolUtils
+import org.jetbrains.research.testspark.tools.factories.TestCompilerFactory
+import org.jetbrains.research.testspark.tools.llm.error.LLMErrorManager
import org.jetbrains.research.testspark.tools.llm.test.JUnitTestSuitePresenter
-import org.jetbrains.research.testspark.tools.llm.testModificationRequest
import java.awt.Dimension
import java.awt.Toolkit
import java.awt.datatransfer.Clipboard
@@ -55,26 +65,32 @@ import javax.swing.SwingUtilities
import javax.swing.border.Border
import javax.swing.border.MatteBorder
-class TestCasePanelFactory(
+class TestCasePanelBuilder(
private val project: Project,
+ private val language: SupportedLanguage,
private val testCase: TestCase,
editor: Editor,
private val checkbox: JCheckBox,
val uiContext: UIContext?,
+ val report: Report,
+ private val coverageVisualisationTabBuilder: CoverageVisualisationTabBuilder,
+ private val generatedTestsTabData: GeneratedTestsTabData,
+ private val testsExecutionResultManager: TestsExecutionResultManager,
) {
- private val settingsState: SettingsApplicationState
- get() = project.getService(SettingsApplicationService::class.java).state
+ private val llmSettingsState: LLMSettingsState
+ get() = project.getService(LLMSettingsService::class.java).state
private val panel = JPanel()
private val previousButton =
- createButton(TestSparkIcons.previous, TestSparkLabelsBundle.defaultValue("previousRequest"))
+ IconButtonCreator.getButton(TestSparkIcons.previous, PluginLabelsBundle.get("previousRequest"))
private var requestNumber: String = "%d / %d"
private var requestLabel: JLabel = JLabel(requestNumber)
- private val nextButton = createButton(TestSparkIcons.next, TestSparkLabelsBundle.defaultValue("nextRequest"))
+ private val nextButton = IconButtonCreator.getButton(TestSparkIcons.next, PluginLabelsBundle.get("nextRequest"))
private val errorLabel = JLabel(TestSparkIcons.showError)
- private val copyButton = createButton(TestSparkIcons.copy, TestSparkLabelsBundle.defaultValue("copyTip"))
- private val likeButton = createButton(TestSparkIcons.like, TestSparkLabelsBundle.defaultValue("likeTip"))
- private val dislikeButton = createButton(TestSparkIcons.dislike, TestSparkLabelsBundle.defaultValue("dislikeTip"))
+ private val copyButton = IconButtonCreator.getButton(TestSparkIcons.copy, PluginLabelsBundle.get("copyTip"))
+ private val likeButton = IconButtonCreator.getButton(TestSparkIcons.like, PluginLabelsBundle.get("likeTip"))
+ private val dislikeButton =
+ IconButtonCreator.getButton(TestSparkIcons.dislike, PluginLabelsBundle.get("dislikeTip"))
private var allRequestsNumber = 1
private var currentRequestNumber = 1
@@ -87,7 +103,7 @@ class TestCasePanelFactory(
// Add an editor to modify the test source code
private val languageTextField = LanguageTextField(
- Language.findLanguageByID("JAVA"),
+ Language.findLanguageByID(language.languageId),
editor.project,
testCase.testCode,
TestCaseDocumentCreator(
@@ -103,22 +119,23 @@ class TestCasePanelFactory(
)
// Create "Remove" button to remove the test from cache
- private val removeButton = createButton(TestSparkIcons.remove, TestSparkLabelsBundle.defaultValue("removeTip"))
+ private val removeButton =
+ IconButtonCreator.getButton(TestSparkIcons.remove, PluginLabelsBundle.get("removeTip"))
// Create "Reset" button to reset the changes in the source code of the test
- private val resetButton = createButton(TestSparkIcons.reset, TestSparkLabelsBundle.defaultValue("resetTip"))
+ private val resetButton = IconButtonCreator.getButton(TestSparkIcons.reset, PluginLabelsBundle.get("resetTip"))
// Create "Reset" button to reset the changes to last run in the source code of the test
private val resetToLastRunButton =
- createButton(TestSparkIcons.resetToLastRun, TestSparkLabelsBundle.defaultValue("resetToLastRunTip"))
+ IconButtonCreator.getButton(TestSparkIcons.resetToLastRun, PluginLabelsBundle.get("resetToLastRunTip"))
// Create "Run tests" button to remove the test from cache
private val runTestButton = createRunTestButton()
- private val requestJLabel = JLabel(TestSparkLabelsBundle.defaultValue("requestJLabel"))
- private val requestComboBox = ComboBox(arrayOf("") + JsonEncoding.decode(settingsState.defaultLLMRequests))
+ private val requestJLabel = JLabel(PluginLabelsBundle.get("requestJLabel"))
+ private val requestComboBox = ComboBox(arrayOf("") + JsonEncoding.decode(llmSettingsState.defaultLLMRequests))
- private val sendButton = createButton(TestSparkIcons.send, TestSparkLabelsBundle.defaultValue("send"))
+ private val sendButton = IconButtonCreator.getButton(TestSparkIcons.send, PluginLabelsBundle.get("send"))
private val loadingLabel: JLabel = JLabel(TestSparkIcons.loading)
@@ -188,7 +205,7 @@ class TestCasePanelFactory(
val clipboard: Clipboard = Toolkit.getDefaultToolkit().systemClipboard
clipboard.setContents(
StringSelection(
- project.service().getEditor(testCase.testName)!!.document.text,
+ generatedTestsTabData.testCaseNameToEditorTextField[testCase.testName]!!.document.text,
),
null,
)
@@ -196,7 +213,7 @@ class TestCasePanelFactory(
.getNotificationGroup("Test case copied")
.createNotification(
"",
- TestSparkBundle.message("testCaseCopied"),
+ PluginMessagesBundle.get("testCaseCopied"),
NotificationType.INFORMATION,
)
.notify(project)
@@ -275,8 +292,8 @@ class TestCasePanelFactory(
runTestButton.addActionListener {
val choice = JOptionPane.showConfirmDialog(
null,
- TestSparkBundle.message("runCautionMessage"),
- TestSparkBundle.message("confirmationTitle"),
+ PluginMessagesBundle.get("runCautionMessage"),
+ PluginMessagesBundle.get("confirmationTitle"),
JOptionPane.OK_CANCEL_OPTION,
JOptionPane.WARNING_MESSAGE,
)
@@ -289,6 +306,11 @@ class TestCasePanelFactory(
sendButton.addActionListener { sendRequest() }
+ /**
+ * The following code is a workaround for the issue with the ComboBox preferred size.
+ * See: https://github.com/JetBrains-Research/TestSpark/pull/343#discussion_r1781430743
+ */
+ requestComboBox.preferredSize = Dimension(0, 0)
requestComboBox.isEditable = true
return panel
@@ -310,12 +332,12 @@ class TestCasePanelFactory(
* Updates the error label with a new message.
*/
private fun updateErrorLabel() {
- val error = project.service().getCurrentError(testCase.id)
+ val error = testsExecutionResultManager.getCurrentError(testCase.id)
if (error.isBlank()) {
errorLabel.isVisible = false
} else {
errorLabel.isVisible = true
- errorLabel.toolTipText = error
+ errorLabel.toolTipText = ErrorMessageManager.normalize(error)
}
}
@@ -328,7 +350,7 @@ class TestCasePanelFactory(
private fun addLanguageTextFieldListener(languageTextField: LanguageTextField) {
languageTextField.document.addDocumentListener(object : DocumentListener {
override fun documentChanged(event: DocumentEvent) {
- updateUI()
+ update()
}
})
}
@@ -336,7 +358,7 @@ class TestCasePanelFactory(
/**
* Updates the user interface based on the provided code.
*/
- private fun updateUI() {
+ private fun update() {
updateTestCaseInformation()
val lastRunCode = lastRunCodes[currentRequestNumber - 1]
@@ -347,16 +369,16 @@ class TestCasePanelFactory(
val error = getError()
if (error.isNullOrBlank()) {
- project.service().addCurrentPassedTest(testCase.id)
+ testsExecutionResultManager.addCurrentPassedTest(testCase.id)
} else {
- project.service().addCurrentFailedTest(testCase.id, error)
+ testsExecutionResultManager.addCurrentFailedTest(testCase.id, error)
}
updateErrorLabel()
runTestButton.isEnabled = (error == null)
updateBorder()
- val modifiedLineIndexes = getModifiedLines(
+ val modifiedLineIndexes = ModifiedLinesGetter.getLines(
lastRunCode.split("\n"),
testCase.testCode.split("\n"),
)
@@ -380,8 +402,8 @@ class TestCasePanelFactory(
testCase.coveredLines = setOf()
}
- project.service().updateTestCase(testCase)
- project.service().updateUI()
+ ReportUpdater.updateTestCase(report, testCase, coverageVisualisationTabBuilder, generatedTestsTabData)
+ GenerateTestsTabHelper.update(generatedTestsTabData)
}
/**
@@ -392,34 +414,39 @@ class TestCasePanelFactory(
*/
private fun sendRequest() {
loadingLabel.isVisible = true
- enableComponents(false)
+ enableGlobalComponents(false)
+ enableLocalComponents(false)
ProgressManager.getInstance()
- .run(object : Task.Backgroundable(project, TestSparkBundle.message("sendingFeedback")) {
+ .run(object : Task.Backgroundable(project, PluginMessagesBundle.get("sendingFeedback")) {
override fun run(indicator: ProgressIndicator) {
val ijIndicator = IJProgressIndicator(indicator)
- if (isProcessStopped(project, ijIndicator)) {
+ if (ToolUtils.isProcessStopped(uiContext!!.errorMonitor, ijIndicator)) {
finishProcess()
return
}
- val modifiedTest = testModificationRequest(
+ val modifiedTest = LLMHelper.testModificationRequest(
+ language,
initialCodes[currentRequestNumber - 1],
requestComboBox.editor.item.toString(),
ijIndicator,
- uiContext!!.requestManager!!,
+ uiContext.requestManager!!,
project,
uiContext.testGenerationOutput,
+ uiContext.errorMonitor,
)
- if (modifiedTest != null) {
+ if (modifiedTest == null || modifiedTest.testCases.isEmpty()) {
+ LLMErrorManager().warningProcess(LLMMessagesBundle.get("modifyWithLLMError"), project)
+ } else {
modifiedTest.setTestFileName(
getClassWithTestCaseName(testCase.testName),
)
addTest(modifiedTest)
}
- if (isProcessStopped(project, ijIndicator)) {
+ if (ToolUtils.isProcessStopped(uiContext.errorMonitor, ijIndicator)) {
finishProcess()
return
}
@@ -430,13 +457,19 @@ class TestCasePanelFactory(
})
}
- private fun finishProcess() {
- project.service().clear()
+ private fun finishProcess(enableGlobal: Boolean = true) {
+ uiContext!!.errorMonitor.clear()
loadingLabel.isVisible = false
- enableComponents(true)
+ if (enableGlobal) enableGlobalComponents(true)
+ enableLocalComponents(true)
}
- private fun enableComponents(isEnabled: Boolean) {
+ private fun enableGlobalComponents(isEnabled: Boolean) {
+ generatedTestsTabData.topButtonsPanelBuilder.getRemoveAllButton().isEnabled = isEnabled
+ generatedTestsTabData.applyButton.isEnabled = isEnabled
+ }
+
+ private fun enableLocalComponents(isEnabled: Boolean) {
nextButton.isEnabled = isEnabled
previousButton.isEnabled = isEnabled
runTestButton.isEnabled = isEnabled
@@ -447,14 +480,12 @@ class TestCasePanelFactory(
}
private fun addTest(testSuite: TestSuiteGeneratedByLLM) {
- val testSuitePresenter = JUnitTestSuitePresenter(project, uiContext!!.testGenerationOutput)
+ val testSuitePresenter = JUnitTestSuitePresenter(project, uiContext!!.testGenerationOutput, language)
WriteCommandAction.runWriteCommandAction(project) {
- project.service().clear()
+ uiContext.errorMonitor.clear()
val code = testSuitePresenter.toString(testSuite)
- testCase.testName =
- project.service()
- .getTestMethodNameFromClassWithTestCase(testCase.testName, code)
+ testCase.testName = TestAnalyzerFactory.create(language).extractFirstTestMethodName(testCase.testName, code)
testCase.testCode = code
// update numbers
@@ -487,12 +518,13 @@ class TestCasePanelFactory(
if (!runTestButton.isEnabled) return
loadingLabel.isVisible = true
- enableComponents(false)
+ enableGlobalComponents(false)
+ enableLocalComponents(false)
ProgressManager.getInstance()
- .run(object : Task.Backgroundable(project, TestSparkBundle.message("sendingFeedback")) {
+ .run(object : Task.Backgroundable(project, PluginMessagesBundle.get("sendingFeedback")) {
override fun run(indicator: ProgressIndicator) {
- runTest(IJProgressIndicator(indicator))
+ runTest(IJProgressIndicator(indicator), true)
}
})
}
@@ -502,25 +534,41 @@ class TestCasePanelFactory(
if (!runTestButton.isEnabled) return
loadingLabel.isVisible = true
- enableComponents(false)
+ enableGlobalComponents(false)
+ enableLocalComponents(false)
tasks.add { indicator ->
- runTest(indicator)
+ runTest(indicator, false)
}
}
- private fun runTest(indicator: CustomProgressIndicator) {
+ fun removeTask() {
+ finishProcess()
+ update()
+ }
+
+ private fun runTest(indicator: CustomProgressIndicator, enableGlobal: Boolean) {
indicator.setText("Executing ${testCase.testName}")
+ val fileName = TestAnalyzerFactory.create(language).getFileNameFromTestCaseCode(testCase.testCode)
+
+ val testCompiler = TestCompilerFactory.create(
+ project,
+ llmSettingsState.junitVersion,
+ language,
+ )
+
val newTestCase = TestProcessor(project)
.processNewTestCase(
- "${project.service().getClassFromTestCaseCode(testCase.testCode)}.java",
+ fileName,
testCase.id,
testCase.testName,
testCase.testCode,
- uiContext!!.testGenerationOutput.packageLine,
+ uiContext!!.testGenerationOutput.packageName,
uiContext.testGenerationOutput.resultPath,
uiContext.projectContext,
+ testCompiler,
+ testsExecutionResultManager,
)
testCase.coveredLines = newTestCase.coveredLines
@@ -530,10 +578,10 @@ class TestCasePanelFactory(
lastRunCodes[currentRequestNumber - 1] = testCase.testCode
SwingUtilities.invokeLater {
- updateUI()
+ update()
}
- finishProcess()
+ finishProcess(enableGlobal)
indicator.stop()
}
@@ -554,7 +602,7 @@ class TestCasePanelFactory(
currentCodes[currentRequestNumber - 1] = testCase.testCode
lastRunCodes[currentRequestNumber - 1] = testCase.testCode
- updateUI()
+ update()
}
}
@@ -566,7 +614,7 @@ class TestCasePanelFactory(
languageTextField.document.setText(lastRunCodes[currentRequestNumber - 1])
currentCodes[currentRequestNumber - 1] = testCase.testCode
- updateUI()
+ update()
}
}
@@ -580,13 +628,14 @@ class TestCasePanelFactory(
*/
private fun remove() {
// Remove the test case from the cache
- project.service().removeTestCase(testCase.testName)
+ GenerateTestsTabHelper.removeTestCase(testCase.testName, generatedTestsTabData)
runTestButton.isEnabled = false
isRemoved = true
- project.service().removeTestCase(testCase)
- project.service().updateUI()
+ ReportUpdater.removeTestCase(report, testCase, coverageVisualisationTabBuilder, generatedTestsTabData)
+
+ GenerateTestsTabHelper.update(generatedTestsTabData)
}
/**
@@ -608,7 +657,7 @@ class TestCasePanelFactory(
*
* @return the error message for the test case
*/
- fun getError() = project.service().getError(testCase.id, testCase.testCode)
+ fun getError() = testsExecutionResultManager.getError(testCase.id, testCase.testCode)
/**
* Returns the border for a given test case.
@@ -630,7 +679,7 @@ class TestCasePanelFactory(
* @return the created button
*/
private fun createRunTestButton(): JButton {
- val runTestButton = JButton(TestSparkLabelsBundle.defaultValue("run"), TestSparkIcons.runTest)
+ val runTestButton = JButton(PluginLabelsBundle.get("run"), TestSparkIcons.runTest)
runTestButton.isOpaque = false
runTestButton.isContentAreaFilled = false
runTestButton.isBorderPainted = true
@@ -644,7 +693,7 @@ class TestCasePanelFactory(
*/
private fun switchToAnotherCode() {
languageTextField.document.setText(currentCodes[currentRequestNumber - 1])
- updateUI()
+ update()
}
/**
@@ -658,9 +707,14 @@ class TestCasePanelFactory(
* Updates the current test case with the specified test name and test code.
*/
private fun updateTestCaseInformation() {
- testCase.testName =
- project.service()
- .getTestMethodNameFromClassWithTestCase(testCase.testName, languageTextField.document.text)
+ testCase.testName = TestAnalyzerFactory.create(language).extractFirstTestMethodName(testCase.testName, languageTextField.document.text)
testCase.testCode = languageTextField.document.text
}
+
+ /**
+ * Retrieves the editor text field from the current UI context.
+ *
+ * @return the editor text field
+ */
+ fun getEditorTextField(): EditorTextField = languageTextField
}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/generatedTests/TopButtonsPanelBuilder.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/generatedTests/TopButtonsPanelBuilder.kt
new file mode 100644
index 000000000..c5eb68b1f
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/display/generatedTests/TopButtonsPanelBuilder.kt
@@ -0,0 +1,186 @@
+package org.jetbrains.research.testspark.display.generatedTests
+
+import com.intellij.openapi.progress.ProgressIndicator
+import com.intellij.openapi.progress.ProgressManager
+import com.intellij.openapi.progress.Task
+import com.intellij.openapi.project.Project
+import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle
+import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle
+import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator
+import org.jetbrains.research.testspark.display.TestSparkIcons
+import org.jetbrains.research.testspark.display.custom.IJProgressIndicator
+import org.jetbrains.research.testspark.display.utils.IconButtonCreator
+import java.awt.Dimension
+import java.util.LinkedList
+import java.util.Queue
+import javax.swing.Box
+import javax.swing.BoxLayout
+import javax.swing.JButton
+import javax.swing.JCheckBox
+import javax.swing.JLabel
+import javax.swing.JOptionPane
+import javax.swing.JPanel
+
+class TopButtonsPanelBuilder {
+ private var runAllButton: JButton = createRunAllTestButton()
+ private var selectAllButton: JButton =
+ IconButtonCreator.getButton(TestSparkIcons.selectAll, PluginLabelsBundle.get("selectAllTip"))
+ private var unselectAllButton: JButton =
+ IconButtonCreator.getButton(TestSparkIcons.unselectAll, PluginLabelsBundle.get("unselectAllTip"))
+ private var removeAllButton: JButton =
+ IconButtonCreator.getButton(TestSparkIcons.removeAll, PluginLabelsBundle.get("removeAllTip"))
+
+ private var testsSelectedText: String = "${PluginLabelsBundle.get("testsSelected")}: %d/%d"
+ private var testsSelectedLabel: JLabel = JLabel(testsSelectedText)
+
+ private val testsPassedText: String = "${PluginLabelsBundle.get("testsPassed")}: %d/%d"
+ private var testsPassedLabel: JLabel = JLabel(testsPassedText)
+
+ /**
+ * Updates the labels.
+ */
+ fun update(generatedTestsTabData: GeneratedTestsTabData) {
+ val passedTestsCount = generatedTestsTabData.testCasePanelFactories
+ .filter { !it.isRemoved() }
+ .count { it.getError()?.isEmpty() == true }
+
+ val removedTestsCount = generatedTestsTabData.testCasePanelFactories.count { it.isRemoved() }
+
+ if (generatedTestsTabData.testCasePanelFactories.size == removedTestsCount) {
+ removeAllButton.doClick()
+ return
+ }
+ testsSelectedLabel.text = String.format(
+ testsSelectedText,
+ generatedTestsTabData.testsSelected,
+ generatedTestsTabData.testCaseNameToPanel.size,
+ )
+ testsPassedLabel.text =
+ String.format(
+ testsPassedText,
+ passedTestsCount,
+ generatedTestsTabData.testCaseNameToPanel.size,
+ )
+ runAllButton.isEnabled = false
+ for (testCasePanelFactory in generatedTestsTabData.testCasePanelFactories) {
+ runAllButton.isEnabled = runAllButton.isEnabled || testCasePanelFactory.isRunEnabled()
+ }
+ }
+
+ fun getRemoveAllButton() = removeAllButton
+
+ /**
+ * Toggles check boxes so that they are either all selected or all not selected,
+ * depending on the provided parameter.
+ *
+ * @param selected whether the checkboxes have to be selected or not
+ */
+ private fun toggleAllCheckboxes(selected: Boolean, generatedTestsTabData: GeneratedTestsTabData) {
+ generatedTestsTabData.testCaseNameToPanel.forEach { (_, jPanel) ->
+ val checkBox = jPanel.getComponent(0) as JCheckBox
+ checkBox.isSelected = selected
+ }
+ generatedTestsTabData.testsSelected = if (selected) generatedTestsTabData.testCaseNameToPanel.size else 0
+ }
+
+ /**
+ * Executes all test cases.
+ *
+ * This method presents a caution message to the user and asks for confirmation before executing the test cases.
+ * If the user confirms, it iterates through each test case panel factory and runs the corresponding test.
+ */
+ private fun runAllTestCases(project: Project, generatedTestsTabData: GeneratedTestsTabData) {
+ val choice = JOptionPane.showConfirmDialog(
+ null,
+ PluginMessagesBundle.get("runCautionMessage"),
+ PluginMessagesBundle.get("confirmationTitle"),
+ JOptionPane.OK_CANCEL_OPTION,
+ JOptionPane.WARNING_MESSAGE,
+ )
+
+ if (choice == JOptionPane.CANCEL_OPTION) return
+
+ runAllButton.isEnabled = false
+
+ // add each test generation task to queue
+ val tasks: Queue<(CustomProgressIndicator) -> Unit> = LinkedList()
+
+ for (testCasePanelFactory in generatedTestsTabData.testCasePanelFactories) {
+ testCasePanelFactory.addTask(tasks)
+ }
+ // run tasks one after each other
+ executeTasks(project, tasks, generatedTestsTabData)
+ }
+
+ private fun executeTasks(
+ project: Project,
+ tasks: Queue<(CustomProgressIndicator) -> Unit>,
+ generatedTestsTabData: GeneratedTestsTabData,
+ ) {
+ val nextTask = tasks.poll()
+
+ nextTask?.let { task ->
+ ProgressManager.getInstance().run(object : Task.Backgroundable(project, "Test execution") {
+ var globalIndicator: ProgressIndicator? = null
+
+ override fun run(indicator: ProgressIndicator) {
+ globalIndicator = indicator
+ task(IJProgressIndicator(indicator))
+ }
+
+ override fun onFinished() {
+ super.onFinished()
+ if (globalIndicator != null && !globalIndicator!!.isCanceled) {
+ executeTasks(project, tasks, generatedTestsTabData)
+ } else {
+ if (tasks.isNotEmpty()) {
+ runAllButton.isEnabled = true
+ val firstTestPanelFactoryIndex = generatedTestsTabData.testCasePanelFactories.size - tasks.size - 1
+ val lastTestPanelFactoryIndex = generatedTestsTabData.testCasePanelFactories.size
+ for (index in firstTestPanelFactoryIndex until lastTestPanelFactoryIndex) {
+ generatedTestsTabData.testCasePanelFactories[index].removeTask()
+ }
+ }
+ }
+ }
+ })
+ }
+ if (nextTask == null) {
+ generatedTestsTabData.topButtonsPanelBuilder.getRemoveAllButton().isEnabled = true
+ generatedTestsTabData.applyButton.isEnabled = true
+ }
+ }
+
+ fun getPanel(project: Project, generatedTestsTabData: GeneratedTestsTabData): JPanel {
+ val panel = JPanel()
+ panel.layout = BoxLayout(panel, BoxLayout.X_AXIS)
+ panel.preferredSize = Dimension(0, 30)
+ panel.add(Box.createRigidArea(Dimension(10, 0)))
+ panel.add(testsPassedLabel)
+ panel.add(Box.createRigidArea(Dimension(10, 0)))
+ panel.add(testsSelectedLabel)
+ panel.add(Box.createHorizontalGlue())
+ panel.add(runAllButton)
+ panel.add(selectAllButton)
+ panel.add(unselectAllButton)
+ panel.add(removeAllButton)
+
+ selectAllButton.addActionListener { toggleAllCheckboxes(true, generatedTestsTabData) }
+ unselectAllButton.addActionListener { toggleAllCheckboxes(false, generatedTestsTabData) }
+ runAllButton.addActionListener { runAllTestCases(project, generatedTestsTabData) }
+
+ return panel
+ }
+
+ fun clear(generatedTestsTabData: GeneratedTestsTabData) {
+ generatedTestsTabData.testCasePanelFactories.clear()
+ }
+
+ private fun createRunAllTestButton(): JButton {
+ val runTestButton = JButton(PluginLabelsBundle.get("runAll"), TestSparkIcons.runTest)
+ runTestButton.isOpaque = false
+ runTestButton.isContentAreaFilled = false
+ runTestButton.isBorderPainted = true
+ return runTestButton
+ }
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/utils/ErrorMessageManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/utils/ErrorMessageManager.kt
new file mode 100644
index 000000000..b17cecda6
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/display/utils/ErrorMessageManager.kt
@@ -0,0 +1,58 @@
+package org.jetbrains.research.testspark.display.utils
+
+import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle
+import javax.swing.JOptionPane
+
+/**
+ * The ErrorMessageNormalizer class is responsible for normalizing error messages by inserting "
" tags after every block size characters.
+ */
+object ErrorMessageManager {
+ const val BLOCK_SIZE = 100
+
+ const val SEPARATOR = "
"
+
+ /**
+ * Normalizes an error message by inserting "
" tags after every block size characters,
+ * except if there is already a "
" tag within the blockSize characters.
+ * If the string length is a multiple of blockSize and the last character is a "
" tag,
+ * it is removed from the result.
+ *
+ * @param error The error message to be normalized.
+ * @return The normalized error message.
+ */
+ fun normalize(error: String): String {
+ // init variables
+ val builder = StringBuilder()
+ var lastIndex = 0
+
+ // string separating
+ while (lastIndex < error.length) {
+ val nextIndex = (lastIndex + BLOCK_SIZE).coerceAtMost(error.length)
+ val substring = error.substring(lastIndex, nextIndex)
+
+ if (!substring.contains(SEPARATOR)) {
+ builder.append(substring).append(SEPARATOR)
+ } else {
+ builder.append(substring)
+ }
+
+ lastIndex = nextIndex
+ }
+
+ // remove the last
if the string length is a multiple of the block size, and it didn't have
+ if (builder.endsWith(SEPARATOR) && (error.length % BLOCK_SIZE == 0)) {
+ builder.deleteRange(builder.length - SEPARATOR.length, builder.length)
+ }
+
+ return builder.toString()
+ }
+
+ fun showErrorWindow(message: String) {
+ JOptionPane.showMessageDialog(
+ null,
+ message,
+ PluginLabelsBundle.get("errorWindowTitle"),
+ JOptionPane.ERROR_MESSAGE,
+ )
+ }
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/utils/IconButtonCreator.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/utils/IconButtonCreator.kt
new file mode 100644
index 000000000..fcecbcda5
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/display/utils/IconButtonCreator.kt
@@ -0,0 +1,24 @@
+package org.jetbrains.research.testspark.display.utils
+
+import java.awt.Dimension
+import javax.swing.Icon
+import javax.swing.JButton
+
+object IconButtonCreator {
+ /**
+ * Creates a button with the specified icon.
+ *
+ * @param icon the icon to be displayed on the button
+ * @return the created button
+ */
+ fun getButton(icon: Icon, tip: String): JButton {
+ val button = JButton(icon)
+ button.isOpaque = false
+ button.isContentAreaFilled = false
+ button.isBorderPainted = false
+ button.toolTipText = tip
+ val size = button.preferredSize.height
+ button.preferredSize = Dimension(size, size)
+ return button
+ }
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/utils/ModifiedLinesGetter.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/utils/ModifiedLinesGetter.kt
new file mode 100644
index 000000000..8d58e54a6
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/display/utils/ModifiedLinesGetter.kt
@@ -0,0 +1,50 @@
+package org.jetbrains.research.testspark.display.utils
+
+object ModifiedLinesGetter {
+ /**
+ * Returns the indexes of lines that are modified between two lists of strings.
+ *
+ * @param source The source list of strings.
+ * @param target The target list of strings.
+ * @return The indexes of modified lines.
+ */
+ fun getLines(source: List, target: List): List {
+ val dp = Array(source.size + 1) { IntArray(target.size + 1) }
+
+ for (i in 1..source.size) {
+ for (j in 1..target.size) {
+ if (source[i - 1] == target[j - 1]) {
+ dp[i][j] = dp[i - 1][j - 1] + 1
+ } else {
+ dp[i][j] = maxOf(dp[i - 1][j], dp[i][j - 1])
+ }
+ }
+ }
+
+ var i = source.size
+ var j = target.size
+
+ val modifiedLineIndexes = mutableListOf()
+
+ while (i > 0 && j > 0) {
+ if (source[i - 1] == target[j - 1]) {
+ i--
+ j--
+ } else if (dp[i][j] == dp[i - 1][j]) {
+ i--
+ } else if (dp[i][j] == dp[i][j - 1]) {
+ modifiedLineIndexes.add(j - 1)
+ j--
+ }
+ }
+
+ while (j > 0) {
+ modifiedLineIndexes.add(j - 1)
+ j--
+ }
+
+ modifiedLineIndexes.reverse()
+
+ return modifiedLineIndexes
+ }
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/utils/ReportUpdater.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/utils/ReportUpdater.kt
new file mode 100644
index 000000000..5eea34f19
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/display/utils/ReportUpdater.kt
@@ -0,0 +1,55 @@
+package org.jetbrains.research.testspark.display.utils
+
+import org.jetbrains.research.testspark.core.data.Report
+import org.jetbrains.research.testspark.core.data.TestCase
+import org.jetbrains.research.testspark.display.coverage.CoverageVisualisationTabBuilder
+import org.jetbrains.research.testspark.display.generatedTests.GeneratedTestsTabData
+
+object ReportUpdater {
+ fun updateTestCase(
+ report: Report,
+ testCase: TestCase,
+ coverageVisualisationTabBuilder: CoverageVisualisationTabBuilder,
+ generatedTestsTabData: GeneratedTestsTabData,
+ ) {
+ report.testCaseList.remove(testCase.id)
+ report.testCaseList[testCase.id] = testCase
+ report.normalized()
+ coverageVisualisationTabBuilder.show(report, generatedTestsTabData)
+ }
+
+ fun removeTestCase(
+ report: Report,
+ testCase: TestCase,
+ coverageVisualisationTabBuilder: CoverageVisualisationTabBuilder,
+ generatedTestsTabData: GeneratedTestsTabData,
+ ) {
+ report.testCaseList.remove(testCase.id)
+ report.normalized()
+ coverageVisualisationTabBuilder.show(report, generatedTestsTabData)
+ }
+
+ fun unselectTestCase(
+ report: Report,
+ unselectedTestCases: HashMap,
+ testCaseId: Int,
+ coverageVisualisationTabBuilder: CoverageVisualisationTabBuilder,
+ generatedTestsTabData: GeneratedTestsTabData,
+ ) {
+ unselectedTestCases[testCaseId] = report.testCaseList[testCaseId]!!
+ removeTestCase(report, report.testCaseList[testCaseId]!!, coverageVisualisationTabBuilder, generatedTestsTabData)
+ }
+
+ fun selectTestCase(
+ report: Report,
+ unselectedTestCases: HashMap,
+ testCaseId: Int,
+ coverageVisualisationTabBuilder: CoverageVisualisationTabBuilder,
+ generatedTestsTabData: GeneratedTestsTabData,
+ ) {
+ report.testCaseList[testCaseId] = unselectedTestCases[testCaseId]!!
+ unselectedTestCases.remove(testCaseId)
+ report.normalized()
+ coverageVisualisationTabBuilder.show(report, generatedTestsTabData)
+ }
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/utils/java/JavaDisplayUtils.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/utils/java/JavaDisplayUtils.kt
new file mode 100644
index 000000000..0320308d7
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/display/utils/java/JavaDisplayUtils.kt
@@ -0,0 +1,220 @@
+package org.jetbrains.research.testspark.display.utils.java
+
+import com.intellij.openapi.command.WriteCommandAction
+import com.intellij.openapi.fileChooser.FileChooser
+import com.intellij.openapi.fileChooser.FileChooserDescriptor
+import com.intellij.openapi.fileEditor.FileEditorManager
+import com.intellij.openapi.fileEditor.OpenFileDescriptor
+import com.intellij.openapi.project.Project
+import com.intellij.openapi.vfs.LocalFileSystem
+import com.intellij.openapi.vfs.VirtualFile
+import com.intellij.openapi.vfs.VirtualFileManager
+import com.intellij.psi.PsiClass
+import com.intellij.psi.PsiDocumentManager
+import com.intellij.psi.PsiElementFactory
+import com.intellij.psi.PsiFile
+import com.intellij.psi.PsiJavaFile
+import com.intellij.psi.PsiManager
+import com.intellij.refactoring.suggested.startOffset
+import com.intellij.util.containers.stream
+import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle
+import org.jetbrains.research.testspark.data.UIContext
+import org.jetbrains.research.testspark.display.utils.ErrorMessageManager
+import org.jetbrains.research.testspark.display.utils.template.DisplayUtils
+import org.jetbrains.research.testspark.java.JavaPsiClassWrapper
+import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper
+import org.jetbrains.research.testspark.testmanager.java.JavaTestAnalyzer
+import org.jetbrains.research.testspark.testmanager.java.JavaTestGenerator
+import java.io.File
+import java.util.Locale
+import javax.swing.JOptionPane
+
+class JavaDisplayUtils : DisplayUtils {
+ override fun applyTests(project: Project, uiContext: UIContext?, testCaseComponents: List): Boolean {
+ // Descriptor for choosing folders and java files
+ val descriptor = FileChooserDescriptor(true, true, false, false, false, false)
+
+ // Apply filter with folders and java files with main class
+ WriteCommandAction.runWriteCommandAction(project) {
+ descriptor.withFileFilter { file ->
+ file.isDirectory || (
+ file.extension?.lowercase(Locale.getDefault()) == "java" && (
+ PsiManager.getInstance(project).findFile(file!!) as PsiJavaFile
+ ).classes.stream().map { it.name }
+ .toArray()
+ .contains(
+ (
+ PsiManager.getInstance(project)
+ .findFile(file) as PsiJavaFile
+ ).name.removeSuffix(".java"),
+ )
+ )
+ }
+ }
+
+ val fileChooser = FileChooser.chooseFiles(
+ descriptor,
+ project,
+ LocalFileSystem.getInstance().findFileByPath(project.basePath!!),
+ )
+
+ /**
+ * Cancel button pressed
+ */
+ if (fileChooser.isEmpty()) return false
+
+ /**
+ * Chosen files by user
+ */
+ val chosenFile = fileChooser[0]
+
+ /**
+ * Virtual file of a final java file
+ */
+ var virtualFile: VirtualFile? = null
+
+ /**
+ * PsiClass of a final java file
+ */
+ var psiClass: PsiClass? = null
+
+ /**
+ * PsiJavaFile of a final java file
+ */
+ var psiJavaFile: PsiJavaFile? = null
+
+ if (chosenFile.isDirectory) {
+ // Input new file data
+ var className: String
+ var fileName: String
+ var filePath: String
+ // Waiting for correct file name input
+ while (true) {
+ val jOptionPane =
+ JOptionPane.showInputDialog(
+ null,
+ PluginLabelsBundle.get("optionPaneMessage"),
+ PluginLabelsBundle.get("optionPaneTitle"),
+ JOptionPane.PLAIN_MESSAGE,
+ null,
+ null,
+ null,
+ )
+
+ // Cancel button pressed
+ jOptionPane ?: return false
+
+ // Get class name from user
+ className = jOptionPane as String
+
+ // Set file name and file path
+ fileName = "${className.split('.')[0]}.java"
+ filePath = "${chosenFile.path}/$fileName"
+
+ // Check the correctness of a class name
+ if (!Regex("[A-Z][a-zA-Z0-9]*(.java)?").matches(className)) {
+ ErrorMessageManager.showErrorWindow(PluginLabelsBundle.get("incorrectFileNameMessage"))
+ continue
+ }
+
+ // Check the existence of a file with this name
+ if (File(filePath).exists()) {
+ ErrorMessageManager.showErrorWindow(PluginLabelsBundle.get("fileAlreadyExistsMessage"))
+ continue
+ }
+ break
+ }
+
+ // Create new file and set services of this file
+ WriteCommandAction.runWriteCommandAction(project) {
+ chosenFile.createChildData(null, fileName)
+ virtualFile = VirtualFileManager.getInstance().findFileByUrl("file://$filePath")!!
+ psiJavaFile = (PsiManager.getInstance(project).findFile(virtualFile!!) as PsiJavaFile)
+ psiClass = PsiElementFactory.getInstance(project).createClass(className.split(".")[0])
+
+ if (uiContext!!.testGenerationOutput.runWith.isNotEmpty()) {
+ psiClass!!.modifierList!!.addAnnotation("RunWith(${uiContext.testGenerationOutput.runWith})")
+ }
+
+ psiJavaFile!!.add(psiClass!!)
+ }
+ } else {
+ // Set services of the chosen file
+ virtualFile = chosenFile
+ psiJavaFile = (PsiManager.getInstance(project).findFile(virtualFile!!) as PsiJavaFile)
+ psiClass = psiJavaFile!!.classes[
+ psiJavaFile!!.classes.stream().map { it.name }.toArray()
+ .indexOf(psiJavaFile!!.name.removeSuffix(".java")),
+ ]
+ }
+
+ // Add tests to the file
+ WriteCommandAction.runWriteCommandAction(project) {
+ appendTestsToClass(project, uiContext, testCaseComponents, JavaPsiClassWrapper(psiClass!!), psiJavaFile!!)
+ }
+
+ // Open the file after adding
+ FileEditorManager.getInstance(project).openTextEditor(
+ OpenFileDescriptor(project, virtualFile!!),
+ true,
+ )
+
+ return true
+ }
+
+ override fun appendTestsToClass(
+ project: Project,
+ uiContext: UIContext?,
+ testCaseComponents: List,
+ selectedClass: PsiClassWrapper,
+ outputFile: PsiFile,
+ ) {
+ // block document
+ PsiDocumentManager.getInstance(project).doPostponedOperationsAndUnblockDocument(
+ PsiDocumentManager.getInstance(project).getDocument(outputFile as PsiJavaFile)!!,
+ )
+
+ // insert tests to a code
+ testCaseComponents.reversed().forEach {
+ val testMethodCode =
+ JavaTestAnalyzer.extractFirstTestMethodCode(
+ JavaTestGenerator.formatCode(
+ project,
+ it.replace("\r\n", "\n")
+ .replace("verifyException(", "// verifyException("),
+ uiContext!!.testGenerationOutput,
+ ),
+ )
+ // Fix Windows line separators
+ .replace("\r\n", "\n")
+
+ PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString(
+ selectedClass.rBrace!!,
+ testMethodCode,
+ )
+ }
+
+ // insert other info to a code
+ PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString(
+ selectedClass.rBrace!!,
+ uiContext!!.testGenerationOutput.otherInfo + "\n",
+ )
+
+ // insert imports to a code
+ PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString(
+ outputFile.importList?.startOffset ?: outputFile.packageStatement?.startOffset ?: 0,
+ uiContext.testGenerationOutput.importsCode.joinToString("\n") + "\n\n",
+ )
+
+ // insert package to a code
+ outputFile.packageStatement ?: PsiDocumentManager.getInstance(project).getDocument(outputFile)!!
+ .insertString(
+ 0,
+ if (uiContext.testGenerationOutput.packageName.isEmpty()) {
+ ""
+ } else {
+ "package ${uiContext.testGenerationOutput.packageName};\n\n"
+ },
+ )
+ }
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/utils/kotlin/KotlinDisplayUtils.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/utils/kotlin/KotlinDisplayUtils.kt
new file mode 100644
index 000000000..f2d61231a
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/display/utils/kotlin/KotlinDisplayUtils.kt
@@ -0,0 +1,235 @@
+package org.jetbrains.research.testspark.display.utils.kotlin
+
+import com.intellij.openapi.command.WriteCommandAction
+import com.intellij.openapi.fileChooser.FileChooser
+import com.intellij.openapi.fileChooser.FileChooserDescriptor
+import com.intellij.openapi.fileEditor.FileEditorManager
+import com.intellij.openapi.fileEditor.OpenFileDescriptor
+import com.intellij.openapi.project.Project
+import com.intellij.openapi.vfs.LocalFileSystem
+import com.intellij.openapi.vfs.VirtualFile
+import com.intellij.openapi.vfs.VirtualFileManager
+import com.intellij.psi.PsiDocumentManager
+import com.intellij.psi.PsiFile
+import com.intellij.psi.PsiJavaFile
+import com.intellij.psi.PsiManager
+import com.intellij.refactoring.suggested.endOffset
+import com.intellij.refactoring.suggested.startOffset
+import com.intellij.util.containers.stream
+import org.jetbrains.kotlin.psi.KtClass
+import org.jetbrains.kotlin.psi.KtFile
+import org.jetbrains.kotlin.psi.KtPsiFactory
+import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle
+import org.jetbrains.research.testspark.data.UIContext
+import org.jetbrains.research.testspark.display.utils.ErrorMessageManager
+import org.jetbrains.research.testspark.display.utils.template.DisplayUtils
+import org.jetbrains.research.testspark.kotlin.KotlinPsiClassWrapper
+import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper
+import org.jetbrains.research.testspark.testmanager.kotlin.KotlinTestAnalyzer
+import org.jetbrains.research.testspark.testmanager.kotlin.KotlinTestGenerator
+import java.io.File
+import java.util.Locale
+import javax.swing.JOptionPane
+
+class KotlinDisplayUtils : DisplayUtils {
+ override fun applyTests(project: Project, uiContext: UIContext?, testCaseComponents: List): Boolean {
+ val descriptor = FileChooserDescriptor(true, true, false, false, false, false)
+
+ // Apply filter with folders and java files with main class
+ WriteCommandAction.runWriteCommandAction(project) {
+ descriptor.withFileFilter { file ->
+ file.isDirectory || (
+ file.extension?.lowercase(Locale.getDefault()) == "kotlin" && (
+ PsiManager.getInstance(project).findFile(file!!) as KtFile
+ ).classes.stream().map { it.name }
+ .toArray()
+ .contains(
+ (
+ PsiManager.getInstance(project)
+ .findFile(file) as PsiJavaFile
+ ).name.removeSuffix(".kt"),
+ )
+ )
+ }
+ }
+
+ val fileChooser = FileChooser.chooseFiles(
+ descriptor,
+ project,
+ LocalFileSystem.getInstance().findFileByPath(project.basePath!!),
+ )
+
+ /**
+ * Cancel button pressed
+ */
+ if (fileChooser.isEmpty()) return false
+
+ /**
+ * Chosen files by user
+ */
+ val chosenFile = fileChooser[0]
+
+ /**
+ * Virtual file of a final java file
+ */
+ var virtualFile: VirtualFile? = null
+
+ /**
+ * PsiClass of a final java file
+ */
+ var ktClass: KtClass? = null
+
+ /**
+ * PsiJavaFile of a final java file
+ */
+ var psiKotlinFile: KtFile? = null
+
+ if (chosenFile.isDirectory) {
+ // Input new file data
+ var className: String
+ var fileName: String
+ var filePath: String
+ // Waiting for correct file name input
+ while (true) {
+ val jOptionPane =
+ JOptionPane.showInputDialog(
+ null,
+ PluginLabelsBundle.get("optionPaneMessage"),
+ PluginLabelsBundle.get("optionPaneTitle"),
+ JOptionPane.PLAIN_MESSAGE,
+ null,
+ null,
+ null,
+ )
+
+ // Cancel button pressed
+ jOptionPane ?: return false
+
+ // Get class name from user
+ className = jOptionPane as String
+
+ // Set file name and file path
+ fileName = "${className.split('.')[0]}.kt"
+ filePath = "${chosenFile.path}/$fileName"
+
+ // Check the correctness of a class name
+ if (!Regex("[A-Z][a-zA-Z0-9]*(.kt)?").matches(className)) {
+ ErrorMessageManager.showErrorWindow(PluginLabelsBundle.get("incorrectFileNameMessage"))
+ continue
+ }
+
+ // Check the existence of a file with this name
+ if (File(filePath).exists()) {
+ ErrorMessageManager.showErrorWindow(PluginLabelsBundle.get("fileAlreadyExistsMessage"))
+ continue
+ }
+ break
+ }
+
+ // Create new file and set services of this file
+ WriteCommandAction.runWriteCommandAction(project) {
+ chosenFile.createChildData(null, fileName)
+ virtualFile = VirtualFileManager.getInstance().findFileByUrl("file://$filePath")!!
+ psiKotlinFile = (PsiManager.getInstance(project).findFile(virtualFile!!) as KtFile)
+
+ val ktPsiFactory = KtPsiFactory(project)
+ ktClass = ktPsiFactory.createClass("class ${className.split(".")[0]} {}")
+
+ if (uiContext!!.testGenerationOutput.runWith.isNotEmpty()) {
+ val annotationEntry =
+ ktPsiFactory.createAnnotationEntry("@RunWith(${uiContext.testGenerationOutput.runWith})")
+ ktClass!!.addBefore(annotationEntry, ktClass!!.body)
+ }
+
+ psiKotlinFile!!.add(ktClass!!)
+ }
+ } else {
+ // Set services of the chosen file
+ virtualFile = chosenFile
+ psiKotlinFile = (PsiManager.getInstance(project).findFile(virtualFile!!) as KtFile)
+ val classNameNoSuffix = psiKotlinFile!!.name.removeSuffix(".kt")
+ ktClass = psiKotlinFile?.declarations?.filterIsInstance()?.find { it.name == classNameNoSuffix }
+ }
+
+ // Add tests to the file
+ WriteCommandAction.runWriteCommandAction(project) {
+ appendTestsToClass(
+ project,
+ uiContext,
+ testCaseComponents,
+ KotlinPsiClassWrapper(ktClass as KtClass),
+ psiKotlinFile!!,
+ )
+ }
+
+ // Open the file after adding
+ FileEditorManager.getInstance(project).openTextEditor(
+ OpenFileDescriptor(project, virtualFile!!),
+ true,
+ )
+
+ return true
+ }
+
+ override fun appendTestsToClass(
+ project: Project,
+ uiContext: UIContext?,
+ testCaseComponents: List,
+ selectedClass: PsiClassWrapper,
+ outputFile: PsiFile,
+ ) {
+ // block document
+ PsiDocumentManager.getInstance(project).doPostponedOperationsAndUnblockDocument(
+ PsiDocumentManager.getInstance(project).getDocument(outputFile as KtFile)!!,
+ )
+
+ // insert tests to a code
+ testCaseComponents.reversed().forEach {
+ val testMethodCode =
+ KotlinTestAnalyzer.extractFirstTestMethodCode(
+ KotlinTestGenerator.formatCode(
+ project,
+ it.replace("\r\n", "\n")
+ .replace("verifyException(", "// verifyException("),
+ uiContext!!.testGenerationOutput,
+ ),
+ )
+ // Fix Windows line separators
+ .replace("\r\n", "\n")
+
+ PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString(
+ selectedClass.rBrace!!,
+ testMethodCode,
+ )
+ }
+
+ // insert other info to a code
+ PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString(
+ selectedClass.rBrace!!,
+ uiContext!!.testGenerationOutput.otherInfo + "\n",
+ )
+
+ // Create the imports string
+ val importsString = uiContext.testGenerationOutput.importsCode.joinToString("\n") + "\n\n"
+
+ // Find the insertion offset
+ val insertionOffset = outputFile.importList?.startOffset
+ ?: outputFile.packageDirective?.endOffset
+ ?: 0
+
+ // Insert the imports into the document
+ PsiDocumentManager.getInstance(project).getDocument(outputFile)?.let { document ->
+ document.insertString(insertionOffset, importsString)
+ PsiDocumentManager.getInstance(project).commitDocument(document)
+ }
+
+ val packageName = uiContext.testGenerationOutput.packageName
+ val packageStatement = if (packageName.isEmpty()) "" else "package $packageName\n\n"
+
+ // Insert the package statement at the beginning of the document
+ PsiDocumentManager.getInstance(project).getDocument(outputFile)?.let { document ->
+ document.insertString(0, packageStatement)
+ PsiDocumentManager.getInstance(project).commitDocument(document)
+ }
+ }
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/utils/template/DisplayUtils.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/utils/template/DisplayUtils.kt
new file mode 100644
index 000000000..5beb63a73
--- /dev/null
+++ b/src/main/kotlin/org/jetbrains/research/testspark/display/utils/template/DisplayUtils.kt
@@ -0,0 +1,30 @@
+package org.jetbrains.research.testspark.display.utils.template
+
+import com.intellij.openapi.project.Project
+import com.intellij.psi.PsiFile
+import org.jetbrains.research.testspark.data.UIContext
+import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper
+
+/**
+ * Interface for utility functions related to tests applying to the project.
+ * Each language utils class implements DisplayUtils interface.
+ */
+interface DisplayUtils {
+ /**
+ * Applies specified tests to a given project.
+ *
+ * @returns true, if tests applying is successful, otherwise false
+ */
+ fun applyTests(project: Project, uiContext: UIContext?, testCaseComponents: List): Boolean
+
+ /**
+ * Appends specified tests to a class within the given project.
+ */
+ fun appendTestsToClass(
+ project: Project,
+ uiContext: UIContext?,
+ testCaseComponents: List,
+ selectedClass: PsiClassWrapper,
+ outputFile: PsiFile,
+ )
+}
diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt
index 14928c818..916da9537 100644
--- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt
+++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt
@@ -2,13 +2,29 @@ package org.jetbrains.research.testspark.helpers
import com.google.gson.JsonParser
import com.intellij.openapi.application.ApplicationManager
+import com.intellij.openapi.project.Project
import com.intellij.openapi.ui.ComboBox
import com.intellij.util.io.HttpRequests
-import org.jetbrains.research.testspark.bundles.TestSparkToolTipsBundle
-import org.jetbrains.research.testspark.settings.SettingsApplicationState
+import org.jetbrains.research.testspark.bundles.llm.LLMMessagesBundle
+import org.jetbrains.research.testspark.bundles.llm.LLMSettingsBundle
+import org.jetbrains.research.testspark.core.data.TestGenerationData
+import org.jetbrains.research.testspark.core.generation.llm.executeTestCaseModificationRequest
+import org.jetbrains.research.testspark.core.generation.llm.network.RequestManager
+import org.jetbrains.research.testspark.core.monitor.ErrorMonitor
+import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator
+import org.jetbrains.research.testspark.core.test.SupportedLanguage
+import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM
+import org.jetbrains.research.testspark.services.LLMSettingsService
+import org.jetbrains.research.testspark.settings.llm.LLMSettingsState
+import org.jetbrains.research.testspark.tools.factories.TestsAssemblerFactory
+import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments
+import org.jetbrains.research.testspark.tools.llm.error.LLMErrorManager
import org.jetbrains.research.testspark.tools.llm.generation.LLMPlatform
+import org.jetbrains.research.testspark.tools.llm.generation.TestBodyPrinterFactory
+import org.jetbrains.research.testspark.tools.llm.generation.TestSuiteParserFactory
import org.jetbrains.research.testspark.tools.llm.generation.grazie.GrazieInfo
import org.jetbrains.research.testspark.tools.llm.generation.grazie.GraziePlatform
+import org.jetbrains.research.testspark.tools.llm.generation.hf.HuggingFacePlatform
import org.jetbrains.research.testspark.tools.llm.generation.openai.OpenAIPlatform
import java.net.HttpURLConnection
import javax.swing.DefaultComboBoxModel
@@ -16,239 +32,341 @@ import javax.swing.JTextField
import javax.swing.event.DocumentEvent
import javax.swing.event.DocumentListener
-/**
- * Checks if the Grazie class is loaded.
- * @return true if the Grazie class is loaded, false otherwise.
- */
-private fun isGrazieClassLoaded(): Boolean {
- val className = "org.jetbrains.research.grazie.Request"
- return try {
- Class.forName(className)
- true
- } catch (e: ClassNotFoundException) {
- false
+// Implementation of the common LLM functions
+object LLMHelper {
+ /**
+ * Checks if the Grazie class is loaded.
+ * @return true if the Grazie class is loaded, false otherwise.
+ */
+ private fun isGrazieClassLoaded(): Boolean {
+ val className = "org.jetbrains.research.grazie.Request"
+ return try {
+ Class.forName(className)
+ true
+ } catch (e: ClassNotFoundException) {
+ false
+ }
}
-}
-/**
- * Updates the model selector based on the selected platform in the platform selector.
- * If the selected platform is "Grazie", the model selector is disabled and set to display only "GPT-4".
- * If the selected platform is not "Grazie", the model selector is updated with the available modules fetched asynchronously using llmUserTokenField and enables the okLlmButton.
- * If the modules fetch fails, the model selector is set to display the default modules and is disabled.
- *
- * This method runs on a separate thread using ApplicationManager.getApplication().executeOnPooledThread{}.
- */
-private fun updateModelSelector(
- platformSelector: ComboBox,
- modelSelector: ComboBox,
- llmUserTokenField: JTextField,
- llmPlatforms: List,
- settingsState: SettingsApplicationState,
-) {
- ApplicationManager.getApplication().executeOnPooledThread {
- var modules = arrayOf("")
- if (platformSelector.selectedItem!!.toString() == settingsState.openAIName) {
- modules = getOpenAIModels(llmUserTokenField.text)
- }
- if (platformSelector.selectedItem!!.toString() == settingsState.grazieName) {
- modules = getGrazieModels()
+ /**
+ * Updates the model selector based on the selected platform in the platform selector.
+ * If the selected platform is "Grazie", the model selector is disabled and set to display only "GPT-4".
+ * If the selected platform is not "Grazie", the model selector is updated with the available models fetched asynchronously using llmUserTokenField and enables the okLlmButton.
+ * If the models fetch fails, the model selector is set to display the default models and is disabled.
+ *
+ * This method runs on a separate thread using ApplicationManager.getApplication().executeOnPooledThread{}.
+ */
+ private fun updateModelSelector(
+ platformSelector: ComboBox,
+ modelSelector: ComboBox,
+ llmUserTokenField: JTextField,
+ llmPlatforms: List,
+ settingsState: LLMSettingsState,
+ ) {
+ ApplicationManager.getApplication().executeOnPooledThread {
+ var models = arrayOf("")
+ if (platformSelector.selectedItem!!.toString() == settingsState.openAIName) {
+ models = getOpenAIModels(llmUserTokenField.text)
+ }
+ if (platformSelector.selectedItem!!.toString() == settingsState.grazieName) {
+ models = getGrazieModels()
+ }
+ if (platformSelector.selectedItem!!.toString() == settingsState.huggingFaceName) {
+ models = getHuggingFaceModels()
+ }
+ modelSelector.model = DefaultComboBoxModel(models)
+ for (index in llmPlatforms.indices) {
+ if (llmPlatforms[index].name == settingsState.openAIName &&
+ llmPlatforms[index].name == platformSelector.selectedItem!!.toString()
+ ) {
+ modelSelector.selectedItem = settingsState.openAIModel
+ llmPlatforms[index].model = modelSelector.selectedItem!!.toString()
+ }
+ if (llmPlatforms[index].name == settingsState.grazieName &&
+ llmPlatforms[index].name == platformSelector.selectedItem!!.toString()
+ ) {
+ modelSelector.selectedItem = settingsState.grazieModel
+ llmPlatforms[index].model = modelSelector.selectedItem!!.toString()
+ }
+ if (llmPlatforms[index].name == settingsState.huggingFaceName &&
+ llmPlatforms[index].name == platformSelector.selectedItem!!.toString()
+ ) {
+ modelSelector.selectedItem = settingsState.huggingFaceModel
+ llmPlatforms[index].model = modelSelector.selectedItem!!.toString()
+ }
+ }
+ modelSelector.isEnabled = true
+ if (models.contentEquals(arrayOf(""))) modelSelector.isEnabled = false
}
- modelSelector.model = DefaultComboBoxModel(modules)
+ }
+
+ /**
+ * Updates LlmUserTokenField based on the selected platform in the platformSelector ComboBox.
+ *
+ * @param platformSelector The ComboBox that allows the user to select a platform.
+ * @param llmUserTokenField The JTextField that displays the user token for the selected platform.
+ */
+ private fun updateLlmUserTokenField(
+ platformSelector: ComboBox,
+ llmUserTokenField: JTextField,
+ llmPlatforms: List,
+ settingsState: LLMSettingsState,
+ ) {
for (index in llmPlatforms.indices) {
if (llmPlatforms[index].name == settingsState.openAIName &&
llmPlatforms[index].name == platformSelector.selectedItem!!.toString()
) {
- modelSelector.selectedItem = settingsState.openAIModel
- llmPlatforms[index].model = modelSelector.selectedItem!!.toString()
+ llmUserTokenField.text = settingsState.openAIToken
+ llmPlatforms[index].token = settingsState.openAIToken
}
if (llmPlatforms[index].name == settingsState.grazieName &&
llmPlatforms[index].name == platformSelector.selectedItem!!.toString()
) {
- modelSelector.selectedItem = settingsState.grazieModel
- llmPlatforms[index].model = modelSelector.selectedItem!!.toString()
+ llmUserTokenField.text = settingsState.grazieToken
+ llmPlatforms[index].token = settingsState.grazieToken
+ }
+ if (llmPlatforms[index].name == settingsState.huggingFaceName &&
+ llmPlatforms[index].name == platformSelector.selectedItem!!.toString()
+ ) {
+ llmUserTokenField.text = settingsState.huggingFaceToken
+ llmPlatforms[index].token = settingsState.huggingFaceToken
}
}
- modelSelector.isEnabled = true
- if (modules.contentEquals(arrayOf(""))) modelSelector.isEnabled = false
}
-}
-/**
- * Updates LlmUserTokenField based on the selected platform in the platformSelector ComboBox.
- *
- * @param platformSelector The ComboBox that allows the user to select a platform.
- * @param llmUserTokenField The JTextField that displays the user token for the selected platform.
- */
-private fun updateLlmUserTokenField(
- platformSelector: ComboBox,
- llmUserTokenField: JTextField,
- llmPlatforms: List,
- settingsState: SettingsApplicationState,
-) {
- for (index in llmPlatforms.indices) {
- if (llmPlatforms[index].name == settingsState.openAIName &&
- llmPlatforms[index].name == platformSelector.selectedItem!!.toString()
- ) {
- llmUserTokenField.text = settingsState.openAIToken
- llmPlatforms[index].token = settingsState.openAIToken
- }
- if (llmPlatforms[index].name == settingsState.grazieName &&
- llmPlatforms[index].name == platformSelector.selectedItem!!.toString()
- ) {
- llmUserTokenField.text = settingsState.grazieToken
- llmPlatforms[index].token = settingsState.grazieToken
- }
- }
-}
+ /**
+ * Adds listeners to various components in the LLM panel.
+ *
+ * @param platformSelector The combo box for selecting the LLM platform.
+ * @param modelSelector The combo box for selecting the LLM model.
+ * @param llmUserTokenField The text field for entering the LLM user token.
+ * @param llmPlatforms The list of LLM platforms.
+ */
+ fun addLLMPanelListeners(
+ platformSelector: ComboBox,
+ modelSelector: ComboBox,
+ llmUserTokenField: JTextField,
+ llmPlatforms: List,
+ settingsState: LLMSettingsState,
+ ) {
+ llmUserTokenField.document.addDocumentListener(object : DocumentListener {
+ override fun insertUpdate(e: DocumentEvent?) {
+ updateToken()
+ }
-/**
- * Adds listeners to various components in the LLM panel.
- *
- * @param platformSelector The combo box for selecting the LLM platform.
- * @param modelSelector The combo box for selecting the LLM model.
- * @param llmUserTokenField The text field for entering the LLM user token.
- * @param llmPlatforms The list of LLM platforms.
- */
-fun addLLMPanelListeners(
- platformSelector: ComboBox,
- modelSelector: ComboBox,
- llmUserTokenField: JTextField,
- llmPlatforms: List,
- settingsState: SettingsApplicationState,
-) {
- llmUserTokenField.document.addDocumentListener(object : DocumentListener {
- override fun insertUpdate(e: DocumentEvent?) {
- updateToken()
- }
+ override fun removeUpdate(e: DocumentEvent?) {
+ updateToken()
+ }
- override fun removeUpdate(e: DocumentEvent?) {
- updateToken()
- }
+ override fun changedUpdate(e: DocumentEvent?) {
+ updateToken()
+ }
+
+ private fun updateToken() {
+ for (llmPlatform in llmPlatforms) {
+ if (platformSelector.selectedItem!!.toString() == llmPlatform.name) {
+ llmPlatform.token = llmUserTokenField.text
+ }
+ }
+ updateModelSelector(platformSelector, modelSelector, llmUserTokenField, llmPlatforms, settingsState)
+ }
+ })
- override fun changedUpdate(e: DocumentEvent?) {
- updateToken()
+ platformSelector.addItemListener {
+ updateLlmUserTokenField(platformSelector, llmUserTokenField, llmPlatforms, settingsState)
+ updateModelSelector(platformSelector, modelSelector, llmUserTokenField, llmPlatforms, settingsState)
}
- private fun updateToken() {
+ modelSelector.addItemListener {
for (llmPlatform in llmPlatforms) {
if (platformSelector.selectedItem!!.toString() == llmPlatform.name) {
- llmPlatform.token = llmUserTokenField.text
+ llmPlatform.model = modelSelector.item
}
}
- updateModelSelector(platformSelector, modelSelector, llmUserTokenField, llmPlatforms, settingsState)
}
- })
+ }
- platformSelector.addItemListener {
+ /**
+ * Stylizes the main components of the application.
+ *
+ * @param llmUserTokenField the text field for the LLM user token
+ * @param modelSelector the combo box for selecting the model
+ * @param platformSelector the combo box for selecting the platform
+ */
+ fun stylizeMainComponents(
+ platformSelector: ComboBox,
+ modelSelector: ComboBox,
+ llmUserTokenField: JTextField,
+ llmPlatforms: List,
+ settingsState: LLMSettingsState,
+ ) {
+ // Check if the Grazie platform access is available in the current build
+ if (isGrazieClassLoaded()) {
+ platformSelector.model = DefaultComboBoxModel(llmPlatforms.map { it.name }.toTypedArray())
+ platformSelector.selectedItem = settingsState.currentLLMPlatformName
+ }
+
+ llmUserTokenField.toolTipText = LLMSettingsBundle.get("llmToken")
updateLlmUserTokenField(platformSelector, llmUserTokenField, llmPlatforms, settingsState)
+
+ modelSelector.toolTipText = LLMSettingsBundle.get("model")
updateModelSelector(platformSelector, modelSelector, llmUserTokenField, llmPlatforms, settingsState)
}
- modelSelector.addItemListener {
- for (llmPlatform in llmPlatforms) {
- if (platformSelector.selectedItem!!.toString() == llmPlatform.name) {
- llmPlatform.model = modelSelector.item
- }
- }
+ /**
+ * Retrieves the list of LLMPlatforms.
+ *
+ * @return The list of LLMPlatforms.
+ */
+ fun getLLLMPlatforms(): List