Skip to content

Commit

Permalink
feat: throw custom exception if SDK is not configured
Browse files Browse the repository at this point in the history
  • Loading branch information
Vladislav Artiukhov committed Oct 9, 2024
1 parent 177ac88 commit 3dffcb8
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 47 deletions.
96 changes: 52 additions & 44 deletions src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.intellij.openapi.progress.ProgressManager
import com.intellij.openapi.project.Project
import org.jetbrains.research.testspark.actions.controllers.TestGenerationController
import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle
import org.jetbrains.research.testspark.core.exception.TestSparkException
import org.jetbrains.research.testspark.core.test.data.CodeType
import org.jetbrains.research.testspark.data.FragmentToTestData
import org.jetbrains.research.testspark.display.TestSparkDisplayManager
Expand All @@ -14,6 +15,7 @@ import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper
import org.jetbrains.research.testspark.langwrappers.PsiHelper
import org.jetbrains.research.testspark.tools.Pipeline
import org.jetbrains.research.testspark.tools.TestsExecutionResultManager
import org.jetbrains.research.testspark.tools.llm.error.LLMErrorManager
import org.jetbrains.research.testspark.tools.llm.generation.LLMProcessManager
import org.jetbrains.research.testspark.tools.llm.generation.PromptManager
import org.jetbrains.research.testspark.tools.template.Tool
Expand Down Expand Up @@ -87,21 +89,16 @@ class Llm(override val name: String = "LLM") : Tool {
return
}
val codeType = FragmentToTestData(CodeType.CLASS)
createLLMPipeline(

generateTests(
project,
psiHelper,
caretOffset,
fileUrl,
testSamplesCode,
testGenerationController,
testSparkDisplayManager,
testsExecutionResultManager,
).runTestGeneration(
LLMProcessManager(
project,
psiHelper.language,
PromptManager(project, psiHelper, caretOffset),
testSamplesCode,
),
codeType,
)
}
Expand Down Expand Up @@ -132,21 +129,16 @@ class Llm(override val name: String = "LLM") : Tool {
}
val psiMethod = psiHelper.getSurroundingMethod(caretOffset)!!
val codeType = FragmentToTestData(CodeType.METHOD, psiHelper.generateMethodDescriptor(psiMethod))
createLLMPipeline(

generateTests(
project,
psiHelper,
caretOffset,
fileUrl,
testSamplesCode,
testGenerationController,
testSparkDisplayManager,
testsExecutionResultManager,
).runTestGeneration(
LLMProcessManager(
project,
psiHelper.language,
PromptManager(project, psiHelper, caretOffset),
testSamplesCode,
),
codeType,
)
}
Expand All @@ -155,7 +147,7 @@ class Llm(override val name: String = "LLM") : Tool {
* Generates tests for a specific line of code.
*
* @param project The current project.
* @param psiFile The PSI file containing the code.
* @param psiHelper The PSI file containing the code.
* @param caretOffset The offset position of the caret.
* @param fileUrl The URL of the file.
* @param testSamplesCode The code for the test samples.
Expand All @@ -177,53 +169,69 @@ class Llm(override val name: String = "LLM") : Tool {
}
val selectedLine: Int = psiHelper.getSurroundingLineNumber(caretOffset)!!
val codeType = FragmentToTestData(CodeType.LINE, selectedLine)
createLLMPipeline(

generateTests(
project,
psiHelper,
caretOffset,
fileUrl,
testSamplesCode,
testGenerationController,
testSparkDisplayManager,
testsExecutionResultManager,
).runTestGeneration(
LLMProcessManager(
project,
psiHelper.language,
PromptManager(project, psiHelper, caretOffset),
testSamplesCode,
),
codeType,
)
}

/**
* Creates a LLMPipeline instance.
* Generates tests for a given code type.
*
* @param project the project of the pipeline.
* @param psiHelper the PsiHelper associated with the pipeline.
* @param caretOffset the offset of the caret position within the PSI file.
* @param fileUrl the URL of the file to be processed by the pipeline.
* @return a LLMPipeline instance.
* @param project The project in which the method is located.
* @param psiHelper The PsiHelper associated with the pipeline.
* @param caretOffset The offset of the caret position in the PSI file.
* @param fileUrl The URL of the file to generate tests for (optional).
* @param testSamplesCode The code of the test samples to use for test generation.
* @param testGenerationController The controller for test generation operations.
* @param testSparkDisplayManager The manager for displaying test-related information.
* @param testsExecutionResultManager The manager for handling test execution results.
* @param codeType The type of data fragment to generate tests for.
*/
private fun createLLMPipeline(
private fun generateTests(
project: Project,
psiHelper: PsiHelper,
caretOffset: Int,
fileUrl: String?,
testSamplesCode: String,
testGenerationController: TestGenerationController,
testSparkDisplayManager: TestSparkDisplayManager,
testsExecutionResultManager: TestsExecutionResultManager,
): Pipeline {
val packageName = psiHelper.getPackageName()
return Pipeline(
project,
psiHelper,
caretOffset,
fileUrl,
packageName,
testGenerationController,
testSparkDisplayManager,
testsExecutionResultManager,
)
codeType: FragmentToTestData,
) {
try {
val packageName = psiHelper.getPackageName()
val pipeline = Pipeline(
project,
psiHelper,
caretOffset,
fileUrl,
packageName,
testGenerationController,
testSparkDisplayManager,
testsExecutionResultManager,
)

val manager = LLMProcessManager(
project,
psiHelper.language,
PromptManager(project, psiHelper, caretOffset),
testSamplesCode,
)

pipeline.runTestGeneration(manager, codeType)
}
catch (err: TestSparkException) {
testGenerationController.finished()
LLMErrorManager().errorProcess(err.message!!, project, testGenerationController.errorMonitor)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.intellij.openapi.roots.ProjectRootManager
import org.jetbrains.research.testspark.bundles.llm.LLMMessagesBundle
import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle
import org.jetbrains.research.testspark.core.data.TestGenerationData
import org.jetbrains.research.testspark.core.exception.JavaSDKMissingException
import org.jetbrains.research.testspark.core.generation.llm.FeedbackCycleExecutionResult
import org.jetbrains.research.testspark.core.generation.llm.LLMWithFeedbackCycle
import org.jetbrains.research.testspark.core.generation.llm.getImportsCodeFromTestSuiteCode
Expand Down Expand Up @@ -50,11 +51,15 @@ class LLMProcessManager(
private val language: SupportedLanguage,
private val promptManager: PromptManager,
private val testSamplesCode: String,
private val projectSDKPath: Path? = null,
projectSDKPath: Path? = null,
) : ProcessManager {

private val homeDirectory =
projectSDKPath?.toString() ?: ProjectRootManager.getInstance(project).projectSdk!!.homeDirectory!!.path
private val homeDirectory = projectSDKPath?.toString() ?: run {
val sdk = ProjectRootManager.getInstance(project).projectSdk?.homeDirectory?.path
?: throw JavaSDKMissingException(LLMMessagesBundle.get("javaSdkNotConfigured"))

return@run sdk
}

private val testFileName: String = when (language) {
SupportedLanguage.Java -> "GeneratedTest.java"
Expand Down

0 comments on commit 3dffcb8

Please sign in to comment.